diff --git a/internal/route/rules/http_flow_block_test.go b/internal/route/rules/http_flow_block_test.go index 3f998afe..7be1c5eb 100644 --- a/internal/route/rules/http_flow_block_test.go +++ b/internal/route/rules/http_flow_block_test.go @@ -420,7 +420,7 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) { var rules Rules err := parseRules(` -default { +{ remove resp_header X-Secret add resp_header X-Custom-Header custom-value } diff --git a/internal/route/rules/http_flow_yaml_test.go b/internal/route/rules/http_flow_yaml_test.go index 7cced60f..0cfcb143 100644 --- a/internal/route/rules/http_flow_yaml_test.go +++ b/internal/route/rules/http_flow_yaml_test.go @@ -895,7 +895,7 @@ func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMat upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upstreamCalled = true w.WriteHeader(http.StatusOK) - w.Write([]byte("upstream")) + fmt.Fprint(w, "upstream") }) var rules Rules @@ -903,7 +903,7 @@ func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMat - on: path / do: error 403 blocked - on: path / - do: set resp_header X-Late should-not-run + do: set resp_header X-Late should-run - on: status 4xx do: set resp_header X-Post true `, &rules) @@ -918,7 +918,7 @@ func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMat assert.False(t, upstreamCalled) assert.Equal(t, 403, w.Code) assert.Equal(t, "blocked\n", w.Body.String()) - assert.Equal(t, "should-not-run", w.Header().Get("X-Late")) + assert.Equal(t, "should-run", w.Header().Get("X-Late")) assert.Equal(t, "true", w.Header().Get("X-Post")) } diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index e41b514d..bb6d6925 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -655,7 +655,9 @@ func forEachPipePart(s string, fn func(part string)) { } } if start < len(s) { - fn(strings.TrimSpace(s[start:])) + if part := strings.TrimSpace(s[start:]); part != "" { + fn(part) + } } } diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index c4c8b4b4..f8b7e93e 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -103,16 +103,53 @@ func (rules Rules) Validate() gperr.Error { } func (rule Rule) doesTerminateInPre() bool { - for _, cmd := range rule.Do.pre { - handler, ok := cmd.(Handler) - if !ok { - continue + return commandsTerminateInPre(rule.Do.pre) +} + +func commandsTerminateInPre(cmds []CommandHandler) bool { + return slices.ContainsFunc(cmds, commandTerminatesInPre) +} + +func commandTerminatesInPre(cmd CommandHandler) bool { + switch c := cmd.(type) { + case Handler: + return c.Terminates() + case *Handler: + return c.Terminates() + case IfBlockCommand: + return ruleOnAlwaysTrue(c.On) && commandsTerminateInPre(c.Do) + case *IfBlockCommand: + return c != nil && ruleOnAlwaysTrue(c.On) && commandsTerminateInPre(c.Do) + case IfElseBlockCommand: + return ifElseBlockTerminatesInPre(c) + case *IfElseBlockCommand: + return c != nil && ifElseBlockTerminatesInPre(*c) + default: + return false + } +} + +func ifElseBlockTerminatesInPre(cmd IfElseBlockCommand) bool { + hasFallback := len(cmd.Else) > 0 + for _, br := range cmd.Ifs { + if !commandsTerminateInPre(br.Do) { + return false } - if handler.Terminates() { - return true + if ruleOnAlwaysTrue(br.On) { + hasFallback = true } } - return false + if !hasFallback { + return false + } + if len(cmd.Else) > 0 && !commandsTerminateInPre(cmd.Else) { + return false + } + return true +} + +func ruleOnAlwaysTrue(on RuleOn) bool { + return strings.TrimSpace(on.raw) == OnDefault || on.checker == nil } func matcherSignature(raw string) (string, bool) { @@ -162,14 +199,18 @@ func (rules *Rules) Parse(config string) error { return nil } + blockTried := false + var blockErr gperr.Error + // Prefer block syntax if it looks like block syntax. if hasTopLevelLBrace(config) { + blockTried = true blockRules, err := parseBlockRules(config) if err == nil { *rules = blockRules return nil } - // Fall through to YAML (backward compatibility). + blockErr = err } // YAML fallback @@ -179,13 +220,16 @@ func (rules *Rules) Parse(config string) error { return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false) } - // If YAML fails and we didn't try block syntax yet, try it now. - blockRules, err := parseBlockRules(config) - if err == nil { - *rules = blockRules - return nil + // If YAML fails and we haven't tried block syntax yet, try it now. + if !blockTried { + blockRules, err := parseBlockRules(config) + if err == nil { + *rules = blockRules + return nil + } + blockErr = err } - return err + return blockErr } // hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments. diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go index fc590e0d..b1f923fc 100644 --- a/internal/route/rules/rules_test.go +++ b/internal/route/rules/rules_test.go @@ -69,6 +69,55 @@ header Host example.com { set resp_header X-Test first } +header Host example.com { + error 403 "forbidden" +} +`, + want: nil, + }, + { + name: "same condition with terminating handler inside if block", + rules: ` +header Host example.com { + @default { + error 404 "not found" + } +} + +header Host example.com { + error 403 "forbidden" +} +`, + want: ErrDeadRule, + }, + { + name: "same condition with terminating handler across if else block", + rules: ` +header Host example.com { + @method GET { + error 404 "not found" + } else { + redirect https://example.com + } +} + +header Host example.com { + error 403 "forbidden" +} +`, + want: ErrDeadRule, + }, + { + name: "same condition with non terminating if branch in if else block", + rules: ` +header Host example.com { + @method GET { + set resp_header X-Test first + } else { + error 404 "not found" + } +} + header Host example.com { error 403 "forbidden" } @@ -128,3 +177,15 @@ func TestHasTopLevelLBrace(t *testing.T) { }) } } + +func TestRulesParse_BlockTriedThenYAMLFails_ReturnsBlockError(t *testing.T) { + input := `default {` + + _, blockErr := parseBlockRules(input) + require.Error(t, blockErr) + + var rules Rules + err := rules.Parse(input) + require.Error(t, err) + assert.Equal(t, blockErr.Error(), err.Error()) +}