mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-19 07:54:48 +01:00
Split error handling into isUnexpectedError predicate and logFlushError function. Use rm.AppendError() to collect unexpected errors during rule execution, then log after FlushRelease completes rather than immediately. Also updates goutils dependency for AppendError method availability.
430 lines
9.9 KiB
Go
430 lines
9.9 KiB
Go
package rules
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/quic-go/quic-go/http3"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/yusing/godoxy/internal/serialization"
|
|
gperr "github.com/yusing/goutils/errs"
|
|
httputils "github.com/yusing/goutils/http"
|
|
"golang.org/x/net/http2"
|
|
|
|
_ "unsafe"
|
|
)
|
|
|
|
type (
|
|
/*
|
|
Rules is a list of rules.
|
|
|
|
Example:
|
|
|
|
proxy.app1.rules: |
|
|
- name: default
|
|
do: |
|
|
rewrite / /index.html
|
|
serve /var/www/goaccess
|
|
- name: ws
|
|
on: |
|
|
header Connection Upgrade
|
|
header Upgrade websocket
|
|
do: bypass
|
|
|
|
proxy.app2.rules: |
|
|
- name: default
|
|
do: bypass
|
|
- name: block POST and PUT
|
|
on: method POST | method PUT
|
|
do: error 403 Forbidden
|
|
*/
|
|
//nolint:recvcheck
|
|
Rules []Rule
|
|
// Rule represents a reverse proxy rule.
|
|
// The `Do` field is executed when `On` matches.
|
|
//
|
|
// - A rule may have multiple lines in the `On` section.
|
|
// - All `On` lines must match for the rule to trigger.
|
|
// - Each line can have several checks—one match per line is enough for that line.
|
|
Rule struct {
|
|
Name string `json:"name"`
|
|
On RuleOn `json:"on" swaggertype:"string"`
|
|
Do Command `json:"do" swaggertype:"string"`
|
|
}
|
|
)
|
|
|
|
func isDefaultRule(rule Rule) bool {
|
|
return rule.Name == "default" || rule.On.raw == OnDefault
|
|
}
|
|
|
|
func (rules Rules) Validate() gperr.Error {
|
|
var defaultRulesFound []int
|
|
for i := range rules {
|
|
rule := rules[i]
|
|
if isDefaultRule(rule) {
|
|
defaultRulesFound = append(defaultRulesFound, i)
|
|
}
|
|
if rules[i].Name == "" {
|
|
// set name to index if name is empty
|
|
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
|
}
|
|
}
|
|
if len(defaultRulesFound) > 1 {
|
|
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
|
}
|
|
for i := range rules {
|
|
r1 := rules[i]
|
|
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
|
continue
|
|
}
|
|
sig1, ok := matcherSignature(r1.On.raw)
|
|
if !ok {
|
|
continue
|
|
}
|
|
for j := i + 1; j < len(rules); j++ {
|
|
r2 := rules[j]
|
|
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
|
continue
|
|
}
|
|
sig2, ok := matcherSignature(r2.On.raw)
|
|
if !ok || sig1 != sig2 {
|
|
continue
|
|
}
|
|
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (rule Rule) doesTerminateInPre() bool {
|
|
for _, cmd := range rule.Do.pre {
|
|
handler, ok := cmd.(Handler)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if handler.Terminates() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func matcherSignature(raw string) (string, bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", false
|
|
}
|
|
|
|
andParts := splitAnd(raw)
|
|
if len(andParts) == 0 {
|
|
return "", false
|
|
}
|
|
|
|
canonAnd := make([]string, 0, len(andParts))
|
|
for _, andPart := range andParts {
|
|
orParts := splitPipe(andPart)
|
|
if len(orParts) == 0 {
|
|
continue
|
|
}
|
|
canonOr := make([]string, 0, len(orParts))
|
|
for _, atom := range orParts {
|
|
subject, args, err := parse(strings.TrimSpace(atom))
|
|
if err != nil || subject == "" {
|
|
return "", false
|
|
}
|
|
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
|
}
|
|
slices.Sort(canonOr)
|
|
canonOr = slices.Compact(canonOr)
|
|
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
|
}
|
|
|
|
slices.Sort(canonAnd)
|
|
canonAnd = slices.Compact(canonAnd)
|
|
if len(canonAnd) == 0 {
|
|
return "", false
|
|
}
|
|
return strings.Join(canonAnd, "&"), true
|
|
}
|
|
|
|
// Parse parses a rule configuration string.
|
|
// It first tries the block syntax (if the string contains a top-level '{'),
|
|
// then falls back to YAML syntax.
|
|
func (rules *Rules) Parse(config string) error {
|
|
config = strings.TrimSpace(config)
|
|
if config == "" {
|
|
return nil
|
|
}
|
|
|
|
// Prefer block syntax if it looks like block syntax.
|
|
if hasTopLevelLBrace(config) {
|
|
blockRules, err := parseBlockRules(config)
|
|
if err == nil {
|
|
*rules = blockRules
|
|
return nil
|
|
}
|
|
// Fall through to YAML (backward compatibility).
|
|
}
|
|
|
|
// YAML fallback
|
|
var anySlice []any
|
|
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
|
if yamlErr == nil {
|
|
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
|
}
|
|
|
|
// If YAML fails and we didn't try block syntax yet, try it now.
|
|
blockRules, err := parseBlockRules(config)
|
|
if err == nil {
|
|
*rules = blockRules
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
|
// Used to decide whether to prioritize the block syntax.
|
|
func hasTopLevelLBrace(s string) bool {
|
|
quote := rune(0)
|
|
inLine := false
|
|
inBlock := false
|
|
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
|
|
if inLine {
|
|
if c == '\n' {
|
|
inLine = false
|
|
}
|
|
continue
|
|
}
|
|
if inBlock {
|
|
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
|
inBlock = false
|
|
i++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if quote != 0 {
|
|
if quote != '`' && c == '\\' && i+1 < len(s) {
|
|
i++
|
|
continue
|
|
}
|
|
if rune(c) == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
|
|
switch c {
|
|
case '\'', '"', '`':
|
|
quote = rune(c)
|
|
continue
|
|
case '{':
|
|
return true
|
|
case '#':
|
|
inLine = true
|
|
continue
|
|
case '/':
|
|
if i+1 < len(s) && s[i+1] == '/' {
|
|
inLine = true
|
|
i++
|
|
continue
|
|
}
|
|
if i+1 < len(s) && s[i+1] == '*' {
|
|
inBlock = true
|
|
i++
|
|
continue
|
|
}
|
|
default:
|
|
if unicode.IsSpace(rune(c)) {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
|
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|
if len(rules) == 0 {
|
|
return up
|
|
}
|
|
|
|
var defaultRule *Rule
|
|
|
|
var nonDefaultRules Rules
|
|
for _, rule := range rules {
|
|
if isDefaultRule(rule) {
|
|
r := rule
|
|
defaultRule = &r
|
|
} else {
|
|
nonDefaultRules = append(nonDefaultRules, rule)
|
|
}
|
|
}
|
|
|
|
if len(nonDefaultRules) == 0 {
|
|
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
|
return up
|
|
}
|
|
}
|
|
|
|
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
|
return cmd.pre.ServeHTTP(w, r, up)
|
|
}
|
|
|
|
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
|
return cmd.post.ServeHTTP(w, r, up)
|
|
}
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
rm := httputils.NewResponseModifier(w)
|
|
defer func() {
|
|
if _, err := rm.FlushRelease(); err != nil {
|
|
logFlushError(err, r)
|
|
}
|
|
}()
|
|
|
|
var hasError bool
|
|
|
|
preRules := make(Rules, 0, len(nonDefaultRules)+1)
|
|
if defaultRule != nil {
|
|
preRules = append(preRules, *defaultRule)
|
|
}
|
|
preRules = append(preRules, nonDefaultRules...)
|
|
|
|
executedPre := make([]bool, len(preRules))
|
|
terminatedInPre := make([]bool, len(preRules))
|
|
preTerminated := false
|
|
for i, rule := range preRules {
|
|
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
|
continue
|
|
}
|
|
if preTerminated {
|
|
// Preserve post-only commands (e.g. logging) even after
|
|
// pre-phase termination.
|
|
if len(rule.Do.pre) == 0 {
|
|
executedPre[i] = true
|
|
}
|
|
continue
|
|
}
|
|
|
|
executedPre[i] = true
|
|
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
terminatedInPre[i] = true
|
|
preTerminated = true
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
hasError = true
|
|
}
|
|
}
|
|
|
|
if !rm.HasStatus() {
|
|
if hasError {
|
|
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
|
up(rm, r)
|
|
}
|
|
}
|
|
|
|
// Run post commands for rules that actually executed in pre phase,
|
|
// unless that same rule terminated in pre phase.
|
|
for i, rule := range preRules {
|
|
if !executedPre[i] || terminatedInPre[i] {
|
|
continue
|
|
}
|
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Run true post-matcher rules after response is available.
|
|
for _, rule := range nonDefaultRules {
|
|
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
|
continue
|
|
}
|
|
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
|
// as "pre" for requirement purposes still need to run in this phase.
|
|
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
}
|
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rule *Rule) String() string {
|
|
return rule.Name
|
|
}
|
|
|
|
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
|
if rule.On.checker == nil {
|
|
return true
|
|
}
|
|
return rule.On.Check(w, r)
|
|
}
|
|
|
|
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
|
var errStreamClosed error
|
|
|
|
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
|
var errClientDisconnected error
|
|
|
|
func isUnexpectedError(err error) bool {
|
|
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
|
return false
|
|
}
|
|
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
|
// ignore these errors
|
|
if h2Err.Code == http2.ErrCodeStreamClosed {
|
|
return false
|
|
}
|
|
}
|
|
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
|
// ignore these errors
|
|
switch h3Err.ErrorCode {
|
|
case
|
|
http3.ErrCodeNoError,
|
|
http3.ErrCodeRequestCanceled:
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func logFlushError(err error, r *http.Request) {
|
|
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
|
}
|