diff --git a/internal/route/rules/README.md b/internal/route/rules/README.md index 2507b2d5..ba6c0ae1 100644 --- a/internal/route/rules/README.md +++ b/internal/route/rules/README.md @@ -145,7 +145,7 @@ Rules run in two phases: 1. **Pre phase** - Evaluate only request-based matchers (`path`, `method`, `header`, `remote`, etc.) in declaration order. - Execute matched rule `do` pre-commands in order. - - If a default rule exists (`name: default` or `on: default`), it is evaluated first as a baseline rule. + - If a default rule exists (`name: default` or `on: default`), it is a fallback and runs only when no non-default pre rule matches. - If a terminating action runs, stop: - remaining commands in that rule - all later pre-phase commands. @@ -552,16 +552,16 @@ Log context includes: `rule`, `alias`, `match_result` do: bypass ``` -### Default Rule (Baseline) +### Default Rule (Fallback) ```yaml -# Default runs first and can provide baseline behavior +# Default runs only if no non-default pre rule matches - name: default do: | remove resp_header X-Internal add resp_header X-Powered-By godoxy -# Specific rules can override or add to baseline behavior +# Matching rules suppress default - name: api routes on: path glob("/api/*") do: proxy http://api:8080 @@ -571,7 +571,7 @@ Log context includes: `rule`, `alias`, `match_result` do: set resp_header X-API true ``` -Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors. +Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors and both behave as fallback-only. ## Testing Notes diff --git a/internal/route/rules/block_parser.go b/internal/route/rules/block_parser.go index 6a6445ba..ee187a9b 100644 --- a/internal/route/rules/block_parser.go +++ b/internal/route/rules/block_parser.go @@ -77,6 +77,9 @@ func expandEnvVarsRaw(v string) (string, gperr.Error) { var err gperr.Error if inEnvVar { + // Write back the unterminated ${...} so the output matches the input. + buf.WriteString("${") + buf.WriteString(envVar.String()) err = ErrUnterminatedEnvVar } if len(missingEnvVars) > 0 { diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 0e69834c..4dbf7d70 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -103,6 +103,9 @@ var commands = map[string]struct { }, build: func(args any) HandlerFunc { return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + if authHandler == nil { // no auth handler configured, allow request to proceed + return nil + } if proceed := authHandler(w, r); !proceed { return errTerminateRule } diff --git a/internal/route/rules/do_blocks.go b/internal/route/rules/do_blocks.go index ada5fb3b..32df3d4d 100644 --- a/internal/route/rules/do_blocks.go +++ b/internal/route/rules/do_blocks.go @@ -67,14 +67,14 @@ func (c IfElseBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Req // If On.checker is nil, treat as unconditional. if br.On.checker == nil { if br.Do == nil { - continue + return nil } return Commands(br.Do).ServeHTTP(w, r, upstream) } - if br.Do == nil { - continue - } if br.On.checker.Check(w, r) { + if br.Do == nil { + return nil + } return Commands(br.Do).ServeHTTP(w, r, upstream) } } diff --git a/internal/route/rules/do_blocks_test.go b/internal/route/rules/do_blocks_test.go index 14396ee0..4f740cb7 100644 --- a/internal/route/rules/do_blocks_test.go +++ b/internal/route/rules/do_blocks_test.go @@ -10,7 +10,7 @@ import ( httputils "github.com/yusing/goutils/http" ) -func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoFallsThrough(t *testing.T) { +func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoNotFallsThrough(t *testing.T) { elseCalled := false cmd := IfElseBlockCommand{ Ifs: []IfBlockCommand{ @@ -36,10 +36,10 @@ func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoFallsThrough(t *testing.T err := cmd.ServeHTTP(rm, req, nil) require.NoError(t, err) - assert.True(t, elseCalled) + assert.False(t, elseCalled) } -func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoFallsThrough(t *testing.T) { +func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoNotFallsThrough(t *testing.T) { elseCalled := false cmd := IfElseBlockCommand{ Ifs: []IfBlockCommand{ @@ -69,5 +69,5 @@ func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoFallsThrough(t *test err := cmd.ServeHTTP(rm, req, nil) require.NoError(t, err) - assert.True(t, elseCalled) + assert.False(t, elseCalled) } diff --git a/internal/route/rules/http_flow_block_test.go b/internal/route/rules/http_flow_block_test.go index e4dcb83f..3f998afe 100644 --- a/internal/route/rules/http_flow_block_test.go +++ b/internal/route/rules/http_flow_block_test.go @@ -370,16 +370,45 @@ path /special { assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) assert.Empty(t, w1.Header().Get("X-Special-Handled")) - // Test special rule + default rule + // Test special rule (default should not run) req2 := httptest.NewRequest(http.MethodGet, "/special", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) - assert.Equal(t, "true", w2.Header().Get("X-Default-Applied")) + assert.Empty(t, w2.Header().Get("X-Default-Applied")) assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) } +func TestHTTPFlow_UnconditionalRuleSuppressesDefaultRule(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "upstream response") + + var rules Rules + err := parseRules(` +{ + set resp_header X-Unconditional true +} +default { + set resp_header X-Default-Applied true +} +path /never-match { + set resp_header X-Never-Match true +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/special", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "true", w.Header().Get("X-Unconditional")) + assert.Empty(t, w.Header().Get("X-Default-Applied")) + assert.Empty(t, w.Header().Get("X-Never-Match")) +} + func TestHTTPFlow_HeaderManipulation(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Echo back a header @@ -449,7 +478,7 @@ header X-Test-Header { // Public IP => public req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2.Header.Set("X-Test-Header", "1") - req2.RemoteAddr = "10.0.0.1:12345" + req2.RemoteAddr = "1.1.1.1:12345" w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) diff --git a/internal/route/rules/http_flow_yaml_test.go b/internal/route/rules/http_flow_yaml_test.go index dcd9be2c..7cced60f 100644 --- a/internal/route/rules/http_flow_yaml_test.go +++ b/internal/route/rules/http_flow_yaml_test.go @@ -380,7 +380,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRulesYAML(t *testing.T) { handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusUnauthorized, w2.Code) - assert.Equal(t, w2.Body.String(), "Unauthorized\n") + assert.Equal(t, "Unauthorized\n", w2.Body.String()) // Test authorized protected request req3 := httptest.NewRequest(http.MethodGet, "/protected", nil) @@ -432,13 +432,48 @@ func TestHTTPFlow_DefaultRuleYAML(t *testing.T) { assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) assert.Empty(t, w1.Header().Get("X-Special-Handled")) - // Test special rule + default rule + // Test special rule (default should not run) req2 := httptest.NewRequest(http.MethodGet, "/special", nil) w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) - assert.Equal(t, "true", w2.Header().Get("X-Default-Applied")) + assert.Empty(t, w2.Header().Get("X-Default-Applied")) + assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) +} + +func TestHTTPFlow_DefaultRuleWithOnDefaultYAML(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "upstream response") + + var rules Rules + err := parseRules(` +- name: default-on-rule + on: default + do: set resp_header X-Default-Applied true +- name: special-rule + on: path /special + do: set resp_header X-Special-Handled true +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test default rule on regular request + req1 := httptest.NewRequest(http.MethodGet, "/regular", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) + assert.Empty(t, w1.Header().Get("X-Special-Handled")) + + // Test special rule on matching request (default should not run) + req2 := httptest.NewRequest(http.MethodGet, "/special", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Empty(t, w2.Header().Get("X-Default-Applied")) assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) } diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index e050c4e7..e41b514d 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -56,7 +56,7 @@ var checkers = map[string]struct { help: Help{ command: OnDefault, description: makeLines( - "Select the default (baseline) rule.", + "Select the default (fallback) rule.", ), args: map[string]string{}, }, @@ -67,8 +67,8 @@ var checkers = map[string]struct { return phase, nil, nil }, builder: func(args any) CheckFunc { - return func(w *httputils.ResponseModifier, r *http.Request) bool { return false } - }, // this should never be called + return func(w *httputils.ResponseModifier, r *http.Request) bool { return true } + }, }, OnHeader: { help: Help{ diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index f69f47bd..c4c8b4b4 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -294,19 +294,15 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { var hasError bool - preRules := make(Rules, 0, len(nonDefaultRules)+1) - if defaultRule != nil { - preRules = append(preRules, *defaultRule) - } - preRules = append(preRules, nonDefaultRules...) - - executedPre := make([]bool, len(preRules)) - terminatedInPre := make([]bool, len(preRules)) + executedPre := make([]bool, len(nonDefaultRules)) + terminatedInPre := make([]bool, len(nonDefaultRules)) + matchedNonDefaultPre := false preTerminated := false - for i, rule := range preRules { + for i, rule := range nonDefaultRules { if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) { continue } + matchedNonDefaultPre = true if preTerminated { // Preserve post-only commands (e.g. logging) even after // pre-phase termination. @@ -331,6 +327,24 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { } } + // Default rule is a fallback: run only when no non-default pre rule matched. + defaultExecutedPre := false + defaultTerminatedInPre := false + if defaultRule != nil && !matchedNonDefaultPre && !defaultRule.On.phase.IsPostRule() && defaultRule.On.Check(rm, r) { + defaultExecutedPre = true + if err := execPreCommand(defaultRule.Do, rm, r); err != nil { + if errors.Is(err, errTerminateRule) { + defaultTerminatedInPre = true + } else { + if isUnexpectedError(err) { + // will logged by logFlushError after FlushRelease + rm.AppendError("executing pre rule (%s): %w", defaultRule.Do.raw, err) + } + hasError = true + } + } + } + if !rm.HasStatus() { if hasError { http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -341,7 +355,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { // Run post commands for rules that actually executed in pre phase, // unless that same rule terminated in pre phase. - for i, rule := range preRules { + for i, rule := range nonDefaultRules { if !executedPre[i] || terminatedInPre[i] { continue } @@ -355,6 +369,14 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { } } } + if defaultExecutedPre && !defaultTerminatedInPre { + if err := execPostCommand(defaultRule.Do, rm, r); err != nil { + if !errors.Is(err, errTerminateRule) && isUnexpectedError(err) { + // will logged by logFlushError after FlushRelease + rm.AppendError("executing post rule (%s): %w", defaultRule.Do.raw, err) + } + } + } // Run true post-matcher rules after response is available. for _, rule := range nonDefaultRules {