mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-28 11:51:53 +01:00
The default rule should runs only when no non-default pre rule matches, instead of running first as a baseline. This follows the old behavior as before the pr is established.: - Default rules act as fallback handlers that execute only when no matching non-default rule exists in the pre phase - IfElseBlockCommand now returns early when a condition matches with a nil Do block, instead of falling through to else blocks - Add nil check for auth handler to allow requests when no auth is configured - Fix unterminated environment variable parsing to preserve input Updates tests to verify the new fallback behavior where special rules suppress default rule execution.
452 lines
11 KiB
Go
452 lines
11 KiB
Go
package rules
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/quic-go/quic-go/http3"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/yusing/godoxy/internal/serialization"
|
|
gperr "github.com/yusing/goutils/errs"
|
|
httputils "github.com/yusing/goutils/http"
|
|
"golang.org/x/net/http2"
|
|
|
|
_ "unsafe"
|
|
)
|
|
|
|
type (
|
|
/*
|
|
Rules is a list of rules.
|
|
|
|
Example:
|
|
|
|
proxy.app1.rules: |
|
|
- name: default
|
|
do: |
|
|
rewrite / /index.html
|
|
serve /var/www/goaccess
|
|
- name: ws
|
|
on: |
|
|
header Connection Upgrade
|
|
header Upgrade websocket
|
|
do: bypass
|
|
|
|
proxy.app2.rules: |
|
|
- name: default
|
|
do: bypass
|
|
- name: block POST and PUT
|
|
on: method POST | method PUT
|
|
do: error 403 Forbidden
|
|
*/
|
|
//nolint:recvcheck
|
|
Rules []Rule
|
|
// Rule represents a reverse proxy rule.
|
|
// The `Do` field is executed when `On` matches.
|
|
//
|
|
// - A rule may have multiple lines in the `On` section.
|
|
// - All `On` lines must match for the rule to trigger.
|
|
// - Each line can have several checks—one match per line is enough for that line.
|
|
Rule struct {
|
|
Name string `json:"name"`
|
|
On RuleOn `json:"on" swaggertype:"string"`
|
|
Do Command `json:"do" swaggertype:"string"`
|
|
}
|
|
)
|
|
|
|
func isDefaultRule(rule Rule) bool {
|
|
return rule.Name == "default" || rule.On.raw == OnDefault
|
|
}
|
|
|
|
func (rules Rules) Validate() gperr.Error {
|
|
var defaultRulesFound []int
|
|
for i := range rules {
|
|
rule := rules[i]
|
|
if isDefaultRule(rule) {
|
|
defaultRulesFound = append(defaultRulesFound, i)
|
|
}
|
|
if rules[i].Name == "" {
|
|
// set name to index if name is empty
|
|
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
|
}
|
|
}
|
|
if len(defaultRulesFound) > 1 {
|
|
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
|
}
|
|
for i := range rules {
|
|
r1 := rules[i]
|
|
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
|
continue
|
|
}
|
|
sig1, ok := matcherSignature(r1.On.raw)
|
|
if !ok {
|
|
continue
|
|
}
|
|
for j := i + 1; j < len(rules); j++ {
|
|
r2 := rules[j]
|
|
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
|
continue
|
|
}
|
|
sig2, ok := matcherSignature(r2.On.raw)
|
|
if !ok || sig1 != sig2 {
|
|
continue
|
|
}
|
|
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (rule Rule) doesTerminateInPre() bool {
|
|
for _, cmd := range rule.Do.pre {
|
|
handler, ok := cmd.(Handler)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if handler.Terminates() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func matcherSignature(raw string) (string, bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", false
|
|
}
|
|
|
|
andParts := splitAnd(raw)
|
|
if len(andParts) == 0 {
|
|
return "", false
|
|
}
|
|
|
|
canonAnd := make([]string, 0, len(andParts))
|
|
for _, andPart := range andParts {
|
|
orParts := splitPipe(andPart)
|
|
if len(orParts) == 0 {
|
|
continue
|
|
}
|
|
canonOr := make([]string, 0, len(orParts))
|
|
for _, atom := range orParts {
|
|
subject, args, err := parse(strings.TrimSpace(atom))
|
|
if err != nil || subject == "" {
|
|
return "", false
|
|
}
|
|
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
|
}
|
|
slices.Sort(canonOr)
|
|
canonOr = slices.Compact(canonOr)
|
|
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
|
}
|
|
|
|
slices.Sort(canonAnd)
|
|
canonAnd = slices.Compact(canonAnd)
|
|
if len(canonAnd) == 0 {
|
|
return "", false
|
|
}
|
|
return strings.Join(canonAnd, "&"), true
|
|
}
|
|
|
|
// Parse parses a rule configuration string.
|
|
// It first tries the block syntax (if the string contains a top-level '{'),
|
|
// then falls back to YAML syntax.
|
|
func (rules *Rules) Parse(config string) error {
|
|
config = strings.TrimSpace(config)
|
|
if config == "" {
|
|
return nil
|
|
}
|
|
|
|
// Prefer block syntax if it looks like block syntax.
|
|
if hasTopLevelLBrace(config) {
|
|
blockRules, err := parseBlockRules(config)
|
|
if err == nil {
|
|
*rules = blockRules
|
|
return nil
|
|
}
|
|
// Fall through to YAML (backward compatibility).
|
|
}
|
|
|
|
// YAML fallback
|
|
var anySlice []any
|
|
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
|
if yamlErr == nil {
|
|
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
|
}
|
|
|
|
// If YAML fails and we didn't try block syntax yet, try it now.
|
|
blockRules, err := parseBlockRules(config)
|
|
if err == nil {
|
|
*rules = blockRules
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
|
// Used to decide whether to prioritize the block syntax.
|
|
func hasTopLevelLBrace(s string) bool {
|
|
quote := rune(0)
|
|
inLine := false
|
|
inBlock := false
|
|
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
|
|
if inLine {
|
|
if c == '\n' {
|
|
inLine = false
|
|
}
|
|
continue
|
|
}
|
|
if inBlock {
|
|
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
|
inBlock = false
|
|
i++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if quote != 0 {
|
|
if quote != '`' && c == '\\' && i+1 < len(s) {
|
|
i++
|
|
continue
|
|
}
|
|
if rune(c) == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
|
|
switch c {
|
|
case '\'', '"', '`':
|
|
quote = rune(c)
|
|
continue
|
|
case '{':
|
|
return true
|
|
case '#':
|
|
inLine = true
|
|
continue
|
|
case '/':
|
|
if i+1 < len(s) && s[i+1] == '/' {
|
|
inLine = true
|
|
i++
|
|
continue
|
|
}
|
|
if i+1 < len(s) && s[i+1] == '*' {
|
|
inBlock = true
|
|
i++
|
|
continue
|
|
}
|
|
default:
|
|
if unicode.IsSpace(rune(c)) {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
|
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|
if len(rules) == 0 {
|
|
return up
|
|
}
|
|
|
|
var defaultRule *Rule
|
|
|
|
var nonDefaultRules Rules
|
|
for _, rule := range rules {
|
|
if isDefaultRule(rule) {
|
|
r := rule
|
|
defaultRule = &r
|
|
} else {
|
|
nonDefaultRules = append(nonDefaultRules, rule)
|
|
}
|
|
}
|
|
|
|
if len(nonDefaultRules) == 0 {
|
|
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
|
return up
|
|
}
|
|
}
|
|
|
|
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
|
return cmd.pre.ServeHTTP(w, r, up)
|
|
}
|
|
|
|
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
|
return cmd.post.ServeHTTP(w, r, up)
|
|
}
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
rm := httputils.NewResponseModifier(w)
|
|
defer func() {
|
|
if _, err := rm.FlushRelease(); err != nil {
|
|
logFlushError(err, r)
|
|
}
|
|
}()
|
|
|
|
var hasError bool
|
|
|
|
executedPre := make([]bool, len(nonDefaultRules))
|
|
terminatedInPre := make([]bool, len(nonDefaultRules))
|
|
matchedNonDefaultPre := false
|
|
preTerminated := false
|
|
for i, rule := range nonDefaultRules {
|
|
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
|
continue
|
|
}
|
|
matchedNonDefaultPre = true
|
|
if preTerminated {
|
|
// Preserve post-only commands (e.g. logging) even after
|
|
// pre-phase termination.
|
|
if len(rule.Do.pre) == 0 {
|
|
executedPre[i] = true
|
|
}
|
|
continue
|
|
}
|
|
|
|
executedPre[i] = true
|
|
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
terminatedInPre[i] = true
|
|
preTerminated = true
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
hasError = true
|
|
}
|
|
}
|
|
|
|
// Default rule is a fallback: run only when no non-default pre rule matched.
|
|
defaultExecutedPre := false
|
|
defaultTerminatedInPre := false
|
|
if defaultRule != nil && !matchedNonDefaultPre && !defaultRule.On.phase.IsPostRule() && defaultRule.On.Check(rm, r) {
|
|
defaultExecutedPre = true
|
|
if err := execPreCommand(defaultRule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
defaultTerminatedInPre = true
|
|
} else {
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing pre rule (%s): %w", defaultRule.Do.raw, err)
|
|
}
|
|
hasError = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if !rm.HasStatus() {
|
|
if hasError {
|
|
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
|
up(rm, r)
|
|
}
|
|
}
|
|
|
|
// Run post commands for rules that actually executed in pre phase,
|
|
// unless that same rule terminated in pre phase.
|
|
for i, rule := range nonDefaultRules {
|
|
if !executedPre[i] || terminatedInPre[i] {
|
|
continue
|
|
}
|
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
}
|
|
}
|
|
if defaultExecutedPre && !defaultTerminatedInPre {
|
|
if err := execPostCommand(defaultRule.Do, rm, r); err != nil {
|
|
if !errors.Is(err, errTerminateRule) && isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing post rule (%s): %w", defaultRule.Do.raw, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Run true post-matcher rules after response is available.
|
|
for _, rule := range nonDefaultRules {
|
|
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
|
continue
|
|
}
|
|
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
|
// as "pre" for requirement purposes still need to run in this phase.
|
|
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
}
|
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
if isUnexpectedError(err) {
|
|
// will logged by logFlushError after FlushRelease
|
|
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rule *Rule) String() string {
|
|
return rule.Name
|
|
}
|
|
|
|
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
|
if rule.On.checker == nil {
|
|
return true
|
|
}
|
|
return rule.On.Check(w, r)
|
|
}
|
|
|
|
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
|
var errStreamClosed error
|
|
|
|
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
|
var errClientDisconnected error
|
|
|
|
func isUnexpectedError(err error) bool {
|
|
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
|
return false
|
|
}
|
|
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
|
// ignore these errors
|
|
if h2Err.Code == http2.ErrCodeStreamClosed {
|
|
return false
|
|
}
|
|
}
|
|
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
|
// ignore these errors
|
|
switch h3Err.ErrorCode {
|
|
case
|
|
http3.ErrCodeNoError,
|
|
http3.ErrCodeRequestCanceled:
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func logFlushError(err error, r *http.Request) {
|
|
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
|
}
|