package rules import ( "errors" "fmt" "net/http" "reflect" "slices" "strings" "unicode" "github.com/goccy/go-yaml" "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog/log" "github.com/yusing/godoxy/internal/serialization" gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" "golang.org/x/net/http2" _ "unsafe" ) type ( /* Rules is a list of rules. Example: proxy.app1.rules: | - name: default do: | rewrite / /index.html serve /var/www/goaccess - name: ws on: | header Connection Upgrade header Upgrade websocket do: bypass proxy.app2.rules: | - name: default do: bypass - name: block POST and PUT on: method POST | method PUT do: error 403 Forbidden */ //nolint:recvcheck Rules []Rule // Rule represents a reverse proxy rule. // The `Do` field is executed when `On` matches. // // - A rule may have multiple lines in the `On` section. // - All `On` lines must match for the rule to trigger. // - Each line can have several checks—one match per line is enough for that line. Rule struct { Name string `json:"name"` On RuleOn `json:"on" swaggertype:"string"` Do Command `json:"do" swaggertype:"string"` } ) func isDefaultRule(rule Rule) bool { return rule.Name == "default" || rule.On.raw == OnDefault } func (rules Rules) Validate() gperr.Error { var defaultRulesFound []int for i := range rules { rule := rules[i] if isDefaultRule(rule) { defaultRulesFound = append(defaultRulesFound, i) } if rules[i].Name == "" { // set name to index if name is empty rules[i].Name = fmt.Sprintf("rule[%d]", i) } } if len(defaultRulesFound) > 1 { return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound)) } for i := range rules { r1 := rules[i] if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() { continue } sig1, ok := matcherSignature(r1.On.raw) if !ok { continue } for j := i + 1; j < len(rules); j++ { r2 := rules[j] if isDefaultRule(r2) || r2.On.phase.IsPostRule() { continue } sig2, ok := matcherSignature(r2.On.raw) if !ok || sig1 != sig2 { continue } return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j) } } return nil } func (rule Rule) doesTerminateInPre() bool { for _, cmd := range rule.Do.pre { handler, ok := cmd.(Handler) if !ok { continue } if handler.Terminates() { return true } } return false } func matcherSignature(raw string) (string, bool) { raw = strings.TrimSpace(raw) if raw == "" { return "", false } andParts := splitAnd(raw) if len(andParts) == 0 { return "", false } canonAnd := make([]string, 0, len(andParts)) for _, andPart := range andParts { orParts := splitPipe(andPart) if len(orParts) == 0 { continue } canonOr := make([]string, 0, len(orParts)) for _, atom := range orParts { subject, args, err := parse(strings.TrimSpace(atom)) if err != nil || subject == "" { return "", false } canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00")) } slices.Sort(canonOr) canonOr = slices.Compact(canonOr) canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")") } slices.Sort(canonAnd) canonAnd = slices.Compact(canonAnd) if len(canonAnd) == 0 { return "", false } return strings.Join(canonAnd, "&"), true } // Parse parses a rule configuration string. // It first tries the block syntax (if the string contains a top-level '{'), // then falls back to YAML syntax. func (rules *Rules) Parse(config string) error { config = strings.TrimSpace(config) if config == "" { return nil } // Prefer block syntax if it looks like block syntax. if hasTopLevelLBrace(config) { blockRules, err := parseBlockRules(config) if err == nil { *rules = blockRules return nil } // Fall through to YAML (backward compatibility). } // YAML fallback var anySlice []any yamlErr := yaml.Unmarshal([]byte(config), &anySlice) if yamlErr == nil { return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false) } // If YAML fails and we didn't try block syntax yet, try it now. blockRules, err := parseBlockRules(config) if err == nil { *rules = blockRules return nil } return err } // hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments. // Used to decide whether to prioritize the block syntax. func hasTopLevelLBrace(s string) bool { quote := rune(0) inLine := false inBlock := false for i := 0; i < len(s); i++ { c := s[i] if inLine { if c == '\n' { inLine = false } continue } if inBlock { if c == '*' && i+1 < len(s) && s[i+1] == '/' { inBlock = false i++ } continue } if quote != 0 { if quote != '`' && c == '\\' && i+1 < len(s) { i++ continue } if rune(c) == quote { quote = 0 } continue } switch c { case '\'', '"', '`': quote = rune(c) continue case '{': return true case '#': inLine = true continue case '/': if i+1 < len(s) && s[i+1] == '/' { inLine = true i++ continue } if i+1 < len(s) && s[i+1] == '*' { inBlock = true i++ continue } default: if unicode.IsSpace(rune(c)) { continue } } } return false } // BuildHandler returns a http.HandlerFunc that implements the rules. func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { if len(rules) == 0 { return up } var defaultRule *Rule var nonDefaultRules Rules for _, rule := range rules { if isDefaultRule(rule) { r := rule defaultRule = &r } else { nonDefaultRules = append(nonDefaultRules, rule) } } if len(nonDefaultRules) == 0 { if defaultRule == nil || defaultRule.Do.raw == CommandUpstream { return up } } execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error { return cmd.pre.ServeHTTP(w, r, up) } execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error { return cmd.post.ServeHTTP(w, r, up) } return func(w http.ResponseWriter, r *http.Request) { rm := httputils.NewResponseModifier(w) defer func() { if _, err := rm.FlushRelease(); err != nil { logError(err, r) } }() var hasError bool preRules := make(Rules, 0, len(nonDefaultRules)+1) if defaultRule != nil { preRules = append(preRules, *defaultRule) } preRules = append(preRules, nonDefaultRules...) executedPre := make([]bool, len(preRules)) terminatedInPre := make([]bool, len(preRules)) preTerminated := false for i, rule := range preRules { if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) { continue } if preTerminated { // Preserve post-only commands (e.g. logging) even after // pre-phase termination. if len(rule.Do.pre) == 0 { executedPre[i] = true } continue } executedPre[i] = true if err := execPreCommand(rule.Do, rm, r); err != nil { if errors.Is(err, errTerminateRule) { terminatedInPre[i] = true preTerminated = true continue } logError(err, r) hasError = true } } if !rm.HasStatus() { if hasError { http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } else { // call upstream if no WriteHeader or Write was called and no error occurred up(rm, r) } } // Run post commands for rules that actually executed in pre phase, // unless that same rule terminated in pre phase. for i, rule := range preRules { if !executedPre[i] || terminatedInPre[i] { continue } if err := execPostCommand(rule.Do, rm, r); err != nil { if errors.Is(err, errTerminateRule) { continue } logError(err, r) } } // Run true post-matcher rules after response is available. for _, rule := range nonDefaultRules { if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) { continue } // Post-rule matchers are only evaluated after upstream, so commands parsed // as "pre" for requirement purposes still need to run in this phase. if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil { if errors.Is(err, errTerminateRule) { continue } logError(err, r) } if err := execPostCommand(rule.Do, rm, r); err != nil { if errors.Is(err, errTerminateRule) { continue } logError(err, r) } } } } func (rule *Rule) String() string { return rule.Name } func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool { if rule.On.checker == nil { return true } return rule.On.Check(w, r) } //go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed var errStreamClosed error //go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected var errClientDisconnected error func logError(err error, r *http.Request) { if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) { return } if h2Err, ok := errors.AsType[http2.StreamError](err); ok { // ignore these errors if h2Err.Code == http2.ErrCodeStreamClosed { return } } if h3Err, ok := errors.AsType[*http3.Error](err); ok { // ignore these errors switch h3Err.ErrorCode { case http3.ErrCodeNoError, http3.ErrCodeRequestCanceled: return } } log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules") }