diff --git a/internal/api/v1/route/playground.go b/internal/api/v1/route/playground.go index 39d40549..24f2d120 100644 --- a/internal/api/v1/route/playground.go +++ b/internal/api/v1/route/playground.go @@ -298,7 +298,7 @@ func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) { On: onStr, Do: doStr, ValidationError: validationErr, - IsResponseRule: rule.IsResponseRule(), + // IsResponseRule: rule.Requirement()&rules.RequirementFlagResponse != 0, }) // Only add valid rules to execution list diff --git a/internal/route/rules/README.md b/internal/route/rules/README.md index dbb4b358..2507b2d5 100644 --- a/internal/route/rules/README.md +++ b/internal/route/rules/README.md @@ -36,15 +36,15 @@ type Rule struct { } type RuleOn struct { - raw string - checker Checker - isResponseChecker bool + raw string + checker Checker + phase PhaseFlag } type Command struct { - raw string - exec CommandHandler - isResponseHandler bool + raw string + pre Commands + post Commands } ``` @@ -59,6 +59,9 @@ func ParseRules(config string) (Rules, error) // ValidateRules validates rule syntax func ValidateRules(config string) error + +// Validate validates rule semantics (e.g., prevents multiple default rules) +func (rules Rules) Validate() gperr.Error ``` ## Architecture @@ -122,16 +125,52 @@ sequenceDiagram Pre->>Pre: Execute handler alt Terminating action Pre-->>Req: Response - Return-->>Req: Return immediately + Note right of Pre: Stop remaining pre commands end end - Req->>Proxy: Forward request - Proxy-->>Req: Response - Req->>Post: Check post-rules - Post->>Post: Execute handlers - Post-->>Req: Modified response + opt No pre termination + Req->>Proxy: Forward request + Proxy-->>Req: Response + end + Req->>Post: Run scheduled post commands + Req->>Post: Evaluate response matchers + Post->>Post: Execute matched post handlers + Post-->>Req: Final response ``` +### Execution Model (Authoritative) + +Rules run in two phases: + +1. **Pre phase** + - Evaluate only request-based matchers (`path`, `method`, `header`, `remote`, etc.) in declaration order. + - Execute matched rule `do` pre-commands in order. + - If a default rule exists (`name: default` or `on: default`), it is evaluated first as a baseline rule. + - If a terminating action runs, stop: + - remaining commands in that rule + - all later pre-phase commands. + - Exception: rules that only contain post commands (no pre commands) are still scheduled for post phase. + +2. **Upstream phase** + - Upstream is called only if pre phase did not terminate. + +3. **Post phase** + - Run post-commands for rules whose pre phase executed, except rules that terminated in pre. + - Then evaluate response-based matchers (`status`, `resp_header`) and execute their `do` commands. + - Response-based rules run even when the response was produced in pre phase. + +**Important:** termination is explicit by command semantics, not inferred from status-code mutation. + +### Phase Flags + +Rule and command parsing tracks phase requirements via `PhaseFlag`: + +- `PhasePre` +- `PhasePost` +- `PhasePre | PhasePost` (combined) + +Combined flags are expected for nested/compound commands and variable templates that may need both request and response context. + ### Condition Matchers | Matcher | Type | Description | @@ -166,22 +205,22 @@ path regex("/api/v[0-9]+/.*") // regex pattern **Terminating Actions** (stop processing): -| Command | Description | -| ------------------------ | ---------------------- | -| `error ` | Return HTTP error | -| `redirect ` | Redirect to URL | -| `serve ` | Serve local files | -| `route ` | Route to another route | -| `proxy ` | Proxy to upstream | +| Command | Description | +| ------------------------------ | ------------------------------------- | +| `upstream` / `bypass` / `pass` | Call upstream and terminate pre-phase | +| `error ` | Return HTTP error | +| `redirect ` | Redirect to URL | +| `serve ` | Serve local files | +| `route ` | Route to another route | +| `proxy ` | Proxy to upstream | +| `require_basic_auth ` | Return 401 challenge | **Non-Terminating Actions** (modify and continue): | Command | Description | | ------------------------------ | ---------------------- | -| `pass` / `bypass` | Pass through unchanged | | `rewrite ` | Rewrite request path | | `require_auth` | Require authentication | -| `require_basic_auth ` | Basic auth challenge | | `set ` | Set header/variable | | `add ` | Add header/variable | | `remove ` | Remove header/variable | @@ -208,6 +247,166 @@ rules: action2 ``` +### Rule Configuration (Block Syntax) + +This is an alternative (and will eventually be the primary) syntax for rules that avoids YAML. +It keeps the **inner** `on` and `do` DSLs exactly the same (same matchers, same commands, same optional quotes), but wraps each rule in a `{ ... }` block. + +#### Key ideas + +- A rule is: + - `default { }`, + - `{ }`, or + - ` { }` +- Comments are supported: + - line comment: `// ...` (to end of line) + - line comment: `# ...` (to end of line, for YAML familiarity) + - block comment: `/* ... */` (may span multiple lines) + - Comments are ignored **only when outside quotes** (`"`, `'` or backticks). + - Environment variable syntax: `${NAME}` is supported by the inner DSL parser in [`parse()`](internal/route/rules/parser.go:34). + Block-syntax rule: + - In `on` (rule header): `${...}` must be inside quotes/backticks. + - In `do` (rule body): `${...}` may be unquoted; the outer parser must treat `${...}` as an opaque token so braces inside it are not structural. + +#### Grammar sketch (EBNF-ish) + +```text +file := { ws | comment | rule } +rule := default_rule | unconditional_rule | conditional_rule + +default_rule := 'default' ws* block +unconditional_rule := ws* block +conditional_rule := on_expr ws* block + +block := '{' do_body '}' + +// on_expr and do_body are raw text regions. +// The outer parser only needs to: +// - find the top-level '{' to start a rule block +// - find the matching top-level '}' to end it +// while respecting quotes and comments. +``` + +#### Elif/Else Chain Grammar + +```text +// Elif/Else chains can appear in do_body +do_stmt := command_line | nested_block | elif_else_chain +elif_else_chain := nested_block { elif_clause } [else_clause] +elif_clause := 'elif' ws* on_expr ws* '{' do_body '}' +else_clause := 'else' ws* '{' do_body '}' +``` + +#### Nested blocks (inline conditionals inside `do`) + +Inside a rule body (`do_body`), you can write **nested blocks** that start with `@`: + +```text +do_stmt := command_line | nested_block | elif_else_chain + +nested_block := '@' on_expr ws* '{' do_body '}' +``` + +Notes: + +- A nested block is only recognized when `@` is the **first non-space character on a line**. +- `on_expr` uses the same syntax as rule `on` (supports `|`, `&`, quoting/backticks, matcher functions, etc.). +- The nested block executes **in sequence**, at the point where it appears in the parent `do` list. +- Nested blocks are evaluated in the same phase the parent rule runs (no special phase promotion). +- Nested blocks can be chained with `elif`/`else` for conditional execution (see Elif/Else Chains section). + +Example: + +```go +default { + remove resp_header X-Secret + add resp_header X-Custom-Header custom-value +} + +header X-Test-Header { + set header X-Remote-Type public + @remote 127.0.0.1 | remote 192.168.0.0/16 { + set header X-Remote-Type private + } +} +``` + +#### Elif/Else Chains + +You can chain multiple conditions using `elif` and provide a fallback with `else`. +The `elif`/`else` keywords must appear on the same line as the preceding closing brace (`}`). + +```go +header X-Test-Header { + @method GET { + set header X-Mode get + } elif method POST { + set header X-Mode post + } else { + set header X-Mode other + } +} +``` + +Notes: + +- `elif` and `else` must be on the same line as the preceding `}`. +- Multiple `elif` branches are allowed; only one `else` is allowed. +- The entire chain is evaluated in sequence; the first matching branch executes. +- Elif/else chains can only be used within nested blocks (starting with `@`). +- Each `elif` clause must have its own condition expression and block. +- The `else` clause is optional and provides a default action when no conditions match. + +#### Examples + +Basic default rule: + +```go +default { + bypass +} +``` + +WebSocket upgrade routing: + +```bash +# WebSocket requests +header Connection Upgrade & +header Upgrade websocket { + route ws-api + log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name" +} +``` + +Block comments: + +```go +/* protect admin area */ +path glob("/admin/*") { + require_auth +} +``` + +Always log the request + +```bash +{ + log info /dev/stdout "Request $req_method $req_path" +} +``` + +#### Notes and constraints + +- The block syntax uses `{` and `}` as structure delimiters at **top-level** (outside quotes/comments). + - Braces inside quoted strings (including backticks) are not structural. + - `${...}` handling: + - `on`: must be quoted/backticked + - `do`: may be unquoted + Preferred style: always write env vars as `${NAME}` rather than a bare `$NAME`. + - If you need literal `{` or `}` outside quotes/backticks (for example unquoted templates like `{{ ... }}`), wrap that argument in quotes/backticks so the outer parser does not treat it as structure. +- Rule naming remains minimal: if no explicit name is provided by the syntax, it will behave like the current YAML behavior (empty name becomes `rule[index]` in [`Rules.BuildHandler()`](internal/route/rules/rules.go:75)). +- YAML remains supported as a fallback for backward compatibility. + ### Condition Syntax ```yaml @@ -215,12 +414,13 @@ rules: on: path /api/users # Multiple conditions (AND) -on: | - header Authorization Bearer - & path /api/admin/* +on: header Authorization Bearer & path glob("/api/admin/*") # Negation -on: !path /public/* +on: !path glob("/public/*") + +# Negation on matcher +on: path !glob("/public/*") # OR within a line on: method GET | method POST @@ -228,21 +428,21 @@ on: method GET | method POST ### Variable Substitution -```go -// Static variables -$req_method // Request method -$req_host // Request host -$req_path // Request path -$status_code // Response status -$remote_host // Client IP +```bash +# Static variables +$req_method # Request method +$req_host # Request host +$req_path # Request path +$status_code # Response status +$remote_host # Client IP -// Dynamic variables -$header(Name) // Request header -$header(Name, index) // Header at index -$arg(Name) // Query argument -$form(Name) // Form field +# Dynamic variables +$header(Name) # Request header +$header(Name, index) # Header at index +$arg(Name) # Query argument +$form(Name) # Form field -// Environment variables +# Environment variables ${ENV_VAR} ``` @@ -277,12 +477,13 @@ Log context includes: `rule`, `alias`, `match_result` ## Failure Modes and Recovery -| Failure | Behavior | Recovery | -| ------------------- | ------------------------- | ---------------------------------- | -| Invalid rule syntax | Route validation fails | Fix YAML syntax | -| Missing variables | Variable renders as empty | Check variable sources | -| Rule timeout | Request times out | Increase timeout or simplify rules | -| Auth failure | Returns 401/403 | Fix credentials | +| Failure | Behavior | Recovery | +| ---------------------- | ------------------------- | ---------------------------------- | +| Invalid rule syntax | Route validation fails | Fix YAML syntax | +| Multiple default rules | Route validation fails | Remove duplicate default rules | +| Missing variables | Variable renders as empty | Check variable sources | +| Rule timeout | Request times out | Increase timeout or simplify rules | +| Auth failure | Returns 401/403 | Fix credentials | ## Usage Examples @@ -297,11 +498,11 @@ Log context includes: `rule`, `alias`, `match_result` ```yaml - name: api proxy - on: path /api/* + on: path glob("/api/*") do: proxy http://api-backend:8080 - name: static files - on: path /static/* + on: path glob("/static/*") do: serve /var/www/static ``` @@ -309,11 +510,11 @@ Log context includes: `rule`, `alias`, `match_result` ```yaml - name: admin protection - on: path /admin/* + on: path glob("/admin/*") do: require_auth - name: basic auth for API - on: path /api/* + on: path glob("/api/*") do: require_basic_auth "API Access" ``` @@ -321,7 +522,7 @@ Log context includes: `rule`, `alias`, `match_result` ```yaml - name: rewrite API v1 - on: path /v1/* + on: path glob("/v1/*") do: | rewrite /v1 /api/v1 proxy http://backend:8080 @@ -351,6 +552,27 @@ Log context includes: `rule`, `alias`, `match_result` do: bypass ``` +### Default Rule (Baseline) + +```yaml +# Default runs first and can provide baseline behavior +- name: default + do: | + remove resp_header X-Internal + add resp_header X-Powered-By godoxy + +# Specific rules can override or add to baseline behavior +- name: api routes + on: path glob("/api/*") + do: proxy http://api:8080 + +- name: api marker + on: path glob("/api/*") + do: set resp_header X-API true +``` + +Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors. + ## Testing Notes - Unit tests for all matchers and actions diff --git a/internal/route/rules/block_parser.go b/internal/route/rules/block_parser.go new file mode 100644 index 00000000..6a6445ba --- /dev/null +++ b/internal/route/rules/block_parser.go @@ -0,0 +1,406 @@ +package rules + +import ( + "strings" + "unicode" + + "github.com/yusing/goutils/env" + gperr "github.com/yusing/goutils/errs" +) + +func getStringBuffer(size int) *strings.Builder { + var buf strings.Builder + if size > 0 { + buf.Grow(size) + } + return &buf +} + +// expandEnvVarsRaw expands ${NAME} in-place using env.LookupEnv (prefix-aware). +func expandEnvVarsRaw(v string) (string, gperr.Error) { + buf := getStringBuffer(len(v)) + envVar := getStringBuffer(0) + + var missingEnvVars []string + inEnvVar := false + expectingBrace := false + + for _, r := range v { + if expectingBrace && r != '{' && r != '$' { + buf.WriteRune('$') + expectingBrace = false + } + switch r { + case '$': + if expectingBrace { + buf.WriteRune('$') + expectingBrace = false + } else { + expectingBrace = true + } + case '{': + if expectingBrace { + inEnvVar = true + expectingBrace = false + envVar.Reset() + } else { + buf.WriteRune(r) + } + case '}': + if inEnvVar { + envValue, ok := env.LookupEnv(envVar.String()) + if !ok { + missingEnvVars = append(missingEnvVars, envVar.String()) + } else { + buf.WriteString(envValue) + } + inEnvVar = false + } else { + buf.WriteRune(r) + } + default: + if expectingBrace { + buf.WriteRune('$') + expectingBrace = false + } + if inEnvVar { + envVar.WriteRune(r) + } else { + buf.WriteRune(r) + } + } + } + + if expectingBrace { + buf.WriteRune('$') + } + + var err gperr.Error + if inEnvVar { + err = ErrUnterminatedEnvVar + } + if len(missingEnvVars) > 0 { + err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...))) + } + return buf.String(), err +} + +// parseBlockRules parses the block-syntax rule format. +// Grammar: +// +// file := { ws | comment | rule } +// rule := default_rule | conditional_rule +// default_rule := 'default' ws* block +// conditional_rule := on_expr ws* block +// block := '{' do_body '}' +// +// Where: +// - on_expr is passed verbatim to RuleOn.Parse() +// - do_body is passed verbatim to Command.Parse() +// +// Comments (ignored outside quotes/backticks): +// - line comment: // ... or # ... +// - block comment: /* ... */ +// +// Brace handling: +// - Braces inside quotes/backticks are ignored +// - Braces inside ${...} (env vars) are ignored in do_body +// - Braces in on_expr are not ignored (env vars must be quoted in on_expr) +// +//nolint:dupword +func parseBlockRules(src string) (Rules, gperr.Error) { + var rules Rules + var errs gperr.Builder + + pos := 0 + length := len(src) + t := newTokenizer(src) + + for pos < length { + // Skip whitespace/comments between rules. + newPos, skipErr := t.skipComments(pos, true, true) + if skipErr != nil { + return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos) + } + pos = newPos + if pos >= length { + break + } + + // Stray closing brace at top-level: keep parsing but mark invalid so Rules.Validate() fails. + if src[pos] == '}' { + return nil, ErrInvalidBlockSyntax.Withf("unmatched '}' at position %d", pos) + } + + // Parse rule header (default, unconditional, or on_expr) + headerStart := pos + header := parseRuleHeader(&t, src, &pos, length) + headerStr := src[headerStart:pos] + + // Skip whitespace/comments before '{' (default header may end before '{'). + newPos, skipErr = t.skipComments(pos, false, true) + if skipErr != nil { + return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos) + } + pos = newPos + + if pos >= length || src[pos] != '{' { + errs.AddSubjectf(ErrInvalidBlockSyntax, "expected '{' after rule header %q", headerStr) + return nil, errs.Error() + } + + // Find matching '}' (respecting quotes and env vars in do_body) + bodyStart := pos + 1 + bodyEnd, err := t.findMatchingBrace(bodyStart) + if err != nil { + errs.AddSubjectf(err, "rule header %q", headerStr) + return nil, errs.Error() + } + pos = bodyEnd + 1 + + onExpr := header + + doBody := "" + if bodyStart < bodyEnd { + doBody = src[bodyStart:bodyEnd] + } + // Normalize do body for the inner DSL parser: + // - strip comments (outside quotes/backticks) + // - trim block whitespace/indentation + // - expand ${ENV} in-place so cmd.raw is usable/debuggable + doBody, err = preprocessDoBody(doBody) + if err != nil { + errs.AddSubjectf(err, "rule header %q", headerStr) + return nil, errs.Error() + } + + rule := Rule{ + Name: "", // auto-generate if empty + On: RuleOn{}, + Do: Command{}, + } + + // Header semantics: + // - "default" => default rule (matched when no other rules are matched) + // - "" => unconditional rule (always matches) + // - otherwise => conditional rule (on expression) + switch onExpr { + case "default": + rule.On.raw = OnDefault + case "": + // leave rule.On as zero value => checker=nil => always matches + default: + if parseErr := rule.On.Parse(onExpr); parseErr != nil { + errs.AddSubjectf(parseErr, "on") + } + } + + if doBody != "" { + if parseErr := rule.Do.Parse(doBody); parseErr != nil { + errs.AddSubjectf(parseErr, "do") + } + } + + if errs.HasError() { + return nil, errs.Error() + } + + rules = append(rules, rule) + } + + return rules, nil +} + +func preprocessDoBody(doBody string) (string, gperr.Error) { + doBody = strings.TrimSpace(doBody) + if doBody == "" { + return "", nil + } + + normalized := doBody + // If comments are possible, strip them first while preserving line breaks. + if strings.ContainsAny(normalized, "#/") { + stripped, err := stripCommentsPreserveNewlines(normalized) + if err != nil { + return "", err + } + normalized = stripped + } + + // Drop lines that are empty after trimming, while preserving indentation of non-empty lines. + out := getStringBuffer(len(normalized)) + + lineStart := 0 + wroteLine := false + for i := 0; i <= len(normalized); i++ { + if i < len(normalized) && normalized[i] != '\n' { + continue + } + line := normalized[lineStart:i] + if strings.TrimSpace(line) != "" { + if wroteLine { + out.WriteByte('\n') + } + out.WriteString(line) + wroteLine = true + } + lineStart = i + 1 + } + + if !wroteLine { + return "", nil + } + normalized = out.String() + + // Expand env vars to keep Command.raw consistent with parsed semantics. + if !strings.Contains(normalized, "${") { + return normalized, nil + } + expanded, err := expandEnvVarsRaw(normalized) + if err != nil { + return "", err + } + return expanded, nil +} + +// stripCommentsPreserveNewlines removes //, #, and /* */ comments outside quotes/backticks. +// It preserves newlines so command line boundaries remain intact. +func stripCommentsPreserveNewlines(src string) (string, gperr.Error) { + if !strings.ContainsAny(src, "#/") { + return src, nil + } + + out := getStringBuffer(len(src)) + + quote := rune(0) + inLine := false + inBlock := false + atLineStart := true + prevIsSpace := true + + for i := 0; i < len(src); { + c := src[i] + + if inLine { + if c == '\n' { + inLine = false + out.WriteByte('\n') + atLineStart = true + prevIsSpace = true + } + i++ + continue + } + if inBlock { + if c == '\n' { + out.WriteByte('\n') + atLineStart = true + prevIsSpace = true + i++ + continue + } + if c == '*' && i+1 < len(src) && src[i+1] == '/' { + inBlock = false + i += 2 + continue + } + i++ + continue + } + + if quote != 0 { + out.WriteByte(c) + if c == '\\' && i+1 < len(src) { + // Write next char and skip it (escape sequence) + i++ + out.WriteByte(src[i]) + atLineStart = false + prevIsSpace = false + i++ + continue + } + if rune(c) == quote { + quote = 0 + } + if c == '\n' { + atLineStart = true + prevIsSpace = true + } else { + atLineStart = false + prevIsSpace = unicode.IsSpace(rune(c)) + } + i++ + continue + } + + // Not in quote/comment. + switch c { + case '\'', '"', '`': + quote = rune(c) + out.WriteByte(c) + atLineStart = false + prevIsSpace = false + i++ + continue + case '#': + if atLineStart || prevIsSpace { + inLine = true + i++ + continue + } + case '/': + if i+1 < len(src) { + n := src[i+1] + if (atLineStart || prevIsSpace) && n == '/' { + inLine = true + i += 2 + continue + } + if (atLineStart || prevIsSpace) && n == '*' { + inBlock = true + i += 2 + continue + } + } + } + + out.WriteByte(c) + if c == '\n' { + atLineStart = true + prevIsSpace = true + } else { + atLineStart = false + prevIsSpace = unicode.IsSpace(rune(c)) + } + i++ + } + + if inBlock { + return "", ErrInvalidBlockSyntax.Withf("unterminated block comment") + } + return out.String(), nil +} + +// parseRuleHeader parses the rule header (default or on expression). +// Returns the header string, or "" if parsing failed. +func parseRuleHeader(t *Tokenizer, src string, pos *int, length int) string { + start := *pos + + // Check for 'default' keyword + if *pos+7 <= length && src[*pos:*pos+7] == "default" { + next := *pos + 7 + if next >= length || unicode.IsSpace(rune(src[next])) { + *pos = next + return "default" + } + } + + // Parse on expression until we hit '{' outside quotes. + bracePos, err := t.scanToBrace(*pos) + if err != nil { + *pos = length + return strings.TrimSpace(src[start:*pos]) + } + *pos = bracePos + return strings.TrimSpace(src[start:*pos]) +} diff --git a/internal/route/rules/block_parser_bench_test.go b/internal/route/rules/block_parser_bench_test.go new file mode 100644 index 00000000..1550a1c5 --- /dev/null +++ b/internal/route/rules/block_parser_bench_test.go @@ -0,0 +1,48 @@ +package rules + +import "testing" + +func BenchmarkParseBlockRules(b *testing.B) { + const rulesString = ` +default { + remove resp_header X-Secret + add resp_header X-Custom-Header custom-value +} + +header X-Test-Header { + set header X-Remote-Type public + @remote 127.0.0.1 | remote 192.168.0.0/16 { + set header X-Remote-Type private + } +} + +path glob(/api/admin/*) { + @cookie session-id { + set header X-Session-ID $cookie(session-id) + } +} + +!remote 192.168.0.0/16 { + @!header X-User-Role admin & !header X-User-Role user { + error 403 "Access denied" + } elif remote 127.0.0.1 { + @header X-User-Role staff { + set header X-User-Role staff + } + } else { + error 403 "Access denied" + } +} +` + + var rules Rules + err := rules.Parse(rulesString) + if err != nil { + b.Fatal(err) + } + + for b.Loop() { + var rules Rules + _ = rules.Parse(rulesString) + } +} diff --git a/internal/route/rules/block_parser_test.go b/internal/route/rules/block_parser_test.go new file mode 100644 index 00000000..30a9de67 --- /dev/null +++ b/internal/route/rules/block_parser_test.go @@ -0,0 +1,339 @@ +package rules + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/serialization" + httputils "github.com/yusing/goutils/http" +) + +func testParseRules(t *testing.T, data string) Rules { + t.Helper() + + var rules Rules + convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules)) + require.True(t, convertible) + require.NoError(t, err) + return rules +} + +func testParseRulesError(t *testing.T, data string) error { + t.Helper() + + var rules Rules + convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules)) + require.True(t, convertible) + return err +} + +func TestParseBlockRules_DefaultRule(t *testing.T) { + rules := testParseRules(t, `default { + upstream +}`) + require.Len(t, rules, 1) + assert.Equal(t, OnDefault, rules[0].On.raw) + assert.Equal(t, "upstream", rules[0].Do.raw) + assert.True(t, rules[0].Do.raw == CommandUpstream) +} + +func TestParseBlockRules_ConditionalRule(t *testing.T) { + rules := testParseRules(t, `path glob(/api/*) { + proxy http://localhost:8080 +}`) + require.Len(t, rules, 1) + assert.Equal(t, "path glob(/api/*)", rules[0].On.raw) + assert.Equal(t, "proxy http://localhost:8080", rules[0].Do.raw) + require.Len(t, rules[0].Do.pre, 1) + _, ok := rules[0].Do.pre[0].(Handler) + require.True(t, ok) + require.Len(t, rules[0].Do.post, 0) +} + +func TestParseBlockRules_MultipleRules(t *testing.T) { + rules := testParseRules(t, `default { + bypass +} + +path /api/* { + proxy http://localhost:8080 +} + +header Connection Upgrade & +header Upgrade websocket { + route ws-api + log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name" +}`) + require.Len(t, rules, 3) + + // Default rule + assert.Equal(t, OnDefault, rules[0].On.raw) + assert.Equal(t, "bypass", rules[0].Do.raw) + + // API rule + assert.Equal(t, "path /api/*", rules[1].On.raw) + assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw) + + // WebSocket rule + assert.Equal(t, "header Connection Upgrade &\nheader Upgrade websocket", rules[2].On.raw) + assert.Equal(t, `route ws-api + log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"`, rules[2].Do.raw) + require.Len(t, rules[2].Do.pre, 2) + _, ok := rules[2].Do.pre[0].(Handler) + require.True(t, ok) + _, ok = rules[2].Do.pre[1].(Handler) + require.True(t, ok) + require.Len(t, rules[2].Do.post, 0) +} + +func TestParseBlockRules_Comments(t *testing.T) { + rules := testParseRules(t, `// This is a comment +default { + bypass // inline comment +} + +/* Block comment + spanning multiple lines */ +path /admin/* { + require_auth +}`) + require.Len(t, rules, 2) + assert.Equal(t, OnDefault, rules[0].On.raw) + assert.Equal(t, "path /admin/*", rules[1].On.raw) + assert.Equal(t, "require_auth", rules[1].Do.raw) +} + +func TestParseBlockRules_HashComment(t *testing.T) { + rules := testParseRules(t, `# YAML-style comment +default { + bypass +}`) + require.Len(t, rules, 1) + assert.Equal(t, OnDefault, rules[0].On.raw) + assert.Equal(t, "bypass", rules[0].Do.raw) +} + +func TestParseBlockRules_EnvVars(t *testing.T) { + t.Setenv("CUSTOM_HEADER", "test-header") + + rules := testParseRules(t, `path /api/* { + set header X-Custom "${CUSTOM_HEADER}" +}`) + require.Len(t, rules, 1) + assert.Equal(t, "path /api/*", rules[0].On.raw) + assert.Equal(t, `set header X-Custom "test-header"`, rules[0].Do.raw) + require.Len(t, rules[0].Do.pre, 1) + _, ok := rules[0].Do.pre[0].(Handler) + require.True(t, ok) + require.Len(t, rules[0].Do.post, 0) +} + +func TestParseBlockRules_YAMLFallback(t *testing.T) { + rules := testParseRules(t, `- name: default + do: bypass +- name: api + on: path glob(/api/*) + do: proxy http://localhost:8080`) + require.Len(t, rules, 2) + assert.Equal(t, "default", rules[0].Name) + assert.Equal(t, "bypass", rules[0].Do.raw) + assert.Equal(t, "api", rules[1].Name) + assert.Equal(t, "path glob(/api/*)", rules[1].On.raw) + assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw) + require.Len(t, rules[1].Do.pre, 1) + _, ok := rules[1].Do.pre[0].(Handler) + require.True(t, ok) + require.Len(t, rules[1].Do.post, 0) +} + +func TestParseBlockRules_UnmatchedBrace(t *testing.T) { + t.Run("unquoted", func(t *testing.T) { + err := testParseRulesError(t, `path /api/* { + proxy http://localhost:8080} +}`) + require.Error(t, err) + }) + t.Run("quoted", func(t *testing.T) { + _ = testParseRules(t, `path /api/* { + error 403 "some message}" + }`) + }) +} + +func TestParseBlockRules_UnterminatedBlockComment(t *testing.T) { + err := testParseRulesError(t, `/* unterminated block comment +default { + bypass +}`) + require.Error(t, err) +} + +func TestParseBlockRules_NestedBlocks(t *testing.T) { + rules := testParseRules(t, ` +header X-Test-Header { + set header X-Remote-Type public + @remote 127.0.0.1 | remote 192.168.0.0/16 { + set header X-Remote-Type private + } +}`) + + require.Len(t, rules, 1) + assert.Equal(t, "header X-Test-Header", rules[0].On.raw) + + require.Len(t, rules[0].Do.pre, 2) + _, ok := rules[0].Do.pre[0].(Handler) + require.True(t, ok) + require.Len(t, rules[0].Do.post, 0) + + ifCmd, ok := rules[0].Do.pre[1].(IfBlockCommand) + require.True(t, ok) + assert.Equal(t, "remote 127.0.0.1 | remote 192.168.0.0/16", ifCmd.On.raw) + + require.Len(t, ifCmd.Do, 1) + + upstream := func(http.ResponseWriter, *http.Request) {} + + t.Run("condition matched executes nested content", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Test-Header", "1") + req.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + rm := httputils.NewResponseModifier(w) + + err := rules[0].Do.pre.ServeHTTP(rm, req, upstream) + require.NoError(t, err) + assert.Equal(t, "private", req.Header.Get("X-Remote-Type")) + }) + + t.Run("condition not matched skips nested content", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Test-Header", "1") + req.RemoteAddr = "10.0.0.1:12345" + w := httptest.NewRecorder() + rm := httputils.NewResponseModifier(w) + + err := rules[0].Do.pre.ServeHTTP(rm, req, upstream) + require.NoError(t, err) + assert.Equal(t, "public", req.Header.Get("X-Remote-Type")) + }) +} + +func TestParseBlockRules_NestedBlocks_ElifElse(t *testing.T) { + rules := testParseRules(t, ` +header X-Test-Header { + set header X-Mode outer + @method GET { + set header X-Mode get + } elif method POST { + set header X-Mode post + } else { + set header X-Mode other + } +}`) + + require.Len(t, rules, 1) + + require.Len(t, rules[0].Do.pre, 2) + + ifCmd, ok := rules[0].Do.pre[1].(IfElseBlockCommand) + require.True(t, ok) + require.Len(t, ifCmd.Ifs, 2) + assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw) + assert.Equal(t, "method POST", ifCmd.Ifs[1].On.raw) + require.NotNil(t, ifCmd.Else) + + upstream := func(http.ResponseWriter, *http.Request) {} + cases := []struct { + name string + method string + want string + }{ + {name: "get branch", method: http.MethodGet, want: "get"}, + {name: "post branch", method: http.MethodPost, want: "post"}, + {name: "else branch", method: http.MethodPut, want: "other"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, "/", nil) + req.Header.Set("X-Test-Header", "1") + w := httptest.NewRecorder() + rm := httputils.NewResponseModifier(w) + + err := rules[0].Do.pre.ServeHTTP(rm, req, upstream) + require.NoError(t, err) + assert.Equal(t, tc.want, req.Header.Get("X-Mode")) + }) + } +} + +func TestParseBlockRules_DefaultRule_CommentBeforeBrace(t *testing.T) { + rules := testParseRules(t, `default /* comment between header and brace */ { + bypass +}`) + require.Len(t, rules, 1) + assert.Equal(t, OnDefault, rules[0].On.raw) + assert.Equal(t, "bypass", rules[0].Do.raw) +} + +func TestParseBlockRules_StrayClosingBraceAtTopLevel(t *testing.T) { + err := testParseRulesError(t, `} +default { + bypass +}`) + require.Error(t, err) +} + +func TestParseBlockRules_NestedBlocks_ElifMustBeSameLine(t *testing.T) { + err := testParseRulesError(t, `header X-Test-Header { + @method GET { + set header X-Mode get + } + elif method POST { + set header X-Mode post + } +}`) + require.Error(t, err) +} + +func TestParseBlockRules_NestedBlocks_ElseMustBeLastOnLine(t *testing.T) { + err := testParseRulesError(t, `header X-Test-Header { + @method GET { + set header X-Mode get + } else { + set header X-Mode other + } set header X-After else +}`) + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected token after else block") +} + +func TestParseBlockRules_NestedBlocks_MultipleElse(t *testing.T) { + err := testParseRulesError(t, `header X-Test-Header { + @method GET { + set header X-Mode get + } else { + set header X-Mode other + } else { + set header X-Mode other2 + } +}`) + require.Error(t, err) + // assert.Contains(t, err.Error(), "multiple 'else' branches") + assert.Contains(t, err.Error(), "unexpected token after else block") +} + +func TestParseBlockRules_NestedBlocks_ElifMissingOnExpr(t *testing.T) { + err := testParseRulesError(t, `header X-Test-Header { + @method GET { + set header X-Mode get + } elif { + set header X-Mode post + } +}`) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected on-expr after 'elif'") +} diff --git a/internal/route/rules/check_on.go b/internal/route/rules/check_on.go index 750a6211..739e79d2 100644 --- a/internal/route/rules/check_on.go +++ b/internal/route/rules/check_on.go @@ -1,21 +1,25 @@ package rules -import "net/http" +import ( + "net/http" + + httputils "github.com/yusing/goutils/http" +) type ( - CheckFunc func(w http.ResponseWriter, r *http.Request) bool + CheckFunc func(w *httputils.ResponseModifier, r *http.Request) bool Checker interface { - Check(w http.ResponseWriter, r *http.Request) bool + Check(w *httputils.ResponseModifier, r *http.Request) bool } CheckMatchSingle []Checker CheckMatchAll []Checker ) -func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool { +func (checker CheckFunc) Check(w *httputils.ResponseModifier, r *http.Request) bool { return checker(w, r) } -func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool { +func (checkers CheckMatchSingle) Check(w *httputils.ResponseModifier, r *http.Request) bool { for _, check := range checkers { if check.Check(w, r) { return true @@ -24,7 +28,7 @@ func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) b return false } -func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool { +func (checkers CheckMatchAll) Check(w *httputils.ResponseModifier, r *http.Request) bool { for _, check := range checkers { if !check.Check(w, r) { return false diff --git a/internal/route/rules/command.go b/internal/route/rules/command.go index a856edaa..4d8cda6a 100644 --- a/internal/route/rules/command.go +++ b/internal/route/rules/command.go @@ -1,79 +1,62 @@ package rules -import "net/http" +import ( + "errors" + "net/http" + + httputils "github.com/yusing/goutils/http" +) + +var errTerminateRule = errors.New("terminate rule") type ( - handlerFunc func(w http.ResponseWriter, r *http.Request) error + HandlerFunc func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error + Handler struct { + fn HandlerFunc + phase PhaseFlag + terminate bool + } CommandHandler interface { // CommandHandler can read and modify the values // then handle the request // finally proceed to next command (or return) base on situation - Handle(w http.ResponseWriter, r *http.Request) error - IsResponseHandler() bool + ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error + Phase() PhaseFlag } - // NonTerminatingCommand will run then proceed to next command or reverse proxy. - NonTerminatingCommand handlerFunc - // TerminatingCommand will run then return immediately. - TerminatingCommand handlerFunc - // OnResponseCommand will run then return based on the response. - OnResponseCommand handlerFunc - // BypassCommand will skip all the following commands - // and directly return to reverse proxy. - BypassCommand struct{} + // Commands is a slice of CommandHandler. Commands []CommandHandler ) -func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error { - return c(w, r) +func (h Handler) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + return h.fn(w, r, upstream) } -func (c NonTerminatingCommand) IsResponseHandler() bool { - return false +func (h Handler) Phase() PhaseFlag { + return h.phase } -func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error { - if err := c(w, r); err != nil { - return err - } - return errTerminated +func (h Handler) Terminates() bool { + return h.terminate } -func (c TerminatingCommand) IsResponseHandler() bool { - return false -} - -func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error { - return c(w, r) -} - -func (c OnResponseCommand) IsResponseHandler() bool { - return true -} - -func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error { - return errTerminated -} - -func (c BypassCommand) IsResponseHandler() bool { - return false -} - -func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error { +func (c Commands) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { for _, cmd := range c { - if err := cmd.Handle(w, r); err != nil { + err := cmd.ServeHTTP(w, r, upstream) + if err != nil { + // Terminating actions stop the command chain immediately. + // Will be handled by the caller. return err } } return nil } -func (c Commands) IsResponseHandler() bool { +func (c Commands) Phase() PhaseFlag { + req := PhaseNone for _, cmd := range c { - if cmd.IsResponseHandler() { - return true - } + req |= cmd.Phase() } - return false + return req } diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 11e98412..0e69834c 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -1,7 +1,6 @@ package rules import ( - "bytes" "fmt" "io" "net/http" @@ -24,17 +23,17 @@ import ( type ( Command struct { - raw string - exec CommandHandler - isResponseHandler bool + raw string + pre Commands // runs before w.WriteHeader + post Commands } ) -func (cmd *Command) IsResponseHandler() bool { - return cmd.isResponseHandler -} - const ( + CommandUpstream = "upstream" + CommandUpstreamOld = "bypass" + CommandUpstreamOld2 = "pass" + CommandRequireAuth = "require_auth" CommandRewrite = "rewrite" CommandServe = "serve" @@ -48,8 +47,6 @@ const ( CommandRemove = "remove" CommandLog = "log" CommandNotify = "notify" - CommandPass = "pass" - CommandPassAlt = "bypass" ) type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool) @@ -60,36 +57,57 @@ func InitAuthHandler(handler AuthHandler) { authHandler = handler } +func init() { + commands[CommandUpstreamOld] = commands[CommandUpstream] + commands[CommandUpstreamOld2] = commands[CommandUpstream] +} + var commands = map[string]struct { - help Help - validate ValidateFunc - build func(args any) CommandHandler - isResponseHandler bool + help Help + validate ValidateFunc + build func(args any) HandlerFunc + terminate bool }{ + CommandUpstream: { + help: Help{ + command: CommandUpstream, + description: makeLines("Pass the request to the upstream"), + args: map[string]string{}, + }, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + if len(args) != 0 { + return phase, nil, ErrExpectNoArg + } + return phase, nil, nil + }, + build: func(args any) HandlerFunc { + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + upstream(w, r) + return errTerminateRule + } + }, + terminate: true, + }, CommandRequireAuth: { help: Help{ command: CommandRequireAuth, description: makeLines("Require HTTP authentication for incoming requests"), args: map[string]string{}, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre if len(args) != 0 { - return nil, ErrExpectNoArg + return phase, nil, ErrExpectNoArg } - //nolint:nilnil - return nil, nil + return phase, nil, nil }, - build: func(args any) CommandHandler { - return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - if authHandler == nil { - http.Error(w, "Auth handler not initialized", http.StatusInternalServerError) - return errTerminated - } - if !authHandler(w, r) { - return errTerminated + build: func(args any) HandlerFunc { + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + if proceed := authHandler(w, r); !proceed { + return errTerminateRule } return nil - }) + } }, }, CommandRewrite: { @@ -104,26 +122,27 @@ var commands = map[string]struct { "to": "the path to rewrite to, must start with /", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre if len(args) != 2 { - return nil, ErrExpectTwoArgs + return phase, nil, ErrExpectTwoArgs } path1, err1 := validateURLPath(args[:1]) path2, err2 := validateURLPath(args[1:]) if err1 != nil { - err1 = gperr.PrependSubject(err1, "from") + err1 = gperr.Errorf("from: %w", err1) } if err2 != nil { - err2 = gperr.PrependSubject(err2, "to") + err2 = gperr.Errorf("to: %w", err2) } if err1 != nil || err2 != nil { - return nil, gperr.Join(err1, err2) + return phase, nil, gperr.Join(err1, err2) } - return &StrTuple{path1.(string), path2.(string)}, nil + return phase, &StrTuple{path1.(string), path2.(string)}, nil }, - build: func(args any) CommandHandler { + build: func(args any) HandlerFunc { orig, repl := args.(*StrTuple).Unpack() - return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { path := r.URL.Path if len(path) > 0 && path[0] != '/' { path = "/" + path @@ -133,10 +152,10 @@ var commands = map[string]struct { } path = repl + path[len(orig):] r.URL.Path = path - r.URL.RawPath = r.URL.EscapedPath() - r.RequestURI = r.URL.RequestURI() + r.URL.RawPath = "" + r.RequestURI = "" return nil - }) + } }, }, CommandServe: { @@ -150,14 +169,19 @@ var commands = map[string]struct { "root": "the file system path to serve, must be an existing directory", }, }, - validate: validateFSPath, - build: func(args any) CommandHandler { - root := args.(string) - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) - return nil - }) + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre + parsedArgs, err = validateFSPath(args) + return }, + build: func(args any) HandlerFunc { + root := args.(string) + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) + return errTerminateRule + } + }, + terminate: true, }, CommandRedirect: { help: Help{ @@ -170,14 +194,19 @@ var commands = map[string]struct { "to": "the url to redirect to, can be relative or absolute URL", }, }, - validate: validateURL, - build: func(args any) CommandHandler { - target := args.(*nettypes.URL).String() - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - http.Redirect(w, r, target, http.StatusTemporaryRedirect) - return nil - }) + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre + parsedArgs, err = validateURL(args) + return }, + build: func(args any) HandlerFunc { + target := args.(*nettypes.URL).String() + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + http.Redirect(w, r, target, http.StatusTemporaryRedirect) + return errTerminateRule + } + }, + terminate: true, }, CommandRoute: { help: Help{ @@ -190,15 +219,16 @@ var commands = map[string]struct { "route": "the route to route to", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre if len(args) != 1 { - return nil, ErrExpectOneArg + return phase, nil, ErrExpectOneArg } - return args[0], nil + return phase, args[0], nil }, - build: func(args any) CommandHandler { + build: func(args any) HandlerFunc { route := args.(string) - return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error { + return func(w *httputils.ResponseModifier, req *http.Request, upstream http.HandlerFunc) error { ep := entrypoint.FromCtx(req.Context()) r, ok := ep.HTTPRoutes().Get(route) if !ok { @@ -212,9 +242,10 @@ var commands = map[string]struct { } else { http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound) } - return nil - }) + return errTerminateRule + } }, + terminate: true, }, CommandError: { help: Help{ @@ -228,34 +259,40 @@ var commands = map[string]struct { "text": "the error message to return", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre if len(args) != 2 { - return nil, ErrExpectTwoArgs + return phase, nil, ErrExpectTwoArgs } codeStr, text := args[0], args[1] code, err := strconv.Atoi(codeStr) if err != nil { - return nil, ErrInvalidArguments.With(err) + return phase, nil, ErrInvalidArguments.With(err) } if !httputils.IsStatusCodeValid(code) { - return nil, ErrInvalidArguments.Subject(codeStr) + return phase, nil, ErrInvalidArguments.Subject(codeStr) } - textTmpl, err := validateTemplate(text, true) + tmplReq, textTmpl, err := validateTemplate(text, true) if err != nil { - return nil, ErrInvalidArguments.With(err) + return phase, nil, ErrInvalidArguments.With(err) } - return &Tuple[int, templateString]{code, textTmpl}, nil + phase |= tmplReq + return phase, &Tuple[int, templateString]{code, textTmpl}, nil }, - build: func(args any) CommandHandler { + build: func(args any) HandlerFunc { code, textTmpl := args.(*Tuple[int, templateString]).Unpack() - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { // error command should overwrite the response body - httputils.GetInitResponseModifier(w).ResetBody() + w.ResetBody() w.WriteHeader(code) - err := textTmpl.ExpandVars(w, r, w) - return err - }) + _, err := textTmpl.ExpandVars(w, r, w.BodyBuffer()) + if err != nil { + return err + } + return errTerminateRule + } }, + terminate: true, }, CommandRequireBasicAuth: { help: Help{ @@ -268,20 +305,22 @@ var commands = map[string]struct { "realm": "the authentication realm", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre if len(args) == 1 { - return args[0], nil + return phase, args[0], nil } - return nil, ErrExpectOneArg + return phase, nil, ErrExpectOneArg }, - build: func(args any) CommandHandler { + build: func(args any) HandlerFunc { realm := args.(string) - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`) + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, realm)) http.Error(w, "Unauthorized", http.StatusUnauthorized) - return nil - }) + return errTerminateRule + } }, + terminate: true, }, CommandProxy: { help: Help{ @@ -294,14 +333,19 @@ var commands = map[string]struct { "to": "the url to proxy to, must be an absolute URL", }, }, - validate: validateURL, - build: func(args any) CommandHandler { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePre + parsedArgs, err = validateURL(args) + return + }, + build: func(args any) HandlerFunc { target := args.(*nettypes.URL) if target.Scheme == "" { target.Scheme = "http" } if target.Host == "" { - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + rawPath := target.EscapedPath() + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { url := target.URL url.Host = routes.TryGetUpstreamHostPort(r) if url.Host == "" { @@ -309,18 +353,19 @@ var commands = map[string]struct { } rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport()) r.URL.Path = target.Path - r.URL.RawPath = r.URL.EscapedPath() - r.RequestURI = r.URL.RequestURI() + r.URL.RawPath = rawPath + r.RequestURI = "" rp.ServeHTTP(w, r) - return nil - }) + return errTerminateRule + } } rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport()) - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { rp.ServeHTTP(w, r) - return nil - }) + return errTerminateRule + } }, + terminate: true, }, CommandSet: { help: Help{ @@ -335,11 +380,11 @@ var commands = map[string]struct { "value": "the value to set", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { return validateModField(ModFieldSet, args) }, - build: func(args any) CommandHandler { - return args.(CommandHandler) + build: func(args any) HandlerFunc { + return args.(HandlerFunc) }, }, CommandAdd: { @@ -355,11 +400,11 @@ var commands = map[string]struct { "value": "the value to add", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { return validateModField(ModFieldAdd, args) }, - build: func(args any) CommandHandler { - return args.(CommandHandler) + build: func(args any) HandlerFunc { + return args.(HandlerFunc) }, }, CommandRemove: { @@ -374,15 +419,14 @@ var commands = map[string]struct { "field": "the field to remove", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { return validateModField(ModFieldRemove, args) }, - build: func(args any) CommandHandler { - return args.(CommandHandler) + build: func(args any) HandlerFunc { + return args.(HandlerFunc) }, }, CommandLog: { - isResponseHandler: true, help: Help{ command: CommandLog, description: makeLines( @@ -399,28 +443,28 @@ var commands = map[string]struct { "template": "the template to log", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 3 { - return nil, ErrExpectThreeArgs + return phase, nil, ErrExpectThreeArgs } - tmpl, err := validateTemplate(args[2], true) + phase, tmpl, err := validateTemplate(args[2], true) if err != nil { - return nil, err + return phase, nil, err } level, err := validateLevel(args[0]) if err != nil { - return nil, err + return phase, nil, err } // NOTE: file will stay opened forever // it leverages accesslog.NewFileIO so // it will be opened only once for the same path f, err := openFile(args[1]) if err != nil { - return nil, err + return phase, nil, err } - return &onLogArgs{level, f, tmpl}, nil + return phase, &onLogArgs{level, f, tmpl}, nil }, - build: func(args any) CommandHandler { + build: func(args any) HandlerFunc { level, f, tmpl := args.(*onLogArgs).Unpack() var logger io.Writer if f == stdout || f == stderr { @@ -428,17 +472,16 @@ var commands = map[string]struct { } else { logger = f } - return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { - err := tmpl.ExpandVars(w, r, logger) + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + _, err := tmpl.ExpandVars(w, r, logger) if err != nil { return err } return nil - }) + } }, }, CommandNotify: { - isResponseHandler: true, help: Help{ command: CommandNotify, description: makeLines( @@ -456,22 +499,24 @@ var commands = map[string]struct { "body": "the body of the notification", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 4 { - return nil, ErrExpectFourArgs + return phase, nil, ErrExpectFourArgs } - titleTmpl, err := validateTemplate(args[2], false) + req1, titleTmpl, err := validateTemplate(args[2], false) if err != nil { - return nil, err + return phase, nil, err } - bodyTmpl, err := validateTemplate(args[3], false) + req2, bodyTmpl, err := validateTemplate(args[3], false) if err != nil { - return nil, err + return phase, nil, err } level, err := validateLevel(args[0]) if err != nil { - return nil, err + return phase, nil, err } + + phase |= req1 | req2 // TODO: validate provider // currently it is not possible, because rule validation happens on UnmarshalYAMLValidate // and we cannot call config.ActiveConfig.Load() because it will cause import cycle @@ -480,34 +525,34 @@ var commands = map[string]struct { // if err != nil { // return nil, err // } - return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil + return phase, &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil }, - build: func(args any) CommandHandler { + build: func(args any) HandlerFunc { level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack() to := []string{provider} - return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { - respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len())) + return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + var respBuf strings.Builder - err := titleTmpl.ExpandVars(w, r, respBuf) + _, err := titleTmpl.ExpandVars(w, r, &respBuf) if err != nil { return err } titleLen := respBuf.Len() - err = bodyTmpl.ExpandVars(w, r, respBuf) + _, err = bodyTmpl.ExpandVars(w, r, &respBuf) if err != nil { return err } - b := respBuf.Bytes() + s := respBuf.String() notif.Notify(¬if.LogMessage{ Level: level, - Title: string(b[:titleLen]), - Body: notif.MessageBodyBytes(b[titleLen:]), + Title: s[:titleLen], + Body: notif.MessageBodyBytes(s[titleLen:]), To: to, }) return nil - }) + } }, }, } @@ -519,121 +564,29 @@ type ( // Parse implements strutils.Parser. func (cmd *Command) Parse(v string) error { - executors := make([]CommandHandler, 0) - isResponseHandler := false - for line := range strings.SplitSeq(v, "\n") { - if line == "" { - continue - } - - directive, args, err := parse(line) - if err != nil { - return err - } - - if directive == CommandPass || directive == CommandPassAlt { - if len(args) != 0 { - return ErrExpectNoArg - } - executors = append(executors, BypassCommand{}) - continue - } - - builder, ok := commands[directive] - if !ok { - return ErrUnknownDirective.Subject(directive) - } - validArgs, err := builder.validate(args) - if err != nil { - // Only attach help for the directive that failed, avoid bringing in unrelated KV errors - return gperr.PrependSubject(err, directive).With(builder.help.Error()) - } - - handler := builder.build(validArgs) - executors = append(executors, handler) - if builder.isResponseHandler || handler.IsResponseHandler() { - isResponseHandler = true - } + executors, parseErr := parseDoWithBlocks(v) + if parseErr != nil { + return parseErr } if len(executors) == 0 { cmd.raw = v - cmd.exec = nil - cmd.isResponseHandler = false + cmd.pre = nil + cmd.post = nil return nil } - exec, err := buildCmd(executors) - if err != nil { - return err - } - cmd.raw = v - cmd.exec = exec - if exec.IsResponseHandler() { - isResponseHandler = true + for _, executor := range executors { + if executor.Phase().IsPostRule() { + cmd.post = append(cmd.post, executor) + } else { + cmd.pre = append(cmd.pre, executor) + } } - cmd.isResponseHandler = isResponseHandler return nil } -func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) { - // Validate the execution order. - // - // This allows sequences like: - // route ws-api - // log info /dev/stdout "..." - // where the first command is request-phase and the last is response-phase. - lastNonResp := -1 - seenResp := false - for i, exec := range executors { - if exec.IsResponseHandler() { - seenResp = true - continue - } - if seenResp { - return nil, ErrInvalidCommandSequence.Withf("response handlers must be the last commands") - } - lastNonResp = i - } - - for i, exec := range executors { - if i > lastNonResp { - break // response-handler tail - } - switch exec.(type) { - case TerminatingCommand, BypassCommand: - if i != lastNonResp { - return nil, ErrInvalidCommandSequence. - Withf("a response handler or terminating/bypass command must be the last command") - } - } - } - - return Commands(executors), nil -} - -// Command is purely "bypass" or empty. -func (cmd *Command) isBypass() bool { - if cmd == nil { - return true - } - switch cmd := cmd.exec.(type) { - case BypassCommand: - return true - case Commands: - // bypass command is always the last one - _, ok := cmd[len(cmd)-1].(BypassCommand) - return ok - default: - return false - } -} - -func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error { - return cmd.exec.Handle(w, r) -} - func (cmd *Command) String() string { return cmd.raw } diff --git a/internal/route/rules/do_blocks.go b/internal/route/rules/do_blocks.go new file mode 100644 index 00000000..ada5fb3b --- /dev/null +++ b/internal/route/rules/do_blocks.go @@ -0,0 +1,386 @@ +package rules + +import ( + "net/http" + "strings" + "unicode" + + gperr "github.com/yusing/goutils/errs" + httputils "github.com/yusing/goutils/http" +) + +// IfBlockCommand is an inline conditional block inside a do-body. +// +// Syntax (within a rule do block): +// +// @ { } +// +// Semantics: +// - Evaluated in the same phase the parent rule runs. +// - If matches, run the nested commands in-order. +// - Otherwise do nothing. +// +// NOTE: Per current design decision, we keep this permissive: +// nested blocks may use response matchers and response commands; no extra phase validation is performed. +type IfBlockCommand struct { + On RuleOn + Do []CommandHandler +} + +func (c IfBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + if c.Do == nil { + return nil + } + // If On.checker is nil, treat as unconditional (should not happen if parsed). + if c.On.checker == nil { + return Commands(c.Do).ServeHTTP(w, r, upstream) + } + if c.On.checker.Check(w, r) { + return Commands(c.Do).ServeHTTP(w, r, upstream) + } + return nil +} + +func (c IfBlockCommand) Phase() PhaseFlag { + phase := c.On.phase + for _, cmd := range c.Do { + phase |= cmd.Phase() + } + return phase +} + +// IfElseBlockCommand is a chained conditional block inside a do-body. +// +// Syntax (within a rule do block): +// +// @ { } elif { } ... else { } +// +// NOTE: `elif`/`else` must appear on the same line as the preceding closing brace (`}`), +// e.g. `} elif ... {` and `} else {`. +type IfElseBlockCommand struct { + Ifs []IfBlockCommand + Else []CommandHandler +} + +func (c IfElseBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + for _, br := range c.Ifs { + // If On.checker is nil, treat as unconditional. + if br.On.checker == nil { + if br.Do == nil { + continue + } + return Commands(br.Do).ServeHTTP(w, r, upstream) + } + if br.Do == nil { + continue + } + if br.On.checker.Check(w, r) { + return Commands(br.Do).ServeHTTP(w, r, upstream) + } + } + if len(c.Else) > 0 { + return Commands(c.Else).ServeHTTP(w, r, upstream) + } + return nil +} + +func (c IfElseBlockCommand) Phase() PhaseFlag { + phase := PhaseNone + for _, br := range c.Ifs { + phase |= br.Phase() + } + if len(c.Else) > 0 { + phase |= Commands(c.Else).Phase() + } + return phase +} + +func skipSameLineSpace(src string, pos int) int { + for pos < len(src) { + switch src[pos] { + case '\n': + return pos + case '\r': + pos++ + continue + case ' ', '\t': + pos++ + continue + default: + return pos + } + } + return pos +} + +func parseAtBlockChain(src string, atPos int) (CommandHandler, int, error) { + length := len(src) + headerStart := atPos + 1 + + parseBranch := func(onExpr string, bodyStart int, bodyEnd int) (RuleOn, []CommandHandler, error) { + var on RuleOn + if err := on.Parse(onExpr); err != nil { + return RuleOn{}, nil, err + } + innerSrc := "" + if bodyStart < bodyEnd { + innerSrc = src[bodyStart:bodyEnd] + } + inner, err := parseDoWithBlocks(innerSrc) + if err != nil { + return RuleOn{}, nil, err + } + if len(inner) == 0 { + return on, nil, nil + } + return on, inner, nil + } + + onExpr, bracePos, herr := parseHeaderToBrace(src, headerStart) + if herr != nil { + return nil, 0, herr + } + if bracePos >= length || src[bracePos] != '{' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after nested block header") + } + + // Parse first @ { ... } + p := bracePos + bodyStart := p + 1 + bodyEnd, ferr := findMatchingBrace(src, &p, bodyStart) + if ferr != nil { + return nil, 0, ferr + } + firstOn, firstDo, berr := parseBranch(onExpr, bodyStart, bodyEnd) + if berr != nil { + return nil, 0, berr + } + + ifs := []IfBlockCommand{{On: firstOn, Do: firstDo}} + var elseDo []CommandHandler + hasChain := false + hasElse := false + + for { + q := skipSameLineSpace(src, p) + if q >= length || src[q] == '\n' { + break + } + + // elif { ... } + if strings.HasPrefix(src[q:], "elif") { + next := q + len("elif") + if next >= length { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'") + } + if src[next] == '\n' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'") + } + if !unicode.IsSpace(rune(src[next])) { + if src[next] == '{' || src[next] == '}' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'") + } + return nil, 0, ErrInvalidBlockSyntax.Withf("expected whitespace after 'elif'") + } + next++ + for next < length { + c := src[next] + if c == '\n' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition") + } + if c == '\r' { + next++ + continue + } + if !unicode.IsSpace(rune(c)) { + break + } + next++ + } + + p2 := next + elifOnExpr, bracePos, herr := parseHeaderToBrace(src, p2) + if herr != nil { + return nil, 0, herr + } + if elifOnExpr == "" { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'") + } + if bracePos >= length || src[bracePos] != '{' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition") + } + p2 = bracePos + elifBodyStart := p2 + 1 + elifBodyEnd, ferr := findMatchingBrace(src, &p2, elifBodyStart) + if ferr != nil { + return nil, 0, ferr + } + elifOn, elifDo, berr := parseBranch(elifOnExpr, elifBodyStart, elifBodyEnd) + if berr != nil { + return nil, 0, berr + } + ifs = append(ifs, IfBlockCommand{On: elifOn, Do: elifDo}) + hasChain = true + p = p2 + continue + } + + // else { ... } + if strings.HasPrefix(src[q:], "else") { + if hasElse { + return nil, 0, ErrInvalidBlockSyntax.Withf("multiple 'else' branches") + } + next := q + len("else") + for next < length { + c := src[next] + if c == '\n' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'") + } + if c == '\r' { + next++ + continue + } + if !unicode.IsSpace(rune(c)) { + break + } + next++ + } + if next >= length || src[next] != '{' { + return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'") + } + + elseBodyStart := next + 1 + p2 := next + elseBodyEnd, ferr := findMatchingBrace(src, &p2, elseBodyStart) + if ferr != nil { + return nil, 0, ferr + } + innerSrc := "" + if elseBodyStart < elseBodyEnd { + innerSrc = src[elseBodyStart:elseBodyEnd] + } + inner, ierr := parseDoWithBlocks(innerSrc) + if ierr != nil { + return nil, 0, ierr + } + if len(inner) == 0 { + elseDo = nil + } else { + elseDo = inner + } + hasChain = true + hasElse = true + p = p2 + + // else must be the last branch on that line. + for q2 := skipSameLineSpace(src, p); q2 < length && src[q2] != '\n'; q2 = skipSameLineSpace(src, q2) { + return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after else block") + } + break + } + + return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after nested block; expected 'elif'/'else' or newline") + } + + if hasChain { + return IfElseBlockCommand{Ifs: ifs, Else: elseDo}, p, nil + } + return IfBlockCommand{On: ifs[0].On, Do: ifs[0].Do}, p, nil +} + +// parseDoWithBlocks parses a do-body containing plain command lines and nested @-blocks. +// It returns the outer command handlers and the require phase. +// +// A nested block is only recognized when '@' is the first non-space character on a line. +func parseDoWithBlocks(src string) (handlers []CommandHandler, err error) { + pos := 0 + length := len(src) + lineStart := true + handlers = make([]CommandHandler, 0, strings.Count(src, "\n")+1) + + appendLineCommand := func(line string) error { + line = strings.TrimSpace(line) + if line == "" { + return nil + } + + directive, args, err := parse(line) + if err != nil { + return err + } + + builder, ok := commands[directive] + if !ok { + return ErrUnknownDirective.Subject(directive) + } + + phase, validArgs, err := builder.validate(args) + if err != nil { + return gperr.PrependSubject(err, directive).With(builder.help.Error()) + } + + h := builder.build(validArgs) + handlers = append(handlers, Handler{fn: h, phase: phase, terminate: builder.terminate}) + return nil + } + + for pos < length { + // Handle newlines + switch src[pos] { + case '\n': + pos++ + lineStart = true + continue + case '\r': + // tolerate CRLF + pos++ + continue + } + + if lineStart { + // Find first non-space on the line. + linePos := pos + for linePos < length { + c := rune(src[linePos]) + if c == '\n' { + break + } + if !unicode.IsSpace(c) { + break + } + linePos++ + } + + if linePos < length && src[linePos] == '@' { + h, next, err := parseAtBlockChain(src, linePos) + if err != nil { + return nil, err + } + handlers = append(handlers, h) + pos = next + lineStart = false + continue + } + + // Not a nested block; parse the rest of this line as a command. + lineEnd := pos + for lineEnd < length && src[lineEnd] != '\n' { + lineEnd++ + } + if lerr := appendLineCommand(src[pos:lineEnd]); lerr != nil { + return nil, lerr + } + pos = lineEnd + lineStart = true + continue + } + + // Not at line start; advance to the next line boundary. + for pos < length && src[pos] != '\n' { + pos++ + } + lineStart = true + } + + return handlers, nil +} diff --git a/internal/route/rules/do_blocks_test.go b/internal/route/rules/do_blocks_test.go new file mode 100644 index 00000000..14396ee0 --- /dev/null +++ b/internal/route/rules/do_blocks_test.go @@ -0,0 +1,73 @@ +package rules + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + httputils "github.com/yusing/goutils/http" +) + +func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoFallsThrough(t *testing.T) { + elseCalled := false + cmd := IfElseBlockCommand{ + Ifs: []IfBlockCommand{ + { + On: RuleOn{}, + Do: nil, + }, + }, + Else: []CommandHandler{ + Handler{ + fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error { + elseCalled = true + return nil + }, + phase: PhaseNone, + }, + }, + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + rm := httputils.NewResponseModifier(w) + + err := cmd.ServeHTTP(rm, req, nil) + require.NoError(t, err) + assert.True(t, elseCalled) +} + +func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoFallsThrough(t *testing.T) { + elseCalled := false + cmd := IfElseBlockCommand{ + Ifs: []IfBlockCommand{ + { + On: RuleOn{ + checker: CheckFunc(func(_ *httputils.ResponseModifier, _ *http.Request) bool { + return true + }), + }, + Do: nil, + }, + }, + Else: []CommandHandler{ + Handler{ + fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error { + elseCalled = true + return nil + }, + phase: PhaseNone, + }, + }, + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + rm := httputils.NewResponseModifier(w) + + err := cmd.ServeHTTP(rm, req, nil) + require.NoError(t, err) + assert.True(t, elseCalled) +} diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go index 79b39f87..191a0732 100644 --- a/internal/route/rules/do_log_test.go +++ b/internal/route/rules/do_log_test.go @@ -37,7 +37,7 @@ func parseRules(data string, target *Rules) error { } func TestLogCommand_TemporaryFile(t *testing.T) { - upstream := mockUpstreamWithHeaders(200, "success response", http.Header{ + upstream := mockUpstreamWithHeaders(http.StatusOK, "success response", http.Header{ "Content-Type": []string{"application/json"}, }) @@ -45,10 +45,9 @@ func TestLogCommand_TemporaryFile(t *testing.T) { var rules Rules err := parseRules(fmt.Sprintf(` -- name: log-request-response - do: | - log info %q '$req_method $req_url $status_code $resp_header(Content-Type)' -`, logFile), &rules) +default { + log info %q '$req_method $req_url $status_code $resp_header(Content-Type)' +}`, logFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -59,7 +58,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "success response", w.Body.String()) // Read and verify log content @@ -74,12 +73,10 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) { var rules Rules err := parseRules(` -- name: log-stdout - do: | - log info /dev/stdout "stdout: $req_method $status_code" -- name: log-stderr - do: | - log error /dev/stderr "stderr: $req_path $status_code" +default { + log info /dev/stdout "stdout: $req_method $status_code" + log error /dev/stderr "stderr: $req_path $status_code" +} `, &rules) require.NoError(t, err) @@ -90,7 +87,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, 200, w.Code) + assert.Equal(t, http.StatusOK, w.Code) // Note: We can't easily capture stdout/stderr in unit tests, // but we can verify no errors occurred and the handler completed } @@ -104,26 +101,22 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) { var rules Rules err := parseRules(fmt.Sprintf(` -- name: log-info - do: | - log info %s "INFO: $req_method $status_code" -- name: log-warn - do: | - log warn %s "WARN: $req_path $status_code" -- name: log-error - do: | - log error %s "ERROR: $req_method $req_path $status_code" +default { + log info %s "INFO: $req_method $status_code" + log warn %s "WARN: $req_path $status_code" + log error %s "ERROR: $req_method $req_path $status_code" +} `, infoFile, warnFile, errorFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("DELETE", "/api/resource/123", nil) + req := httptest.NewRequest(http.MethodDelete, "/api/resource/123", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) // Verify each log file infoContent := TestFileContent(infoFile) @@ -148,22 +141,22 @@ func TestLogCommand_TemplateVariables(t *testing.T) { var rules Rules err := parseRules(fmt.Sprintf(` -- name: log-with-templates - do: | - log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)' +default { + log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)' +} `, tempFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) - req := httptest.NewRequest("PUT", "/api/resource", nil) + req := httptest.NewRequest(http.MethodPut, "/api/resource", nil) req.Header.Set("User-Agent", "test-client/1.0") req.Host = "example.com" w := httptest.NewRecorder() handler.ServeHTTP(w, req) - assert.Equal(t, 201, w.Code) + assert.Equal(t, http.StatusCreated, w.Code) // Verify log content content := TestFileContent(tempFile) @@ -192,14 +185,12 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { var rules Rules err := parseRules(fmt.Sprintf(` -- name: log-success - on: status 2xx - do: | - log info %q "SUCCESS: $req_method $req_path $status_code" -- name: log-error - on: status 4xx | status 5xx - do: | - log error %q "ERROR: $req_method $req_path $status_code" +status 2xx { + log info %q "SUCCESS: $req_method $req_path $status_code" +} +status 4xx | status 5xx { + log error %q "ERROR: $req_method $req_path $status_code" +} `, successFile, errorFile), &rules) require.NoError(t, err) @@ -244,9 +235,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) { var rules Rules err := parseRules(fmt.Sprintf(` -- name: log-multiple - do: | - log info %q "$req_method $req_path $status_code"`, tempFile), &rules) +default { + log info %q "$req_method $req_path $status_code" +} +`, tempFile), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -256,10 +248,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) { method string path string }{ - {"GET", "/users"}, - {"POST", "/users"}, - {"PUT", "/users/1"}, - {"DELETE", "/users/1"}, + {http.MethodGet, "/users"}, + {http.MethodPost, "/users"}, + {http.MethodPost, "/users/1"}, + {http.MethodDelete, "/users/1"}, } for _, reqInfo := range requests { @@ -287,8 +279,9 @@ func TestLogCommand_InvalidTemplate(t *testing.T) { // Test with invalid template syntax err := parseRules(` -- name: log-invalid - do: | - log info /dev/stdout "$invalid_var"`, &rules) - assert.ErrorIs(t, err, ErrUnexpectedVar) +default { + log info /dev/stdout "$invalid_var" +} +`, &rules) + require.ErrorIs(t, err, ErrUnexpectedVar) } diff --git a/internal/route/rules/do_set.go b/internal/route/rules/do_set.go index 025f9f1c..88f93a5f 100644 --- a/internal/route/rules/do_set.go +++ b/internal/route/rules/do_set.go @@ -12,7 +12,7 @@ import ( type ( FieldHandler struct { - set, add, remove CommandHandler + set, add, remove HandlerFunc } FieldModifier string ) @@ -49,30 +49,30 @@ var modFields = map[string]struct { "value": "the header template", }, }, - validate: toKeyValueTemplate, + validate: validatePreRequestKVTemplate, builder: func(args any) *FieldHandler { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } r.Header[k] = []string{v} return nil - }), - add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + }, + add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } r.Header[k] = append(r.Header[k], v) return nil - }), - remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + }, + remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { delete(r.Header, k) return nil - }), + }, } }, }, @@ -84,30 +84,30 @@ var modFields = map[string]struct { "value": "the response header template", }, }, - validate: toKeyValueTemplate, + validate: validatePostResponseKVTemplate, builder: func(args any) *FieldHandler { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } w.Header()[k] = []string{v} return nil - }), - add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + }, + add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } w.Header()[k] = append(w.Header()[k], v) return nil - }), - remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + }, + remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { delete(w.Header(), k) return nil - }), + }, } }, }, @@ -119,36 +119,36 @@ var modFields = map[string]struct { "value": "the query template", }, }, - validate: toKeyValueTemplate, + validate: validatePreRequestKVTemplate, builder: func(args any) *FieldHandler { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } - httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + w.SharedData().UpdateQueries(r, func(queries url.Values) { queries.Set(k, v) }) return nil - }), - add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + }, + add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } - httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + w.SharedData().UpdateQueries(r, func(queries url.Values) { queries.Add(k, v) }) return nil - }), - remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + }, + remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + w.SharedData().UpdateQueries(r, func(queries url.Values) { queries.Del(k) }) return nil - }), + }, } }, }, @@ -160,16 +160,16 @@ var modFields = map[string]struct { "value": "the cookie value", }, }, - validate: toKeyValueTemplate, + validate: validatePreRequestKVTemplate, builder: func(args any) *FieldHandler { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } - httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { for i, c := range cookies { if c.Name == k { cookies[i].Value = v @@ -179,19 +179,19 @@ var modFields = map[string]struct { return append(cookies, &http.Cookie{Name: k, Value: v}) }) return nil - }), - add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := tmpl.ExpandVarsToString(w, r) + }, + add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + v, _, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } - httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { return append(cookies, &http.Cookie{Name: k, Value: v}) }) return nil - }), - remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + }, + remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { index := -1 for i, c := range cookies { if c.Name == k { @@ -208,7 +208,7 @@ var modFields = map[string]struct { return cookies }) return nil - }), + }, } }, }, @@ -227,24 +227,27 @@ var modFields = map[string]struct { "template": "the body template", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 1 { - return nil, ErrExpectOneArg + return 0, nil, ErrExpectOneArg } - return validateTemplate(args[0], true) + phase = PhasePre + tmplReq, parsedArgs, err := validateTemplate(args[0], true) + phase |= tmplReq + return }, builder: func(args any) *FieldHandler { tmpl := args.(templateString) return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { if r.Body != nil { r.Body.Close() r.Body = nil } - bufPool := httputils.GetInitResponseModifier(w).BufPool() + bufPool := w.BufPool() b := bufPool.GetBuffer() - err := tmpl.ExpandVars(w, r, b) + _, err := tmpl.ExpandVars(w, r, b) if err != nil { return err } @@ -252,7 +255,7 @@ var modFields = map[string]struct { bufPool.PutBuffer(b) }) return nil - }), + }, } }, }, @@ -272,20 +275,26 @@ var modFields = map[string]struct { "template": "the response body template", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 1 { - return nil, ErrExpectOneArg + return 0, nil, ErrExpectOneArg } - return validateTemplate(args[0], true) + phase = PhasePost + tmplReq, parsedArgs, err := validateTemplate(args[0], true) + phase |= tmplReq + return }, builder: func(args any) *FieldHandler { tmpl := args.(templateString) return &FieldHandler{ - set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { - rm := httputils.GetInitResponseModifier(w) - rm.ResetBody() - return tmpl.ExpandVars(w, r, rm) - }), + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + w.ResetBody() + _, err := tmpl.ExpandVars(w, r, w) + if err != nil { + return err + } + return nil + }, } }, }, @@ -300,26 +309,27 @@ var modFields = map[string]struct { "code": "the status code", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 1 { - return nil, ErrExpectOneArg + return phase, nil, ErrExpectOneArg } + phase = PhasePost status, err := strconv.Atoi(args[0]) if err != nil { - return nil, ErrInvalidArguments.With(err) + return phase, nil, ErrInvalidArguments.With(err) } if status < 100 || status > 599 { - return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status) + return phase, nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status) } - return status, nil + return phase, status, nil }, builder: func(args any) *FieldHandler { status := args.(int) return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - httputils.GetInitResponseModifier(w).WriteHeader(status) + set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error { + w.WriteHeader(status) return nil - }), + }, } }, }, diff --git a/internal/route/rules/do_set_test.go b/internal/route/rules/do_set_test.go index 3d861e41..5146c2cd 100644 --- a/internal/route/rules/do_set_test.go +++ b/internal/route/rules/do_set_test.go @@ -5,7 +5,6 @@ import ( "io" "net/http" "net/http/httptest" - "slices" "strings" "testing" @@ -72,12 +71,12 @@ func TestFieldHandler_Header(t *testing.T) { tt.setup(req) w := httptest.NewRecorder() - tmpl, tErr := validateTemplate(tt.value, false) + _, tmpl, tErr := validateTemplate(tt.value, false) if tErr != nil { t.Fatalf("Failed to validate template: %v", tErr) } handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl}) - var cmd CommandHandler + var cmd HandlerFunc switch tt.modifier { case ModFieldSet: cmd = handler.set @@ -87,7 +86,7 @@ func TestFieldHandler_Header(t *testing.T) { cmd = handler.remove } - err := cmd.Handle(w, req) + err := cmd(httputils.NewResponseModifier(w), req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } @@ -150,12 +149,12 @@ func TestFieldHandler_ResponseHeader(t *testing.T) { tt.setup(w) } - tmpl, tErr := validateTemplate(tt.value, false) + _, tmpl, tErr := validateTemplate(tt.value, false) if tErr != nil { t.Fatalf("Failed to validate template: %v", tErr) } handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl}) - var cmd CommandHandler + var cmd HandlerFunc switch tt.modifier { case ModFieldSet: cmd = handler.set @@ -165,7 +164,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) { cmd = handler.remove } - err := cmd.Handle(w, req) + err := cmd(httputils.NewResponseModifier(w), req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } @@ -237,12 +236,12 @@ func TestFieldHandler_Query(t *testing.T) { tt.setup(req) w := httptest.NewRecorder() - tmpl, tErr := validateTemplate(tt.value, false) + _, tmpl, tErr := validateTemplate(tt.value, false) if tErr != nil { t.Fatalf("Failed to validate template: %v", tErr) } handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl}) - var cmd CommandHandler + var cmd HandlerFunc switch tt.modifier { case ModFieldSet: cmd = handler.set @@ -252,7 +251,7 @@ func TestFieldHandler_Query(t *testing.T) { cmd = handler.remove } - err := cmd.Handle(w, req) + err := cmd(httputils.NewResponseModifier(w), req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } @@ -335,12 +334,12 @@ func TestFieldHandler_Cookie(t *testing.T) { tt.setup(req) w := httptest.NewRecorder() - tmpl, tErr := validateTemplate(tt.value, false) + _, tmpl, tErr := validateTemplate(tt.value, false) if tErr != nil { t.Fatalf("Failed to validate template: %v", tErr) } handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl}) - var cmd CommandHandler + var cmd HandlerFunc switch tt.modifier { case ModFieldSet: cmd = handler.set @@ -350,7 +349,7 @@ func TestFieldHandler_Cookie(t *testing.T) { cmd = handler.remove } - err := cmd.Handle(w, req) + err := cmd(httputils.NewResponseModifier(w), req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } @@ -371,7 +370,7 @@ func TestFieldHandler_Body(t *testing.T) { name: "set body with template", template: "Hello $req_method $req_path", setup: func(r *http.Request) { - r.Method = "POST" + r.Method = http.MethodPost r.URL.Path = "/test" }, verify: func(r *http.Request) { @@ -399,15 +398,15 @@ func TestFieldHandler_Body(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) - w := httptest.NewRecorder() + w := httputils.NewResponseModifier(httptest.NewRecorder()) - tmpl, tErr := validateTemplate(tt.template, false) + _, tmpl, tErr := validateTemplate(tt.template, false) if tErr != nil { t.Fatalf("Failed to parse template: %v", tErr) } handler := modFields[FieldBody].builder(tmpl) - err := handler.set.Handle(w, req) + err := handler.set(w, req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } @@ -428,7 +427,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) { name: "set response body with template", template: "Response: $req_method $req_path", setup: func(r *http.Request) { - r.Method = "GET" + r.Method = http.MethodGet r.URL.Path = "/api/test" }, verify: func(rm *httputils.ResponseModifier) { @@ -443,23 +442,20 @@ func TestFieldHandler_ResponseBody(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) tt.setup(req) - w := httptest.NewRecorder() + w := httputils.NewResponseModifier(httptest.NewRecorder()) - // Create ResponseModifier wrapper - rm := httputils.NewResponseModifier(w) - - tmpl, tErr := validateTemplate(tt.template, false) + _, tmpl, tErr := validateTemplate(tt.template, false) if tErr != nil { t.Fatalf("Failed to parse template: %v", tErr) } handler := modFields[FieldResponseBody].builder(tmpl) - err := handler.set.Handle(rm, req) + err := handler.set(w, req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } - tt.verify(rm) + tt.verify(w) }) } } @@ -472,23 +468,23 @@ func TestFieldHandler_StatusCode(t *testing.T) { }{ { name: "set status code 200", - status: 200, + status: http.StatusOK, verify: func(w *httptest.ResponseRecorder) { - assert.Equal(t, 200, w.Code, "Expected status code 200") + assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200") }, }, { name: "set status code 404", - status: 404, + status: http.StatusNotFound, verify: func(w *httptest.ResponseRecorder) { - assert.Equal(t, 404, w.Code, "Expected status code 404") + assert.Equal(t, http.StatusNotFound, w.Code, "Expected status code 404") }, }, { name: "set status code 500", - status: 500, + status: http.StatusInternalServerError, verify: func(w *httptest.ResponseRecorder) { - assert.Equal(t, 500, w.Code, "Expected status code 500") + assert.Equal(t, http.StatusInternalServerError, w.Code, "Expected status code 500") }, }, } @@ -503,12 +499,11 @@ func TestFieldHandler_StatusCode(t *testing.T) { if err != nil { t.Fatalf("Handler returned error: %v", err) } - err = cmd.ServeHTTP(rm, req) + err = cmd.post.ServeHTTP(rm, req, nil) if err != nil { t.Fatalf("Handler returned error: %v", err) } rm.FlushRelease() - tt.verify(w) }) } @@ -600,7 +595,7 @@ func TestFieldValidation(t *testing.T) { field, exists := modFields[tt.field] assert.True(t, exists, "Field %s does not exist", tt.field) - _, err := field.validate(tt.args) + _, _, err := field.validate(tt.args) if tt.wantError { assert.Error(t, err, "Expected error but got none") } else { @@ -610,25 +605,6 @@ func TestFieldValidation(t *testing.T) { } } -func TestAllFields(t *testing.T) { - expectedFields := []string{ - FieldHeader, - FieldResponseHeader, - FieldQuery, - FieldCookie, - FieldBody, - FieldResponseBody, - FieldStatusCode, - } - - require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields)) - - for _, expected := range expectedFields { - found := slices.Contains(AllFields, expected) - assert.True(t, found, "Expected field %s not found in AllFields", expected) - } -} - func TestModFields(t *testing.T) { for fieldName, field := range modFields { // Test that each field has required components diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index 4d310559..e616bde5 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -14,8 +14,9 @@ var ( ErrEnvVarNotFound = gperr.New("env variable not found") ErrInvalidArguments = gperr.New("invalid arguments") ErrInvalidOnTarget = gperr.New("invalid `rule.on` target") - ErrInvalidCommandSequence = gperr.New("invalid command sequence") - ErrMultipleDefaultRules = gperr.New("multiple default rules") + + ErrMultipleDefaultRules = gperr.New("multiple default rules") + ErrDeadRule = gperr.New("dead rule") // vars errors ErrNoArgProvided = gperr.New("no argument provided") @@ -31,5 +32,5 @@ var ( ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args") ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'") - errTerminated = gperr.New("terminated") + ErrInvalidBlockSyntax = gperr.New("invalid block syntax") // TODO: struct this error ) diff --git a/internal/route/rules/help.go b/internal/route/rules/help.go index e8bfb5c7..7503ce3d 100644 --- a/internal/route/rules/help.go +++ b/internal/route/rules/help.go @@ -131,12 +131,13 @@ Error generates help string as error, e.g. from: the path to rewrite, must start with / to: the path to rewrite to, must start with / */ -func (h *Help) Error() error { - var lines gperr.MultilineError +func (h *Help) Error() gperr.Error { + help := gperr.New(ansi.WithANSI(h.command, ansi.HighlightGreen)) + for _, line := range h.description { + help = help.Withf("%s", line) + } - lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen)) - lines.AddStrings(h.description...) - lines.Adds(" args:") + args := gperr.New("args") argKeys := make([]string, 0, len(h.args)) longestArg := 0 @@ -151,7 +152,9 @@ func (h *Help) Error() error { slices.Sort(argKeys) for _, arg := range argKeys { desc := h.args[arg] - lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc) + paddedArg := fmt.Sprintf("%-"+strconv.Itoa(longestArg)+"s", arg) + args = args.Withf("%s%s", ansi.WithANSI(paddedArg, ansi.HighlightCyan)+": ", desc) } - return &lines + + return help.With(args) } diff --git a/internal/route/rules/http_flow_block_test.go b/internal/route/rules/http_flow_block_test.go new file mode 100644 index 00000000..e4dcb83f --- /dev/null +++ b/internal/route/rules/http_flow_block_test.go @@ -0,0 +1,1328 @@ +package rules_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/route/routes" + "golang.org/x/crypto/bcrypt" + + . "github.com/yusing/godoxy/internal/route/rules" +) + +func TestHTTPFlow_BasicPreRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header")) + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream response")) + }) + + var rules Rules + err := parseRules(` +- name: add-header + on: path / + do: set header X-Custom-Header test-value +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "upstream response", w.Body.String()) + assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header")) +} + +func TestHTTPFlow_TerminatingCommand(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "should not be called") + + var rules Rules + err := parseRules(` +path /error { + error 403 Forbidden +} +path /error { + set header X-Header ignored +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/error", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, "Forbidden\n", w.Body.String()) + assert.Empty(t, w.Header().Get("X-Header")) +} + +func TestHTTPFlow_RedirectFlow(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "should not be called") + + var rules Rules + err := parseRules(` +path /old-path { + redirect /new-path +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/old-path", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTemporaryRedirect, w.Code) + assert.Equal(t, "/new-path", w.Header().Get("Location")) +} + +func TestHTTPFlow_RewriteFlow(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("path: " + r.URL.Path)) + }) + + var rules Rules + err := parseRules(` +path glob(/api/*) { + rewrite /api/ /v1/ +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/api/users", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "path: /v1/users", w.Body.String()) +} + +func TestHTTPFlow_MultiplePreRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id"))) + }) + + var rules Rules + err := parseRules(` +path / { + set header X-Request-Id req-123 +} +path / { + set header X-Auth-Token token-456 +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "upstream: req-123", w.Body.String()) + assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token")) +} + +func TestHTTPFlow_PostResponseRule(t *testing.T) { + upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{ + "X-Upstream": []string{"upstream-value"}, + }) + + tempFile := TestRandomFileName() + + var rules Rules + err := parseRules(fmt.Sprintf(` +path /test { + log info %s "$req_method $status_code" +} +`, tempFile), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "success", w.Body.String()) + assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream")) + + // Check log file + content := TestFileContent(tempFile) + assert.Equal(t, "GET 200\n", string(content)) +} + +func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/success" { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + } else { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("not found")) + } + }) + + var rules Rules + + // Create a temporary file for logging + tempFile := TestRandomFileName() + + err := parseRules(fmt.Sprintf(` +status 4xx { + log error %s "$req_url returned $status_code" +} +`, tempFile), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test successful request (should not log) + req1 := httptest.NewRequest(http.MethodGet, "/success", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + + // Test error request (should log) + req2 := httptest.NewRequest(http.MethodGet, "/notfound", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusNotFound, w2.Code) + + // Check log file + content := TestFileContent(tempFile) + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + require.Len(t, lines, 1, "only 4xx requests should be logged") + assert.Equal(t, "/notfound returned 404", lines[0]) +} + +func TestHTTPFlow_ConditionalRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello " + r.Header.Get("X-Username"))) + }) + + var rules Rules + err := parseRules(` +header Authorization { + set header X-Username authenticated-user + set resp_header X-Username authenticated-user +} +default { + set header X-Username anonymous + set resp_header X-Username anonymous +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test with Authorization header + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("Authorization", "Bearer token") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "hello authenticated-user", w1.Body.String()) + assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username")) + + // Test without Authorization header + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "hello anonymous", w2.Body.String()) + assert.Equal(t, "anonymous", w2.Header().Get("X-Username")) +} + +func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate different responses based on path + if r.URL.Path == "/protected" { + if r.Header.Get("X-Auth") != "valid" { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, "unauthorized") + return + } + } + w.Header().Set("X-Response-Time", "100ms") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "success") + }) + + // Create temporary files for logging + logFile := TestRandomFileName() + errorLogFile := TestRandomFileName() + + var rules Rules + err := parseRules(fmt.Sprintf(` +{ + set resp_header X-Correlation-Id random_uuid +} +path /protected { + require_basic_auth "Protected Area" +} +{ + log info %q "$req_method $req_url -> $status_code" +} +status 4xx { + log error %q "ERROR: $req_method $req_url $status_code" +} +`, logFile, errorLogFile), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test successful request + req1 := httptest.NewRequest(http.MethodGet, "/public", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "success", w1.Body.String()) + assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id")) + assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time")) + + // Test unauthorized protected request + req2 := httptest.NewRequest(http.MethodGet, "/protected", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusUnauthorized, w2.Code) + assert.Equal(t, w2.Body.String(), "Unauthorized\n") + + // Test authorized protected request + req3 := httptest.NewRequest(http.MethodGet, "/protected", nil) + req3.SetBasicAuth("user", "pass") + w3 := httptest.NewRecorder() + handler.ServeHTTP(w3, req3) + + // This should fail because our simple upstream expects X-Auth: valid header + // but the basic auth requirement should add the appropriate header + assert.Equal(t, http.StatusUnauthorized, w3.Code) + + // Check log files + logContent := TestFileContent(logFile) + lines := strings.Split(strings.TrimSpace(string(logContent)), "\n") + require.Len(t, lines, 3, "all requests should be logged") + assert.Equal(t, "GET /public -> 200", lines[0]) + assert.Equal(t, "GET /protected -> 401", lines[1]) + assert.Equal(t, "GET /protected -> 401", lines[2]) + + errorLogContent := TestFileContent(errorLogFile) + // Should have at least one 401 error logged + lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n") + require.Len(t, lines, 2, "all errors should be logged") + assert.Equal(t, "ERROR: GET /protected 401", lines[0]) + assert.Equal(t, "ERROR: GET /protected 401", lines[1]) +} + +func TestHTTPFlow_DefaultRule(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "upstream response") + + var rules Rules + err := parseRules(` +default { + set resp_header X-Default-Applied true +} +path /special { + set resp_header X-Special-Handled true +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test default rule + req1 := httptest.NewRequest(http.MethodGet, "/regular", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) + assert.Empty(t, w1.Header().Get("X-Special-Handled")) + + // Test special rule + default rule + req2 := httptest.NewRequest(http.MethodGet, "/special", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "true", w2.Header().Get("X-Default-Applied")) + assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) +} + +func TestHTTPFlow_HeaderManipulation(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Echo back a header + headerValue := r.Header.Get("X-Test-Header") + w.Header().Set("X-Echoed-Header", headerValue) + w.WriteHeader(http.StatusOK) + w.Write([]byte("header echoed")) + }) + + var rules Rules + err := parseRules(` +default { + remove resp_header X-Secret + add resp_header X-Custom-Header custom-value +} +header X-Test-Header { + set header X-Test-Header modified-value +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Secret", "secret-value") + req.Header.Set("X-Test-Header", "original-value") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header")) + assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header")) + // Ensure the secret header was removed and not passed to upstream + // (we can't directly test this, but the upstream shouldn't see it) +} + +func TestHTTPFlow_NestedBlocks_RemoteOverride(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Remote-Type", r.Header.Get("X-Remote-Type")) + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + var rules Rules + err := parseRules(` +header X-Test-Header { + set header X-Remote-Type public + @remote 127.0.0.1 | remote 192.168.0.0/16 { + set header X-Remote-Type private + } +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Localhost => private + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Test-Header", "1") + req1.RemoteAddr = "127.0.0.1:12345" + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "private", w1.Header().Get("X-Remote-Type")) + + // Public IP => public + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Test-Header", "1") + req2.RemoteAddr = "10.0.0.1:12345" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "public", w2.Header().Get("X-Remote-Type")) +} + +func TestHTTPFlow_NestedBlocks_ElifElse(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Mode", r.Header.Get("X-Mode")) + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + var rules Rules + err := parseRules(` +header X-Test-Header { + @method GET { + set header X-Mode get + } elif method POST { + set header X-Mode post + } else { + set header X-Mode other + } +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // GET => get + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Test-Header", "1") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "get", w1.Header().Get("X-Mode")) + + // POST => post + req2 := httptest.NewRequest(http.MethodPost, "/", nil) + req2.Header.Set("X-Test-Header", "1") + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "post", w2.Header().Get("X-Mode")) + + // other methods => else branch + req3 := httptest.NewRequest(http.MethodPut, "/", nil) + req3.Header.Set("X-Test-Header", "1") + w3 := httptest.NewRecorder() + handler.ServeHTTP(w3, req3) + assert.Equal(t, http.StatusOK, w3.Code) + assert.Equal(t, "other", w3.Header().Get("X-Mode")) + + // no match + req4 := httptest.NewRequest(http.MethodDelete, "/", nil) + w4 := httptest.NewRecorder() + handler.ServeHTTP(w4, req4) + assert.Equal(t, http.StatusOK, w4.Code) + assert.Equal(t, "", w4.Header().Get("X-Mode")) +} + +func TestHTTPFlow_NestedBlocks_TerminatingActionStopsFlow(t *testing.T) { + called := false + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream")) + }) + + var rules Rules + err := parseRules(` +path / { + set header X-Pre pre + @header X-Block { + error 403 "blocked" + } + set resp_header X-After should-not-run +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Without X-Block => should reach upstream and execute non-terminating commands + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + assert.Equal(t, http.StatusOK, w1.Code) + assert.True(t, called) + assert.Equal(t, "should-not-run", w1.Header().Get("X-After")) + assert.Equal(t, "pre", req1.Header.Get("X-Pre")) + + // With X-Block => nested terminating action should stop processing before upstream + called = false + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Block", "1") + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, 403, w2.Code) + assert.Equal(t, "blocked\n", w2.Body.String()) + assert.False(t, called, "nested error should terminate before calling upstream") + assert.Empty(t, w2.Header().Get("X-After"), "commands after the nested block should not run") +} + +func TestHTTPFlow_NestedBlocks_InResponseRule_ModifiesResponseByRequestMethod(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream")) + }) + + var rules Rules + err := parseRules(` +{ + set header X-Method "should-be-overridden" + @method POST { + set header X-Method "post" + } elif method GET { + set header X-Method "get" + } else { + set header X-Method "other" + } +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + t.Run(http.MethodGet, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "get", req.Header.Get("X-Method")) + }) + + t.Run(http.MethodPost, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "post", req.Header.Get("X-Method")) + }) + + t.Run("other", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "other", req.Header.Get("X-Method")) + }) +} + +func TestHTTPFlow_QueryParameterHandling(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + w.WriteHeader(http.StatusOK) + w.Write([]byte("query: " + query.Get("param"))) + }) + + var rules Rules + err := parseRules(` +query param { + set query param added-value +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/path?param=original", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + // The set command should have modified the query parameter + assert.Equal(t, "query: added-value", w.Body.String()) +} + +func TestHTTPFlow_ServeCommand(t *testing.T) { + // Create a temporary directory with test files + tempDir, err := os.MkdirTemp("", "test-serve-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create test files directly in the temp directory + testFile := filepath.Join(tempDir, "index.html") + err = os.WriteFile(testFile, []byte("

Test Page

"), 0o644) + require.NoError(t, err) + + var rules Rules + err = parseRules(fmt.Sprintf(` +path glob(/files/*) { + serve %s +} +`, tempDir), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called")) + + // Test serving a file - serve command serves files relative to the root directory + // The path /files/index.html gets mapped to tempDir + "/files/index.html" + // We need to create the file at the expected path + filesDir := filepath.Join(tempDir, "files") + err = os.Mkdir(filesDir, 0o755) + require.NoError(t, err) + + filesIndexFile := filepath.Join(filesDir, "index.html") + err = os.WriteFile(filesIndexFile, []byte("

Test Page

"), 0o644) + require.NoError(t, err) + + req1 := httptest.NewRequest(http.MethodGet, "/files/index.html", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + // The serve command should work, but might redirect + // Let's just verify it doesn't call the upstream + assert.NotEqual(t, "should not be called", w1.Body.String()) + + // Test file not found + req2 := httptest.NewRequest(http.MethodGet, "/files/nonexistent.html", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusNotFound, w2.Code) +} + +func TestHTTPFlow_ProxyCommand(t *testing.T) { + // Create a mock upstream server + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream-Header", "upstream-value") + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream response")) + })) + defer upstreamServer.Close() + + var rules Rules + err := parseRules(fmt.Sprintf(` +path glob(/api/*) { + proxy %s +} +`, upstreamServer.URL), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called")) + + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // The proxy command should forward the request to the upstream server + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "upstream response", w.Body.String()) + assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header")) +} + +func TestHTTPFlow_NotifyCommand(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "ok") + + var rules Rules + err := parseRules(` +path /notify { + notify info test-provider "title $req_method" "body $req_url $status_code" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/notify", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "ok", w.Body.String()) +} + +func TestHTTPFlow_FormConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("form processed")) + }) + + var rules Rules + err := parseRules(` +form username { + set resp_header X-Username "$form(username)" +} +postform email { + set resp_header X-Email "$postform(email)" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test form condition + formData := url.Values{"username": {"john_doe"}} + req1 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(formData.Encode())) + req1.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "john_doe", w1.Header().Get("X-Username")) + + // Test postform condition + postFormData := url.Values{"email": {"john@example.com"}} + req2 := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(postFormData.Encode())) + req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "john@example.com", w2.Header().Get("X-Email")) +} + +func TestHTTPFlow_RemoteConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("remote processed")) + }) + + var rules Rules + err := parseRules(` +remote 127.0.0.1 { + set resp_header X-Access "local" +} +remote 192.168.0.0/16 { + error 403 "Private network blocked" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test localhost condition + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.RemoteAddr = "127.0.0.1:12345" + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "local", w1.Header().Get("X-Access")) + + // Test private network block + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.RemoteAddr = "192.168.1.100:12345" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 403, w2.Code) + assert.Equal(t, "Private network blocked\n", w2.Body.String()) +} + +func TestHTTPFlow_BasicAuthConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("auth processed")) + }) + + // Generate bcrypt hashes for passwords + adminHash, err := bcrypt.GenerateFromPassword([]byte("adminpass"), bcrypt.DefaultCost) + require.NoError(t, err) + guestHash, err := bcrypt.GenerateFromPassword([]byte("guestpass"), bcrypt.DefaultCost) + require.NoError(t, err) + + var rules Rules + err = parseRules(fmt.Sprintf(` +basic_auth admin %q { + set resp_header X-Auth-Status "admin" +} +basic_auth guest %q { + set resp_header X-Auth-Status "guest" +} +`, string(adminHash), string(guestHash)), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test admin user + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.SetBasicAuth("admin", "adminpass") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status")) + + // Test guest user + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.SetBasicAuth("guest", "guestpass") + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status")) +} + +func TestHTTPFlow_RouteConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("route processed")) + }) + + var rules Rules + err := parseRules(` +route backend { + set resp_header X-Route "backend" +} +route frontend { + set resp_header X-Route "frontend" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test API route + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1 = routes.WithRouteContext(req1, mockRoute("backend")) + + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "backend", w1.Header().Get("X-Route")) + + // Test admin route + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2 = routes.WithRouteContext(req2, mockRoute("frontend")) + + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "frontend", w2.Header().Get("X-Route")) +} + +func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusMethodNotAllowed) + fmt.Fprint(w, "method not allowed") + }) + + var rules Rules + err := parseRules(` +status 405 { + error 405 'error' +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) + assert.Equal(t, "error\n", w.Body.String()) +} + +func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Response-Header", "response header") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "processed") + }) + + t.Run("any_value", func(t *testing.T) { + var rules Rules + err := parseRules(` +resp_header X-Response-Header { + error 405 "error" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) + assert.Equal(t, "error\n", w.Body.String()) + }) + t.Run("with_value", func(t *testing.T) { + var rules Rules + err := parseRules(` +resp_header X-Response-Header "response header" { + error 405 "error" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Code) + assert.Equal(t, "error\n", w.Body.String()) + }) + + t.Run("with_value_not_matched", func(t *testing.T) { + var rules Rules + err := parseRules(` +resp_header X-Response-Header "not-matched" { + error 405 "error" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "processed", w.Body.String()) + }) +} + +func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "complex processed") + }) + + var rules Rules + err := parseRules(` +path glob(/api/admin/*) & +header Authorization & +method POST { + set resp_header X-Access-Level "admin" + set resp_header X-API-Version "v1" +} +path glob(/api/users/*) & method GET { + set resp_header X-Access-Level "user" + set resp_header X-API-Version "v1" +} +path glob(/api/public/*) & method GET { + set resp_header X-Access-Level "public" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test admin API (should match first rule) + req1 := httptest.NewRequest(http.MethodPost, "/api/admin/users", nil) + req1.Header.Set("Authorization", "Bearer token") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, http.StatusOK, w1.Code) + assert.Equal(t, "admin", w1.Header().Get("X-Access-Level")) + assert.Equal(t, "v1", w1.Header()["X-API-Version"][0]) + + // Test user API (should match second rule) + req2 := httptest.NewRequest(http.MethodGet, "/api/users/profile", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, http.StatusOK, w2.Code) + assert.Equal(t, "user", w2.Header().Get("X-Access-Level")) + assert.Equal(t, "v1", w2.Header()["X-API-Version"][0]) + + // Test public API (should match third rule) + req3 := httptest.NewRequest(http.MethodGet, "/api/public/info", nil) + w3 := httptest.NewRecorder() + handler.ServeHTTP(w3, req3) + + assert.Equal(t, http.StatusOK, w3.Code) + assert.Equal(t, "public", w3.Header().Get("X-Access-Level")) + assert.Empty(t, w3.Header()["X-API-Version"]) +} + +func TestHTTPFlow_ResponseModifier(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "original response") + }) + + var rules Rules + err := parseRules(`{ + set resp_header X-Modified "true" + set resp_body "Modified: $req_method $req_path" +}`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "true", w.Header().Get("X-Modified")) + assert.Equal(t, "Modified: GET /test\n", w.Body.String()) +} + +func TestHTTPFlow_RequireBasicAuth_Challenge(t *testing.T) { + called := false + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "upstream") + }) + + var rules Rules + err := parseRules(` +path /protected { + require_basic_auth "My Realm" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.False(t, called, "require_basic_auth should terminate before calling upstream") + assert.Equal(t, 401, w.Code) + assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Basic") + assert.Contains(t, w.Header().Get("WWW-Authenticate"), "My Realm") +} + +func TestHTTPFlow_NegationMatcher(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "ok") + + var rules Rules + err := parseRules(` +!path glob("/public/*") { + set resp_header X-Scope private +} +path glob("/public/*") { + set resp_header X-Scope public +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + t.Run("public", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/public/index.html", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "public", w.Header().Get("X-Scope")) + }) + + t.Run("private", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/admin", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "private", w.Header().Get("X-Scope")) + }) +} + +func TestHTTPFlow_BlockSyntaxCommentsAreIgnored(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "ok") + + var rules Rules + err := parseRules(` +path /comment { + // comment with braces { } should be ignored + set resp_header X-Commented ok # trailing comment should be ignored too + set resp_header X-Literal "//not-a-comment" // but this one is a real comment + /* block comment + spanning multiple lines { } */ +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/comment", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "ok", w.Header().Get("X-Commented")) + assert.Equal(t, "//not-a-comment", w.Header().Get("X-Literal")) +} + +func TestHTTPFlow_RemoveResponseHeader_RemovesUpstreamHeader(t *testing.T) { + upstream := mockUpstreamWithHeaders(http.StatusOK, "ok", http.Header{ + "X-Secret": []string{"top-secret"}, + "X-Keep": []string{"keep"}, + }) + + var rules Rules + err := parseRules(` +{ + remove resp_header X-Secret +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "keep", w.Header().Get("X-Keep")) + assert.Empty(t, w.Result().Header.Get("X-Secret")) +} + +func TestHTTPFlow_RemoveRequestHeader_BeforeUpstream(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Seen-Secret", r.Header.Get("X-Secret")) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok") + }) + + var rules Rules + err := parseRules(` +{ + remove header X-Secret +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Secret", "should-not-reach-upstream") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Empty(t, w.Header().Get("X-Seen-Secret"), "X-Secret should be removed before reaching upstream") +} + +func TestHTTPFlow_RewritePreservesQueryString(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "path=%s foo=%s bar=%s", r.URL.Path, r.URL.Query().Get("foo"), r.URL.Query().Get("bar")) + }) + + var rules Rules + err := parseRules(` +path glob("/api/*") { + rewrite /api/ /v1/ +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/api/users?foo=1&bar=2", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "path=/v1/users foo=1 bar=2", w.Body.String()) +} + +func TestHTTPFlow_ResponseModifier_PreservesUpstreamStatus(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, "created") + }) + + var rules Rules + err := parseRules(` +{ + set resp_body "overridden" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodPost, "/create", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusCreated, w.Code) + assert.Equal(t, "overridden\n", w.Body.String()) +} + +func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMatchers(t *testing.T) { + upstreamCalled := false + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalled = true + fmt.Fprint(w, "upstream") + }) + + var rules Rules + err := parseRules(` +path / { + error 403 blocked +} +path / { + set resp_header X-Late should-not-run +} +status 4xx { + set resp_header X-Post true +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.False(t, upstreamCalled) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, "blocked\n", w.Body.String()) + assert.Equal(t, "should-not-run", w.Header().Get("X-Late")) + assert.Equal(t, "true", w.Header().Get("X-Post")) +} + +func TestHTTPFlow_PostRuleTermination_StopsRemainingCommandsInRule(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "ok") + }) + + var rules Rules + err := parseRules(` +status 200 { + error 500 failed + set resp_header X-After should-not-run +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "failed\n", w.Body.String()) + assert.Empty(t, w.Header().Get("X-After")) +} + +func TestHTTPFlow_EnvVarExpansionInDoBody(t *testing.T) { + t.Setenv("GODOXY_TEST_ENV", "env-value") + + upstream := mockUpstream(http.StatusOK, "ok") + + var rules Rules + err := parseRules(` +{ + set resp_header X-From-Env "${GODOXY_TEST_ENV}" +} +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "env-value", w.Header().Get("X-From-Env")) +} diff --git a/internal/route/rules/http_flow_test.go b/internal/route/rules/http_flow_yaml_test.go similarity index 82% rename from internal/route/rules/http_flow_test.go rename to internal/route/rules/http_flow_yaml_test.go index 2476ae6c..dcd9be2c 100644 --- a/internal/route/rules/http_flow_test.go +++ b/internal/route/rules/http_flow_yaml_test.go @@ -23,8 +23,9 @@ import ( ) // mockUpstream creates a simple upstream handler for testing -func mockUpstream(body string) http.HandlerFunc { +func mockUpstream(status int, body string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(status) w.Write([]byte(body)) } } @@ -47,7 +48,7 @@ func parseRules(data string, target *Rules) error { return err } -func TestHTTPFlow_BasicPreRules(t *testing.T) { +func TestHTTPFlow_BasicPreRulesYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header")) w.WriteHeader(http.StatusOK) @@ -74,8 +75,8 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) { assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header")) } -func TestHTTPFlow_BypassRule(t *testing.T) { - upstream := mockUpstream("upstream response") +func TestHTTPFlow_BypassRuleYAML(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "upstream response") var rules Rules err := parseRules(` @@ -99,8 +100,8 @@ func TestHTTPFlow_BypassRule(t *testing.T) { assert.Equal(t, "upstream response", w.Body.String()) } -func TestHTTPFlow_TerminatingCommand(t *testing.T) { - upstream := mockUpstream("should not be called") +func TestHTTPFlow_TerminatingCommandYAML(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "should not be called") var rules Rules err := parseRules(` @@ -120,13 +121,13 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, 403, w.Code) assert.Equal(t, "Forbidden\n", w.Body.String()) assert.Empty(t, w.Header().Get("X-Header")) } -func TestHTTPFlow_RedirectFlow(t *testing.T) { - upstream := mockUpstream("should not be called") +func TestHTTPFlow_RedirectFlowYAML(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "should not be called") var rules Rules err := parseRules(` @@ -143,11 +144,11 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect + assert.Equal(t, 307, w.Code) // TemporaryRedirect assert.Equal(t, "/new-path", w.Header().Get("Location")) } -func TestHTTPFlow_RewriteFlow(t *testing.T) { +func TestHTTPFlow_RewriteFlowYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("path: " + r.URL.Path)) @@ -172,7 +173,7 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) { assert.Equal(t, "path: /v1/users", w.Body.String()) } -func TestHTTPFlow_MultiplePreRules(t *testing.T) { +func TestHTTPFlow_MultiplePreRulesYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id"))) @@ -201,7 +202,7 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) { assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token")) } -func TestHTTPFlow_PostResponseRule(t *testing.T) { +func TestHTTPFlow_PostResponseRuleYAML(t *testing.T) { upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{ "X-Upstream": []string{"upstream-value"}, }) @@ -229,11 +230,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) { // Check log file content := TestFileContent(tempFile) - require.NoError(t, err) assert.Equal(t, "GET 200\n", string(content)) } -func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { +func TestHTTPFlow_ResponseRuleWithStatusConditionYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/success" { w.WriteHeader(http.StatusOK) @@ -246,14 +246,17 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { var rules Rules - // Create a temporary file for logging - tempFile := TestRandomFileName() + errorLog := TestRandomFileName() + infoLog := TestRandomFileName() err := parseRules(fmt.Sprintf(` -- name: log-errors - on: status 4xx - do: log error %s "$req_url returned $status_code" -`, tempFile), &rules) + status 4xx { + log error %s "$req_url returned $status_code" + } + status 200 { + log info %s "$req_url returned $status_code" + } +`, errorLog, infoLog), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -273,14 +276,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { assert.Equal(t, http.StatusNotFound, w2.Code) // Check log file - content := TestFileContent(tempFile) - require.NoError(t, err) + content := TestFileContent(errorLog) lines := strings.Split(strings.TrimSpace(string(content)), "\n") require.Len(t, lines, 1, "only 4xx requests should be logged") assert.Equal(t, "/notfound returned 404", lines[0]) + + infoContent := TestFileContent(infoLog) + lines = strings.Split(strings.TrimSpace(string(infoContent)), "\n") + require.Len(t, lines, 1, "only 200 requests should be logged") + assert.Equal(t, "/success returned 200", lines[0]) } -func TestHTTPFlow_ConditionalRules(t *testing.T) { +func TestHTTPFlow_ConditionalRulesYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("hello " + r.Header.Get("X-Username"))) @@ -320,22 +327,21 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) { assert.Equal(t, "anonymous", w2.Header().Get("X-Username")) } -func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { +func TestHTTPFlow_ComplexFlowWithPreAndPostRulesYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Simulate different responses based on path if r.URL.Path == "/protected" { if r.Header.Get("X-Auth") != "valid" { w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("unauthorized")) + fmt.Fprint(w, "unauthorized") return } } w.Header().Set("X-Response-Time", "100ms") w.WriteHeader(http.StatusOK) - w.Write([]byte("success")) + fmt.Fprint(w, "success") }) - // Create temporary files for logging logFile := TestRandomFileName() errorLogFile := TestRandomFileName() @@ -374,7 +380,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { handler.ServeHTTP(w2, req2) assert.Equal(t, http.StatusUnauthorized, w2.Code) - assert.Equal(t, "Unauthorized\n", w2.Body.String()) + assert.Equal(t, w2.Body.String(), "Unauthorized\n") // Test authorized protected request req3 := httptest.NewRequest(http.MethodGet, "/protected", nil) @@ -402,8 +408,8 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { assert.Equal(t, "ERROR: GET /protected 401", lines[1]) } -func TestHTTPFlow_DefaultRule(t *testing.T) { - upstream := mockUpstream("upstream response") +func TestHTTPFlow_DefaultRuleYAML(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "upstream response") var rules Rules err := parseRules(` @@ -436,11 +442,12 @@ func TestHTTPFlow_DefaultRule(t *testing.T) { assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) } -func TestHTTPFlow_HeaderManipulation(t *testing.T) { +func TestHTTPFlow_HeaderManipulationYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Echo back a header headerValue := r.Header.Get("X-Test-Header") w.Header().Set("X-Echoed-Header", headerValue) + w.Header().Set("X-Secret", "sensitive-data") w.WriteHeader(http.StatusOK) w.Write([]byte("header echoed")) }) @@ -460,7 +467,6 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) { handler := rules.BuildHandler(upstream) req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("X-Secret", "secret-value") req.Header.Set("X-Test-Header", "original-value") w := httptest.NewRecorder() @@ -469,11 +475,10 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header")) assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header")) - // Ensure the secret header was removed and not passed to upstream - // (we can't directly test this, but the upstream shouldn't see it) + assert.Empty(t, w.Header().Get("X-Secret")) } -func TestHTTPFlow_QueryParameterHandling(t *testing.T) { +func TestHTTPFlow_QueryParameterHandlingYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() w.WriteHeader(http.StatusOK) @@ -500,13 +505,15 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) { assert.Equal(t, "query: added-value", w.Body.String()) } -func TestHTTPFlow_ServeCommand(t *testing.T) { +func TestHTTPFlow_ServeCommandYAML(t *testing.T) { // Create a temporary directory with test files - tempDir := t.TempDir() + tempDir, err := os.MkdirTemp("", "test-serve-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) // Create test files directly in the temp directory testFile := filepath.Join(tempDir, "index.html") - err := os.WriteFile(testFile, []byte("

Test Page

"), 0o644) + err = os.WriteFile(testFile, []byte("

Test Page

"), 0o644) require.NoError(t, err) var rules Rules @@ -517,7 +524,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) { `, tempDir), &rules) require.NoError(t, err) - handler := rules.BuildHandler(mockUpstream("should not be called")) + handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called")) // Test serving a file - serve command serves files relative to the root directory // The path /files/index.html gets mapped to tempDir + "/files/index.html" @@ -546,7 +553,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) { assert.Equal(t, http.StatusNotFound, w2.Code) } -func TestHTTPFlow_ProxyCommand(t *testing.T) { +func TestHTTPFlow_ProxyCommandYAML(t *testing.T) { // Create a mock upstream server upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Upstream-Header", "upstream-value") @@ -563,7 +570,7 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) { `, upstreamServer.URL), &rules) require.NoError(t, err) - handler := rules.BuildHandler(mockUpstream("should not be called")) + handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called")) req := httptest.NewRequest(http.MethodGet, "/api/test", nil) w := httptest.NewRecorder() @@ -576,11 +583,28 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) { assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header")) } -func TestHTTPFlow_NotifyCommand(t *testing.T) { - // TODO: +func TestHTTPFlow_NotifyCommandYAML(t *testing.T) { + upstream := mockUpstream(http.StatusOK, "ok") + + var rules Rules + err := parseRules(` +- name: notify-rule + on: path /notify + do: notify info test-provider "title $req_method" "body $req_url $status_code" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/notify", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "ok", w.Body.String()) } -func TestHTTPFlow_FormConditions(t *testing.T) { +func TestHTTPFlow_FormConditionsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("form processed")) @@ -620,7 +644,7 @@ func TestHTTPFlow_FormConditions(t *testing.T) { assert.Equal(t, "john@example.com", w2.Header().Get("X-Email")) } -func TestHTTPFlow_RemoteConditions(t *testing.T) { +func TestHTTPFlow_RemoteConditionsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("remote processed")) @@ -654,11 +678,11 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) { w2 := httptest.NewRecorder() handler.ServeHTTP(w2, req2) - assert.Equal(t, http.StatusForbidden, w2.Code) + assert.Equal(t, 403, w2.Code) assert.Equal(t, "Private network blocked\n", w2.Body.String()) } -func TestHTTPFlow_BasicAuthConditions(t *testing.T) { +func TestHTTPFlow_BasicAuthConditionsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("auth processed")) @@ -702,7 +726,7 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) { assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status")) } -func TestHTTPFlow_RouteConditions(t *testing.T) { +func TestHTTPFlow_RouteConditionsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("route processed")) @@ -742,10 +766,10 @@ func TestHTTPFlow_RouteConditions(t *testing.T) { assert.Equal(t, "frontend", w2.Header().Get("X-Route")) } -func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { +func TestHTTPFlow_ResponseStatusConditionsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusMethodNotAllowed) - w.Write([]byte("method not allowed")) + fmt.Fprint(w, "method not allowed") }) var rules Rules @@ -767,11 +791,11 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { assert.Equal(t, "error\n", w.Body.String()) } -func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { +func TestHTTPFlow_ResponseHeaderConditionsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Response-Header", "response header") w.WriteHeader(http.StatusOK) - w.Write([]byte("processed")) + fmt.Fprint(w, "processed") }) t.Run("any_value", func(t *testing.T) { @@ -831,7 +855,65 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { }) } -func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { +func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMatchersYAML(t *testing.T) { + upstreamCalled := false + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalled = true + w.WriteHeader(http.StatusOK) + w.Write([]byte("upstream")) + }) + + var rules Rules + err := parseRules(` +- on: path / + do: error 403 blocked +- on: path / + do: set resp_header X-Late should-not-run +- on: status 4xx + do: set resp_header X-Post true +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.False(t, upstreamCalled) + assert.Equal(t, 403, w.Code) + assert.Equal(t, "blocked\n", w.Body.String()) + assert.Equal(t, "should-not-run", w.Header().Get("X-Late")) + assert.Equal(t, "true", w.Header().Get("X-Post")) +} + +func TestHTTPFlow_PostRuleTermination_StopsRemainingCommandsInRuleYAML(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + + var rules Rules + err := parseRules(` +- on: status 200 + do: | + error 500 failed + set resp_header X-After should-not-run +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "failed\n", w.Body.String()) + assert.Empty(t, w.Header().Get("X-After")) +} + +func TestHTTPFlow_ComplexRuleCombinationsYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("complex processed")) @@ -887,12 +969,12 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { w3 := httptest.NewRecorder() handler.ServeHTTP(w3, req3) - assert.Equal(t, 200, w3.Code) + assert.Equal(t, http.StatusOK, w3.Code) assert.Equal(t, "public", w3.Header().Get("X-Access-Level")) assert.Empty(t, w3.Header()["X-API-Version"]) } -func TestHTTPFlow_ResponseModifier(t *testing.T) { +func TestHTTPFlow_ResponseModifierYAML(t *testing.T) { upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("original response")) diff --git a/internal/route/rules/io.go b/internal/route/rules/io.go index 6ce0e53c..67df0ac3 100644 --- a/internal/route/rules/io.go +++ b/internal/route/rules/io.go @@ -10,6 +10,7 @@ import ( "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/logging/accesslog" + gperr "github.com/yusing/goutils/errs" ) type noopWriteCloser struct { @@ -30,7 +31,7 @@ var ( testFilesLock sync.Mutex ) -func openFile(path string) (io.WriteCloser, error) { +func openFile(path string) (io.WriteCloser, gperr.Error) { switch path { case "/dev/stdout": return stdout, nil diff --git a/internal/route/rules/matcher.go b/internal/route/rules/matcher.go index 7c27d908..f29d0fab 100644 --- a/internal/route/rules/matcher.go +++ b/internal/route/rules/matcher.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/gobwas/glob" + "github.com/puzpuzpuz/xsync/v4" gperr "github.com/yusing/goutils/errs" ) @@ -13,6 +14,8 @@ type ( MatcherType string ) +var matcherCache = xsync.NewMap[string, Matcher]() // map[string]Matcher + const ( MatcherTypeString MatcherType = "string" MatcherTypeGlob MatcherType = "glob" @@ -59,7 +62,12 @@ func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Erro } func ParseMatcher(expr string) (Matcher, gperr.Error) { + if cached, ok := matcherCache.Load(expr); ok { + return cached, nil + } + negate := false + origExpr := expr if strings.HasPrefix(expr, "!") { negate = true expr = expr[1:] @@ -72,11 +80,23 @@ func ParseMatcher(expr string) (Matcher, gperr.Error) { switch t { case MatcherTypeString: - return StringMatcher(expr, negate) + m, err := StringMatcher(expr, negate) + if err == nil { + matcherCache.Store(origExpr, m) + } + return m, err case MatcherTypeGlob: - return GlobMatcher(expr, negate) + m, err := GlobMatcher(expr, negate) + if err == nil { + matcherCache.Store(origExpr, m) + } + return m, err case MatcherTypeRegex: - return RegexMatcher(expr, negate) + m, err := RegexMatcher(expr, negate) + if err == nil { + matcherCache.Store(origExpr, m) + } + return m, err } // won't reach here return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t) diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 79330268..e050c4e7 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -12,19 +12,19 @@ import ( ) type RuleOn struct { - raw string - checker Checker - isResponseChecker bool -} - -func (on *RuleOn) IsResponseChecker() bool { - return on.isResponseChecker + raw string + checker Checker + phase PhaseFlag } func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool { - return on.checker.Check(w, r) + if on.checker == nil { + return true + } + return on.checker.Check(httputils.GetInitResponseModifier(w), r) } +// on request const ( OnDefault = "default" OnHeader = "header" @@ -39,35 +39,36 @@ const ( OnRemote = "remote" OnBasicAuth = "basic_auth" OnRoute = "route" +) - // on response - +// on response +const ( OnResponseHeader = "resp_header" OnStatus = "status" ) var checkers = map[string]struct { - help Help - validate ValidateFunc - builder func(args any) CheckFunc - isResponseChecker bool + help Help + validate ValidateFunc + builder func(args any) CheckFunc }{ OnDefault: { help: Help{ command: OnDefault, description: makeLines( - "The default rule is matched when no other rules are matched.", + "Select the default (baseline) rule.", ), args: map[string]string{}, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 0 { - return nil, ErrExpectNoArg + return phase, nil, ErrExpectNoArg } - //nolint:nilnil - return nil, nil + return phase, nil, nil }, - builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called + builder: func(args any) CheckFunc { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return false } + }, // this should never be called }, OnHeader: { help: Help{ @@ -83,21 +84,23 @@ var checkers = map[string]struct { "[value]": "the header value", }, }, - validate: toKVOptionalVMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = toKVOptionalVMatcher(args) + return + }, builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return len(r.Header[k]) > 0 } } - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return slices.ContainsFunc(r.Header[k], matcher) } }, }, OnResponseHeader: { - isResponseChecker: true, help: Help{ command: OnResponseHeader, description: makeLines( @@ -111,16 +114,20 @@ var checkers = map[string]struct { "[value]": "the response header value", }, }, - validate: toKVOptionalVMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePost + parsedArgs, err = toKVOptionalVMatcher(args) + return + }, builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(w http.ResponseWriter, r *http.Request) bool { - return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0 + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return len(w.Header()[k]) > 0 } } - return func(w http.ResponseWriter, r *http.Request) bool { - return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return slices.ContainsFunc(w.Header()[k], matcher) } }, }, @@ -138,16 +145,19 @@ var checkers = map[string]struct { "[value]": "the query value", }, }, - validate: toKVOptionalVMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = toKVOptionalVMatcher(args) + return + }, builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(w http.ResponseWriter, r *http.Request) bool { - return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0 + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return len(w.SharedData().GetQueries(r)[k]) > 0 } } - return func(w http.ResponseWriter, r *http.Request) bool { - return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher) } }, }, @@ -165,12 +175,15 @@ var checkers = map[string]struct { "[value]": "the cookie value", }, }, - validate: toKVOptionalVMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = toKVOptionalVMatcher(args) + return + }, builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(w http.ResponseWriter, r *http.Request) bool { - cookies := httputils.GetSharedData(w).GetCookies(r) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + cookies := w.SharedData().GetCookies(r) for _, cookie := range cookies { if cookie.Name == k { return true @@ -179,8 +192,8 @@ var checkers = map[string]struct { return false } } - return func(w http.ResponseWriter, r *http.Request) bool { - cookies := httputils.GetSharedData(w).GetCookies(r) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + cookies := w.SharedData().GetCookies(r) for _, cookie := range cookies { if cookie.Name == k { if matcher(cookie.Value) { @@ -192,6 +205,7 @@ var checkers = map[string]struct { } }, }, + //nolint:dupl OnForm: { help: Help{ command: OnForm, @@ -206,15 +220,18 @@ var checkers = map[string]struct { "[value]": "the form value", }, }, - validate: toKVOptionalVMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = toKVOptionalVMatcher(args) + return + }, builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return r.FormValue(k) != "" } } - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return matcher(r.FormValue(k)) } }, @@ -233,15 +250,18 @@ var checkers = map[string]struct { "[value]": "the form value", }, }, - validate: toKVOptionalVMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = toKVOptionalVMatcher(args) + return + }, builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return r.PostFormValue(k) != "" } } - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return matcher(r.PostFormValue(k)) } }, @@ -250,32 +270,46 @@ var checkers = map[string]struct { help: Help{ command: OnProto, args: map[string]string{ - "proto": "the http protocol (http, https, h3)", + "proto": "the http protocol (http, https, h1, h2, h2c, h3)", }, }, - validate: func(args []string) (any, error) { + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { if len(args) != 1 { - return nil, ErrExpectOneArg + return phase, nil, ErrExpectOneArg } proto := args[0] - if proto != "http" && proto != "https" && proto != "h3" { - return nil, ErrInvalidArguments.Withf("proto: %q", proto) + switch proto { + case "http", "https", "h1", "h2", "h2c", "h3": + return phase, proto, nil + default: + return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto) } - return proto, nil }, builder: func(args any) CheckFunc { proto := args.(string) switch proto { case "http": - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return r.TLS == nil } case "https": - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return r.TLS != nil } + case "h1": + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return r.TLS == nil && r.ProtoMajor == 1 + } + case "h2": + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return r.TLS != nil && r.ProtoMajor == 2 + } + case "h2c": + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return r.TLS == nil && r.ProtoMajor == 2 + } default: // h3 - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return r.TLS != nil && r.ProtoMajor == 3 } } @@ -288,10 +322,13 @@ var checkers = map[string]struct { "method": "the http method", }, }, - validate: validateMethod, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = validateMethod(args) + return + }, builder: func(args any) CheckFunc { method := args.(string) - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return r.Method == method } }, @@ -310,10 +347,13 @@ var checkers = map[string]struct { "host": "the host name", }, }, - validate: validateSingleMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = validateSingleMatcher(args) + return + }, builder: func(args any) CheckFunc { matcher := args.(Matcher) - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return matcher(r.Host) } }, @@ -332,10 +372,13 @@ var checkers = map[string]struct { "path": "the request path", }, }, - validate: validateURLPathMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = validateURLPathMatcher(args) + return + }, builder: func(args any) CheckFunc { matcher := args.(Matcher) - return func(w http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { reqPath := r.URL.Path if len(reqPath) > 0 && reqPath[0] != '/' { reqPath = "/" + reqPath @@ -351,22 +394,25 @@ var checkers = map[string]struct { "ip|cidr": "the remote ip or cidr", }, }, - validate: validateCIDR, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = validateCIDR(args) + return + }, builder: func(args any) CheckFunc { ipnet := args.(*net.IPNet) // for /32 (IPv4) or /128 (IPv6), just compare the IP if ones, bits := ipnet.Mask.Size(); ones == bits { wantIP := ipnet.IP - return func(w http.ResponseWriter, r *http.Request) bool { - ip := httputils.GetSharedData(w).GetRemoteIP(r) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + ip := w.SharedData().GetRemoteIP(r) if ip == nil { return false } return ip.Equal(wantIP) } } - return func(w http.ResponseWriter, r *http.Request) bool { - ip := httputils.GetSharedData(w).GetRemoteIP(r) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + ip := w.SharedData().GetRemoteIP(r) if ip == nil { return false } @@ -382,11 +428,14 @@ var checkers = map[string]struct { "password": "the password encrypted with bcrypt", }, }, - validate: validateUserBCryptPassword, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = validateUserBCryptPassword(args) + return + }, builder: func(args any) CheckFunc { cred := args.(*HashedCrendentials) - return func(w http.ResponseWriter, r *http.Request) bool { - return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r)) + return func(w *httputils.ResponseModifier, r *http.Request) bool { + return cred.Match(w.SharedData().GetBasicAuth(r)) } }, }, @@ -403,16 +452,18 @@ var checkers = map[string]struct { "route": "the route name", }, }, - validate: validateSingleMatcher, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + parsedArgs, err = validateSingleMatcher(args) + return + }, builder: func(args any) CheckFunc { matcher := args.(Matcher) - return func(_ http.ResponseWriter, r *http.Request) bool { + return func(w *httputils.ResponseModifier, r *http.Request) bool { return matcher(routes.TryGetUpstreamName(r)) } }, }, OnStatus: { - isResponseChecker: true, help: Help{ command: OnStatus, description: makeLines( @@ -429,16 +480,20 @@ var checkers = map[string]struct { "status": "the status code range", }, }, - validate: validateStatusRange, + validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) { + phase = PhasePost + parsedArgs, err = validateStatusRange(args) + return + }, builder: func(args any) CheckFunc { beg, end := args.(*IntTuple).Unpack() if beg == end { - return func(w http.ResponseWriter, _ *http.Request) bool { - return httputils.GetInitResponseModifier(w).StatusCode() == beg + return func(w *httputils.ResponseModifier, _ *http.Request) bool { + return w.StatusCode() == beg } } - return func(w http.ResponseWriter, _ *http.Request) bool { - statusCode := httputils.GetInitResponseModifier(w).StatusCode() + return func(w *httputils.ResponseModifier, _ *http.Request) bool { + statusCode := w.StatusCode() return statusCode >= beg && statusCode <= end } }, @@ -515,85 +570,119 @@ func splitPipe(s string) []string { return []string{} } - var result []string - var current strings.Builder - escaped := false - quote := rune(0) + result := make([]string, 0, 2) + quote := byte(0) brackets := 0 + start := 0 - for _, r := range s { - if escaped { - current.WriteRune(r) - escaped = false - continue - } - - switch r { + for i := 0; i < len(s); i++ { + switch s[i] { case '\\': - escaped = true - current.WriteRune(r) + // Skip escaped character. + if i+1 < len(s) { + i++ + } case '"', '\'', '`': if quote == 0 && brackets == 0 { - quote = r - } else if r == quote { + quote = s[i] + } else if s[i] == quote { quote = 0 } - current.WriteRune(r) case '(': brackets++ - current.WriteRune(r) case ')': if brackets > 0 { brackets-- } - current.WriteRune(r) case '|': if quote == 0 && brackets == 0 { - // Found a pipe outside quotes/brackets, split here - result = append(result, strings.TrimSpace(current.String())) - current.Reset() - } else { - current.WriteRune(r) + result = append(result, strings.TrimSpace(s[start:i])) + start = i + 1 } - default: - current.WriteRune(r) } } - // Add the last part - if current.Len() > 0 { - result = append(result, strings.TrimSpace(current.String())) + // drop trailing empty part. + if start < len(s) { + result = append(result, strings.TrimSpace(s[start:])) } return result } +func forEachAndPart(s string, fn func(part string)) { + start := 0 + for i := 0; i <= len(s); i++ { + if i < len(s) && andSeps[s[i]] == 0 { + continue + } + part := strings.TrimSpace(s[start:i]) + if part != "" { + fn(part) + } + start = i + 1 + } +} + +func forEachPipePart(s string, fn func(part string)) { + quote := byte(0) + brackets := 0 + start := 0 + + for i := 0; i < len(s); i++ { + switch s[i] { + case '\\': + if i+1 < len(s) { + i++ + } + case '"', '\'', '`': + if quote == 0 && brackets == 0 { + quote = s[i] + } else if s[i] == quote { + quote = 0 + } + case '(': + brackets++ + case ')': + if brackets > 0 { + brackets-- + } + case '|': + if quote == 0 && brackets == 0 { + fn(strings.TrimSpace(s[start:i])) + start = i + 1 + } + } + } + if start < len(s) { + fn(strings.TrimSpace(s[start:])) + } +} + // Parse implements strutils.Parser. func (on *RuleOn) Parse(v string) error { on.raw = v - rules := splitAnd(v) - checkAnd := make(CheckMatchAll, 0, len(rules)) + ruleCount := 0 + forEachAndPart(v, func(_ string) { + ruleCount++ + }) + checkAnd := make(CheckMatchAll, 0, ruleCount) errs := gperr.NewBuilder("rule.on syntax errors") - isResponseChecker := false - for i, rule := range rules { - if rule == "" { - continue - } - parsed, isResp, err := parseOn(rule) + i := 0 + forEachAndPart(v, func(rule string) { + i++ + parsed, phase, err := parseOn(rule) if err != nil { - errs.AddSubjectf(err, "line %d", i+1) - continue - } - if isResp { - isResponseChecker = true + errs.AddSubjectf(err, "line %d", i) + return } + on.phase |= phase checkAnd = append(checkAnd, parsed) - } + }) on.checker = checkAnd - on.isResponseChecker = isResponseChecker return errs.Error() } @@ -605,33 +694,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) { return []byte(on.String()), nil } -func parseOn(line string) (Checker, bool, error) { - ors := splitPipe(line) - - if len(ors) > 1 { +func parseOn(line string) (Checker, PhaseFlag, error) { + orCount := 0 + forEachPipePart(line, func(_ string) { + orCount++ + }) + if orCount > 1 { + var phase PhaseFlag errs := gperr.NewBuilder("rule.on syntax errors") - checkOr := make(CheckMatchSingle, len(ors)) - isResponseChecker := false - for i, or := range ors { - curCheckers, isResp, err := parseOn(or) + checkOr := make(CheckMatchSingle, orCount) + i := 0 + forEachPipePart(line, func(or string) { + i++ + checkFunc, req, err := parseOnAtom(or) if err != nil { - errs.Add(err) - continue + errs.AddSubjectf(err, "or[%d]", i) + return } - if isResp { - isResponseChecker = true - } - checkOr[i] = curCheckers.(CheckFunc) - } + checkOr[i-1] = checkFunc + phase |= req + }) if err := errs.Error(); err != nil { - return nil, false, err + return nil, phase, err } - return checkOr, isResponseChecker, nil + return checkOr, phase, nil } + return parseOnAtom(line) +} + +func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) { + var phase PhaseFlag subject, args, err := parse(line) if err != nil { - return nil, false, err + return nil, phase, err } negate := false @@ -642,20 +738,21 @@ func parseOn(line string) (Checker, bool, error) { checker, ok := checkers[subject] if !ok { - return nil, false, ErrInvalidOnTarget.Subject(subject) + return nil, phase, ErrInvalidOnTarget.Subject(subject) } - validArgs, err := checker.validate(args) + req, validArgs, err := checker.validate(args) if err != nil { - return nil, false, gperr.Wrap(err).With(checker.help.Error()) + return nil, phase, gperr.Wrap(err).With(checker.help.Error()) } + phase |= req checkFunc := checker.builder(validArgs) if negate { origCheckFunc := checkFunc - checkFunc = func(w http.ResponseWriter, r *http.Request) bool { + checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool { return !origCheckFunc(w, r) } } - return checkFunc, checker.isResponseChecker, nil + return checkFunc, phase, nil } diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index 7b1492b3..99c19731 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -12,6 +12,7 @@ import ( "github.com/yusing/godoxy/internal/route" "github.com/yusing/godoxy/internal/route/routes" . "github.com/yusing/godoxy/internal/route/rules" + httputils "github.com/yusing/goutils/http" expect "github.com/yusing/goutils/testing" "golang.org/x/crypto/bcrypt" ) @@ -386,7 +387,7 @@ func TestOnCorrectness(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - w := httptest.NewRecorder() + w := httputils.NewResponseModifier(httptest.NewRecorder()) var on RuleOn err := on.Parse(tt.checker) expect.NoError(t, err) diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go index 174babe6..54e0ffb7 100644 --- a/internal/route/rules/parser.go +++ b/internal/route/rules/parser.go @@ -1,8 +1,7 @@ package rules import ( - "bytes" - "fmt" + "strings" "unicode" "github.com/yusing/goutils/env" @@ -25,6 +24,76 @@ var quoteChars = [256]bool{ '`': true, } +func parseSimple(v string) (subject string, args []string, err error, ok bool) { + brackets := 0 + for i := range len(v) { + switch v[i] { + case '\\', '$', '"', '\'', '`', '\t', '\r', '\n': + return "", nil, nil, false + case '(': + brackets++ + case ')': + if brackets == 0 { + return "", nil, ErrUnterminatedBrackets, true + } + brackets-- + } + } + if brackets != 0 { + return "", nil, ErrUnterminatedBrackets, true + } + + i := 0 + for i < len(v) && v[i] == ' ' { + i++ + } + if i >= len(v) { + return "", nil, nil, true + } + + start := i + for i < len(v) && v[i] != ' ' { + i++ + } + subject = v[start:i] + + if i >= len(v) { + return subject, nil, nil, true + } + + argCount := 0 + for j := i; j < len(v); { + for j < len(v) && v[j] == ' ' { + j++ + } + if j >= len(v) { + break + } + argCount++ + for j < len(v) && v[j] != ' ' { + j++ + } + } + if argCount == 0 { + return subject, nil, nil, true + } + args = make([]string, 0, argCount) + for i < len(v) { + for i < len(v) && v[i] == ' ' { + i++ + } + if i >= len(v) { + break + } + start = i + for i < len(v) && v[i] != ' ' { + i++ + } + args = append(args, v[start:i]) + } + return subject, args, nil, true +} + // parse expression to subject and args // with support for quotes, escaped chars, and env substitution, e.g. // @@ -32,14 +101,21 @@ var quoteChars = [256]bool{ // error 403 Forbidden\ \"foo\"\ \"bar\". // error 403 "Message: ${CLOUDFLARE_API_KEY}" func parse(v string) (subject string, args []string, err error) { - buf := bytes.NewBuffer(make([]byte, 0, len(v))) + if subject, args, err, ok := parseSimple(v); ok { + return subject, args, err + } + + buf := getStringBuffer(len(v)) + args = make([]string, 0, 4) escaped := false quote := rune(0) brackets := 0 - var envVar bytes.Buffer - var missingEnvVars []string + var ( + envVar strings.Builder + missingEnvVars []string + ) inEnvVar := false expectingBrace := false @@ -71,7 +147,8 @@ func parse(v string) (subject string, args []string, err error) { if ch, ok := escapedChars[r]; ok { buf.WriteRune(ch) } else { - fmt.Fprintf(buf, `\%c`, r) + buf.WriteRune('\\') + buf.WriteRune(r) } escaped = false continue diff --git a/internal/route/rules/parser_test.go b/internal/route/rules/parser_test.go index c8451c86..9249fa7c 100644 --- a/internal/route/rules/parser_test.go +++ b/internal/route/rules/parser_test.go @@ -4,6 +4,7 @@ import ( "strconv" "testing" + gperr "github.com/yusing/goutils/errs" expect "github.com/yusing/goutils/testing" ) @@ -13,6 +14,7 @@ func TestParser(t *testing.T) { input string subject string args []string + wantErr gperr.Error }{ { name: "basic", @@ -90,6 +92,10 @@ func TestParser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { subject, args, err := parse(tt.input) + if tt.wantErr != nil { + expect.ErrorIs(t, tt.wantErr, err) + return + } // t.Log(subject, args, err) expect.NoError(t, err) expect.Equal(t, subject, tt.subject) diff --git a/internal/route/rules/phase.go b/internal/route/rules/phase.go new file mode 100644 index 00000000..d9bd7ae1 --- /dev/null +++ b/internal/route/rules/phase.go @@ -0,0 +1,29 @@ +package rules + +import "strings" + +type PhaseFlag uint8 + +const ( + PhaseNone PhaseFlag = 0 + PhasePre PhaseFlag = 1 << (iota - 1) + PhasePost +) + +func (phase PhaseFlag) IsPostRule() bool { + return phase&PhasePost != 0 +} + +func (phase PhaseFlag) String() string { + if phase == PhaseNone { + return "none" + } + var flags []string + if phase&PhasePre != 0 { + flags = append(flags, "PhasePre") + } + if phase&PhasePost != 0 { + flags = append(flags, "PhasePost") + } + return strings.Join(flags, ",") +} diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index 01b4640d..04cf8280 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -4,9 +4,16 @@ import ( "errors" "fmt" "net/http" + "reflect" + "slices" + "strings" + "unicode" + "github.com/goccy/go-yaml" "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog/log" + "github.com/yusing/godoxy/internal/serialization" + gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" "golang.org/x/net/http2" @@ -15,37 +22,36 @@ import ( type ( /* - Example: + Rules is a list of rules. - proxy.app1.rules: | - - name: default - do: | - rewrite / /index.html - serve /var/www/goaccess - - name: ws - on: | - header Connection Upgrade - header Upgrade websocket - do: bypass + Example: - proxy.app2.rules: | - - name: default - do: bypass - - name: block POST and PUT - on: method POST | method PUT - do: error 403 Forbidden + proxy.app1.rules: | + - name: default + do: | + rewrite / /index.html + serve /var/www/goaccess + - name: ws + on: | + header Connection Upgrade + header Upgrade websocket + do: bypass + + proxy.app2.rules: | + - name: default + do: bypass + - name: block POST and PUT + on: method POST | method PUT + do: error 403 Forbidden */ + //nolint:recvcheck Rules []Rule - /* - Rule is a rule for a reverse proxy. - It do `Do` when `On` matches. - - A rule can have multiple lines of on. - - All lines of on must match, - but each line can have multiple checks that - one match means this line is matched. - */ + // Rule represents a reverse proxy rule. + // The `Do` field is executed when `On` matches. + // + // - A rule may have multiple lines in the `On` section. + // - All `On` lines must match for the rule to trigger. + // - Each line can have several checks—one match per line is enough for that line. Rule struct { Name string `json:"name"` On RuleOn `json:"on" swaggertype:"string"` @@ -53,103 +59,230 @@ type ( } ) -func (rule *Rule) IsResponseRule() bool { - return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler() +func isDefaultRule(rule Rule) bool { + return rule.Name == "default" || rule.On.raw == OnDefault } -func (rules Rules) Validate() error { +func (rules Rules) Validate() gperr.Error { var defaultRulesFound []int - for i, rule := range rules { - if rule.Name == "default" || rule.On.raw == OnDefault { + for i := range rules { + rule := rules[i] + if isDefaultRule(rule) { defaultRulesFound = append(defaultRulesFound, i) } + if rules[i].Name == "" { + // set name to index if name is empty + rules[i].Name = fmt.Sprintf("rule[%d]", i) + } } if len(defaultRulesFound) > 1 { return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound)) } + for i := range rules { + r1 := rules[i] + if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() { + continue + } + sig1, ok := matcherSignature(r1.On.raw) + if !ok { + continue + } + for j := i + 1; j < len(rules); j++ { + r2 := rules[j] + if isDefaultRule(r2) || r2.On.phase.IsPostRule() { + continue + } + sig2, ok := matcherSignature(r2.On.raw) + if !ok || sig1 != sig2 { + continue + } + return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j) + } + } return nil } +func (rule Rule) doesTerminateInPre() bool { + for _, cmd := range rule.Do.pre { + handler, ok := cmd.(Handler) + if !ok { + continue + } + if handler.Terminates() { + return true + } + } + return false +} + +func matcherSignature(raw string) (string, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", false + } + + andParts := splitAnd(raw) + if len(andParts) == 0 { + return "", false + } + + canonAnd := make([]string, 0, len(andParts)) + for _, andPart := range andParts { + orParts := splitPipe(andPart) + if len(orParts) == 0 { + continue + } + canonOr := make([]string, 0, len(orParts)) + for _, atom := range orParts { + subject, args, err := parse(strings.TrimSpace(atom)) + if err != nil || subject == "" { + return "", false + } + canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00")) + } + slices.Sort(canonOr) + canonOr = slices.Compact(canonOr) + canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")") + } + + slices.Sort(canonAnd) + canonAnd = slices.Compact(canonAnd) + if len(canonAnd) == 0 { + return "", false + } + return strings.Join(canonAnd, "&"), true +} + +// Parse parses a rule configuration string. +// It first tries the block syntax (if the string contains a top-level '{'), +// then falls back to YAML syntax. +func (rules *Rules) Parse(config string) error { + config = strings.TrimSpace(config) + if config == "" { + return nil + } + + // Prefer block syntax if it looks like block syntax. + if hasTopLevelLBrace(config) { + blockRules, err := parseBlockRules(config) + if err == nil { + *rules = blockRules + return nil + } + // Fall through to YAML (backward compatibility). + } + + // YAML fallback + var anySlice []any + yamlErr := yaml.Unmarshal([]byte(config), &anySlice) + if yamlErr == nil { + return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false) + } + + // If YAML fails and we didn't try block syntax yet, try it now. + blockRules, err := parseBlockRules(config) + if err == nil { + *rules = blockRules + return nil + } + return err +} + +// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments. +// Used to decide whether to prioritize the block syntax. +func hasTopLevelLBrace(s string) bool { + quote := rune(0) + inLine := false + inBlock := false + + for i := 0; i < len(s); i++ { + c := s[i] + + if inLine { + if c == '\n' { + inLine = false + } + continue + } + if inBlock { + if c == '*' && i+1 < len(s) && s[i+1] == '/' { + inBlock = false + i++ + } + continue + } + + if quote != 0 { + if quote != '`' && c == '\\' && i+1 < len(s) { + i++ + continue + } + if rune(c) == quote { + quote = 0 + } + continue + } + + switch c { + case '\'', '"', '`': + quote = rune(c) + continue + case '{': + return true + case '#': + inLine = true + continue + case '/': + if i+1 < len(s) && s[i+1] == '/' { + inLine = true + i++ + continue + } + if i+1 < len(s) && s[i+1] == '*' { + inBlock = true + i++ + continue + } + default: + if unicode.IsSpace(rune(c)) { + continue + } + } + } + return false +} + // BuildHandler returns a http.HandlerFunc that implements the rules. func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { if len(rules) == 0 { return up } - defaultRule := Rule{ - Name: "default", - Do: Command{ - raw: "pass", - exec: BypassCommand{}, - }, - } + var defaultRule *Rule var nonDefaultRules Rules - hasDefaultRule := false - for i, rule := range rules { - if rule.Name == "default" || rule.On.raw == OnDefault { - defaultRule = rule - hasDefaultRule = true + for _, rule := range rules { + if isDefaultRule(rule) { + r := rule + defaultRule = &r } else { - // set name to index if name is empty - if rule.Name == "" { - rule.Name = fmt.Sprintf("rule[%d]", i) - } nonDefaultRules = append(nonDefaultRules, rule) } } if len(nonDefaultRules) == 0 { - if defaultRule.Do.isBypass() { + if defaultRule == nil || defaultRule.Do.raw == CommandUpstream { return up } - if defaultRule.IsResponseRule() { - return func(w http.ResponseWriter, r *http.Request) { - rm := httputils.NewResponseModifier(w) - defer func() { - if _, err := rm.FlushRelease(); err != nil { - logError(err, r) - } - }() - w = rm - up(w, r) - err := defaultRule.Do.exec.Handle(w, r) - if err != nil && !errors.Is(err, errTerminated) { - appendRuleError(rm, &defaultRule, err) - } - } - } - return func(w http.ResponseWriter, r *http.Request) { - rm := httputils.NewResponseModifier(w) - defer func() { - if _, err := rm.FlushRelease(); err != nil { - logError(err, r) - } - }() - w = rm - err := defaultRule.Do.exec.Handle(w, r) - if err == nil { - up(w, r) - return - } - if !errors.Is(err, errTerminated) { - appendRuleError(rm, &defaultRule, err) - } - } } - preRules := make(Rules, 0, len(nonDefaultRules)) - postRules := make(Rules, 0, len(nonDefaultRules)) - for _, rule := range nonDefaultRules { - if rule.IsResponseRule() { - postRules = append(postRules, rule) - } else { - preRules = append(preRules, rule) - } + execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error { + return cmd.pre.ServeHTTP(w, r, up) } - isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule() - defaultTerminates := isTerminatingHandler(defaultRule.Do.exec) + execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error { + return cmd.post.ServeHTTP(w, r, up) + } return func(w http.ResponseWriter, r *http.Request) { rm := httputils.NewResponseModifier(w) @@ -159,104 +292,84 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { } }() - w = rm + var hasError bool - shouldCallUpstream := true - preMatched := false + preRules := make(Rules, 0, len(nonDefaultRules)+1) + if defaultRule != nil { + preRules = append(preRules, *defaultRule) + } + preRules = append(preRules, nonDefaultRules...) - if hasDefaultRule && !isDefaultRulePost && !defaultTerminates { - if defaultRule.Do.isBypass() { - // continue to upstream - } else { - err := defaultRule.Handle(w, r) - if err != nil { - if !errors.Is(err, errTerminated) { - appendRuleError(rm, &defaultRule, err) - } - shouldCallUpstream = false + executedPre := make([]bool, len(preRules)) + terminatedInPre := make([]bool, len(preRules)) + preTerminated := false + for i, rule := range preRules { + if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) { + continue + } + if preTerminated { + // Preserve post-only commands (e.g. logging) even after + // pre-phase termination. + if len(rule.Do.pre) == 0 { + executedPre[i] = true } + continue } - } - if shouldCallUpstream { - for _, rule := range preRules { - if rule.Check(w, r) { - preMatched = true - if rule.Do.isBypass() { - break // post rules should still execute - } - err := rule.Handle(w, r) - if err != nil { - if !errors.Is(err, errTerminated) { - appendRuleError(rm, &rule, err) - } - shouldCallUpstream = false - break - } + executedPre[i] = true + if err := execPreCommand(rule.Do, rm, r); err != nil { + if errors.Is(err, errTerminateRule) { + terminatedInPre[i] = true + preTerminated = true + continue } + logError(err, r) + hasError = true } } - if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched { - if defaultRule.Do.isBypass() { - // continue to upstream - } else { - err := defaultRule.Handle(w, r) - if err != nil { - if !errors.Is(err, errTerminated) { - appendRuleError(rm, &defaultRule, err) - return - } - shouldCallUpstream = false + if !rm.HasStatus() { + if hasError { + http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } else { // call upstream if no WriteHeader or Write was called and no error occurred + up(rm, r) + } + } + + // Run post commands for rules that actually executed in pre phase, + // unless that same rule terminated in pre phase. + for i, rule := range preRules { + if !executedPre[i] || terminatedInPre[i] { + continue + } + if err := execPostCommand(rule.Do, rm, r); err != nil { + if errors.Is(err, errTerminateRule) { + continue } + logError(err, r) } } - if shouldCallUpstream { - up(w, r) - } - - // if no post rules, we are done here - if len(postRules) == 0 && !isDefaultRulePost { - return - } - - for _, rule := range postRules { - if rule.Check(w, r) { - err := rule.Handle(w, r) - if err != nil { - if !errors.Is(err, errTerminated) { - appendRuleError(rm, &rule, err) - } - return + // Run true post-matcher rules after response is available. + for _, rule := range nonDefaultRules { + if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) { + continue + } + // Post-rule matchers are only evaluated after upstream, so commands parsed + // as "pre" for requirement purposes still need to run in this phase. + if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil { + if errors.Is(err, errTerminateRule) { + continue } + logError(err, r) + } + if err := execPostCommand(rule.Do, rm, r); err != nil { + if errors.Is(err, errTerminateRule) { + continue + } + logError(err, r) } } - - if isDefaultRulePost { - err := defaultRule.Handle(w, r) - if err != nil && !errors.Is(err, errTerminated) { - appendRuleError(rm, &defaultRule, err) - } - } - } -} - -func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) { - // rm.AppendError("rule: %s, error: %w", rule.Name, err) -} - -func isTerminatingHandler(handler CommandHandler) bool { - switch h := handler.(type) { - case TerminatingCommand: - return true - case Commands: - if len(h) == 0 { - return false - } - return isTerminatingHandler(h[len(h)-1]) - default: - return false } } @@ -264,34 +377,30 @@ func (rule *Rule) String() string { return rule.Name } -func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool { +func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool { if rule.On.checker == nil { return true } - v := rule.On.checker.Check(w, r) - return v -} - -func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error { - return rule.Do.exec.Handle(w, r) + return rule.On.Check(w, r) } //go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed var errStreamClosed error +//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected +var errClientDisconnected error + func logError(err error, r *http.Request) { - if errors.Is(err, errStreamClosed) { + if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) { return } - var h2Err http2.StreamError - if errors.As(err, &h2Err) { + if h2Err, ok := errors.AsType[http2.StreamError](err); ok { // ignore these errors if h2Err.Code == http2.ErrCodeStreamClosed { return } } - var h3Err *http3.Error - if errors.As(err, &h3Err) { + if h3Err, ok := errors.AsType[*http3.Error](err); ok { // ignore these errors switch h3Err.ErrorCode { case diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go index 7a4b2e81..fc590e0d 100644 --- a/internal/route/rules/rules_test.go +++ b/internal/route/rules/rules_test.go @@ -19,28 +19,71 @@ func TestRulesValidate(t *testing.T) { { name: "no default rule", rules: ` -- name: rule1 - on: header Host example.com - do: pass - `, +header Host example.com { + pass +}`, }, { name: "multiple default rules", rules: ` -- name: default - do: pass -- name: rule1 - on: default - do: pass - `, +default { + pass +} + +default { + pass +}`, want: ErrMultipleDefaultRules, }, + { + name: "multiple responses on same condition", + rules: ` +header Host example.com { + error 404 "not found" +} + +header Host example.com { + error 403 "forbidden" +} +`, + want: ErrDeadRule, + }, + { + name: "same condition different formatting error then proxy", + rules: ` +header Host example.com & method GET { + error 404 "not found" +} + +method GET +header Host example.com { + proxy http://127.0.0.1:8080 +} +`, + want: ErrDeadRule, + }, + { + name: "same condition with non terminating first rule", + rules: ` +header Host example.com { + set resp_header X-Test first +} + +header Host example.com { + error 403 "forbidden" +} +`, + want: nil, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var rules Rules convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules)) require.True(t, convertible) + require.NoError(t, err) + + err = rules.Validate() if tt.want == nil { assert.NoError(t, err) @@ -50,3 +93,38 @@ func TestRulesValidate(t *testing.T) { }) } } + +func TestHasTopLevelLBrace(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + { + name: "escaped quote inside double quoted string", + in: `"test\"more{"`, + want: false, + }, + { + name: "escaped quote inside single quoted string", + in: "'test\\'more{'", + want: false, + }, + { + name: "top-level brace outside quoted string", + in: `"test\"more" {`, + want: true, + }, + { + name: "backtick keeps existing behavior", + in: "`test\\`more{`", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, hasTopLevelLBrace(tt.in)) + }) + } +} diff --git a/internal/route/rules/scanner.go b/internal/route/rules/scanner.go new file mode 100644 index 00000000..4dc5e7a9 --- /dev/null +++ b/internal/route/rules/scanner.go @@ -0,0 +1,273 @@ +package rules + +import ( + "strings" + "unicode" + + gperr "github.com/yusing/goutils/errs" +) + +// Tokenizer provides utilities for parsing rule syntax with proper handling +// of quotes, comments, and env vars. +// +// This is intentionally reusable by both the top-level rule block parser and +// the nested do-block parser. +type Tokenizer struct { + src string + length int +} + +// newTokenizer creates a tokenizer for the given source. +func newTokenizer(src string) Tokenizer { + return Tokenizer{src: src, length: len(src)} +} + +// skipComments skips whitespace, line comments, and block comments. +// It returns the new position and an error if a block comment is unterminated. +func (t *Tokenizer) skipComments(pos int, atLineStart bool, prevIsSpace bool) (int, gperr.Error) { + for pos < t.length { + c := t.src[pos] + + // Skip whitespace + if unicode.IsSpace(rune(c)) { + pos++ + atLineStart = false + prevIsSpace = true + continue + } + + // Check for line comment: // or # + if c == '/' { + if pos+1 < t.length && t.src[pos+1] == '/' { + // Skip to end of line + for pos < t.length && t.src[pos] != '\n' { + pos++ + } + atLineStart = true + prevIsSpace = true + continue + } + } + if c == '#' && (atLineStart || prevIsSpace) { + // Skip to end of line + for pos < t.length && t.src[pos] != '\n' { + pos++ + } + atLineStart = true + prevIsSpace = true + continue + } + + // Check for block comment: /* + if c == '/' && pos+1 < t.length && t.src[pos+1] == '*' { + pos += 2 + closed := false + for pos+1 < t.length { + if t.src[pos] == '*' && t.src[pos+1] == '/' { + pos += 2 + closed = true + break + } + pos++ + } + if !closed { + return 0, ErrInvalidBlockSyntax.Withf("unterminated block comment") + } + atLineStart = false + prevIsSpace = true + continue + } + + break + } + + return pos, nil +} + +// scanToBrace scans from pos until it finds '{' outside quotes, or returns an error. +func (t *Tokenizer) scanToBrace(pos int) (int, gperr.Error) { + quote := rune(0) + for pos < t.length { + c := rune(t.src[pos]) + if quote != 0 { + if c == quote { + quote = 0 + } + pos++ + continue + } + if c == '"' || c == '\'' || c == '`' { + quote = c + pos++ + continue + } + if c == '{' { + return pos, nil + } + if c == '}' { + return 0, ErrInvalidBlockSyntax.Withf("unmatched '}' in block header") + } + pos++ + } + return 0, ErrInvalidBlockSyntax.Withf("expected '{' after block header") +} + +// findMatchingBrace finds the matching '}' for a '{' starting at startPos. +// It respects quotes/backticks and ${...} env vars. +func (t *Tokenizer) findMatchingBrace(startPos int) (int, gperr.Error) { + pos := startPos + braceDepth := 1 + quote := rune(0) + inLine := false + inBlock := false + atLineStart := true + prevIsSpace := true + + for pos < t.length { + c := rune(t.src[pos]) + + if inLine { + if c == '\n' { + inLine = false + atLineStart = true + prevIsSpace = true + } + pos++ + continue + } + if inBlock { + if c == '*' && pos+1 < t.length && t.src[pos+1] == '/' { + pos += 2 + inBlock = false + continue + } + if c == '\n' { + atLineStart = true + prevIsSpace = true + } + pos++ + continue + } + + if quote != 0 { + if c == quote { + quote = 0 + } + if c == '\n' { + atLineStart = true + prevIsSpace = true + } else { + atLineStart = false + prevIsSpace = unicode.IsSpace(c) + } + pos++ + continue + } + + if c == '"' || c == '\'' || c == '`' { + quote = c + atLineStart = false + prevIsSpace = false + pos++ + continue + } + + // Comments (only outside quotes) at token boundary + if c == '#' && (atLineStart || prevIsSpace) { + inLine = true + pos++ + continue + } + if c == '/' && pos+1 < t.length { + n := rune(t.src[pos+1]) + if (atLineStart || prevIsSpace) && n == '/' { + inLine = true + pos += 2 + continue + } + if (atLineStart || prevIsSpace) && n == '*' { + inBlock = true + pos += 2 + continue + } + } + + if c == '$' && pos+1 < t.length && t.src[pos+1] == '{' { + // Skip env var ${...} + pos += 2 + envBraceDepth := 1 + envQuote := rune(0) + for pos < t.length { + ec := rune(t.src[pos]) + if envQuote != 0 { + if ec == envQuote { + envQuote = 0 + } + pos++ + continue + } + if ec == '"' || ec == '\'' || ec == '`' { + envQuote = ec + pos++ + continue + } + if ec == '{' { + envBraceDepth++ + } else if ec == '}' { + envBraceDepth-- + if envBraceDepth == 0 { + pos++ // Move past the closing '}' + break + } + } + pos++ + } + continue + } + + switch c { + case '{': + braceDepth++ + case '}': + braceDepth-- + if braceDepth == 0 { + return pos, nil + } + } + + if c == '\n' { + atLineStart = true + prevIsSpace = true + } else { + atLineStart = false + prevIsSpace = unicode.IsSpace(c) + } + pos++ + } + + return 0, ErrInvalidBlockSyntax.Withf("unmatched '{' at position %d", startPos) +} + +// parseHeaderToBrace parses an expression/header starting at start and returns: +// - header: trimmed src[start:bracePos] +// - bracePos: position of '{' (outside quotes/backticks) +func parseHeaderToBrace(src string, start int) (header string, bracePos int, err gperr.Error) { + t := newTokenizer(src) + bracePos, err = t.scanToBrace(start) + if err != nil { + return "", 0, err + } + return strings.TrimSpace(src[start:bracePos]), bracePos, nil +} + +// findMatchingBrace finds the matching '}' for a '{' at position startPos. +// It respects quotes/backticks and ${...} env vars in do_body. +func findMatchingBrace(src string, pos *int, startPos int) (int, gperr.Error) { + t := newTokenizer(src) + endPos, err := t.findMatchingBrace(startPos) + if err != nil { + return 0, err + } + *pos = endPos + 1 + return endPos, nil +} diff --git a/internal/route/rules/scanner_test.go b/internal/route/rules/scanner_test.go new file mode 100644 index 00000000..044e1159 --- /dev/null +++ b/internal/route/rules/scanner_test.go @@ -0,0 +1,39 @@ +package rules + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTokenizer_skipComments_UnterminatedBlockComment(t *testing.T) { + tok := newTokenizer("/* unterminated") + pos, err := tok.skipComments(0, true, true) + require.Error(t, err) + require.Equal(t, 0, pos) +} + +func TestTokenizer_skipComments_SkipsLineAndBlockComments(t *testing.T) { + src := " // line\n /* block */\n # hash\n default" + tok := newTokenizer(src) + pos, err := tok.skipComments(0, true, true) + require.NoError(t, err) + require.Equal(t, strings.Index(src, "default"), pos) +} + +func TestTokenizer_scanToBrace_IgnoresQuotedBraces(t *testing.T) { + src := "cond \"{\" {" // the brace inside quotes must be ignored + tok := newTokenizer(src) + bracePos, err := tok.scanToBrace(0) + require.NoError(t, err) + require.Equal(t, strings.LastIndex(src, "{"), bracePos) +} + +func TestTokenizer_findMatchingBrace_IgnoresQuotedClosingBrace(t *testing.T) { + src := `{ "}" }` + tok := newTokenizer(src) + endPos, err := tok.findMatchingBrace(1) // body starts after the first '{' + require.NoError(t, err) + require.Equal(t, strings.LastIndex(src, "}"), endPos) +} diff --git a/internal/route/rules/template.go b/internal/route/rules/template.go index f9ee576a..a934bca4 100644 --- a/internal/route/rules/template.go +++ b/internal/route/rules/template.go @@ -4,13 +4,13 @@ import ( "io" "net/http" "strings" - "unsafe" httputils "github.com/yusing/goutils/http" ) type templateString struct { string + isTemplate bool } @@ -23,32 +23,28 @@ func (tmpl *keyValueTemplate) Unpack() (string, templateString) { return tmpl.key, tmpl.tmpl } -func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error { +func (tmpl *templateString) ExpandVars(w *httputils.ResponseModifier, req *http.Request, dst io.Writer) (phase PhaseFlag, err error) { if !tmpl.isTemplate { - _, err := dstW.Write(strtobNoCopy(tmpl.string)) - return err + _, err := asBytesBufferLike(dst).WriteString(tmpl.string) + return PhaseNone, err } - return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW) + return ExpandVars(w, req, tmpl.string, dst) } -func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) { +func (tmpl *templateString) ExpandVarsToString(w *httputils.ResponseModifier, r *http.Request) (string, PhaseFlag, error) { if !tmpl.isTemplate { - return tmpl.string, nil + return tmpl.string, PhaseNone, nil } var buf strings.Builder - err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf) + phase, err := tmpl.ExpandVars(w, r, &buf) if err != nil { - return "", err + return "", PhaseNone, err } - return buf.String(), nil + return buf.String(), phase, nil } func (tmpl *templateString) Len() int { return len(tmpl.string) } - -func strtobNoCopy(s string) []byte { - return unsafe.Slice(unsafe.StringData(s), len(s)) -} diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index dde408b9..12b72d02 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog" nettypes "github.com/yusing/godoxy/internal/net/types" gperr "github.com/yusing/goutils/errs" @@ -16,7 +17,7 @@ import ( ) type ( - ValidateFunc func(args []string) (any, error) + ValidateFunc func(args []string) (phase PhaseFlag, parsedArgs any, err error) Tuple[T1, T2 any] struct { First T1 Second T2 @@ -37,6 +38,8 @@ type ( MapValueMatcher = Tuple[string, Matcher] ) +var cidrCache = xsync.NewMap[string, *net.IPNet]() + func (t *Tuple[T1, T2]) Unpack() (T1, T2) { return t.First, t.Second } @@ -62,7 +65,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string { } // validateSingleMatcher returns Matcher with the matcher validated. -func validateSingleMatcher(args []string) (any, error) { +func validateSingleMatcher(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } @@ -70,7 +73,7 @@ func validateSingleMatcher(args []string) (any, error) { } // toKVOptionalVMatcher returns *MapValueMatcher that value is optional. -func toKVOptionalVMatcher(args []string) (any, error) { +func toKVOptionalVMatcher(args []string) (any, gperr.Error) { switch len(args) { case 1: return &MapValueMatcher{args[0], nil}, nil @@ -85,20 +88,8 @@ func toKVOptionalVMatcher(args []string) (any, error) { } } -func toKeyValueTemplate(args []string) (any, error) { - if len(args) != 2 { - return nil, ErrExpectTwoArgs - } - - isTemplate, err := validateTemplate(args[1], false) - if err != nil { - return nil, err - } - return &keyValueTemplate{args[0], isTemplate}, nil -} - // validateURL returns types.URL with the URL validated. -func validateURL(args []string) (any, error) { +func validateURL(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } @@ -116,22 +107,27 @@ func validateURL(args []string) (any, error) { } // validateCIDR returns types.CIDR with the CIDR validated. -func validateCIDR(args []string) (any, error) { +func validateCIDR(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } - if !strings.Contains(args[0], "/") { - args[0] += "/32" + cidr := args[0] + if !strings.Contains(cidr, "/") { + cidr += "/32" } - _, ipnet, err := net.ParseCIDR(args[0]) + if cached, ok := cidrCache.Load(cidr); ok { + return cached, nil + } + _, ipnet, err := net.ParseCIDR(cidr) if err != nil { return nil, ErrInvalidArguments.With(err) } + cidrCache.Store(cidr, ipnet) return ipnet, nil } // validateURLPath returns string with the path validated. -func validateURLPath(args []string) (any, error) { +func validateURLPath(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } @@ -148,7 +144,7 @@ func validateURLPath(args []string) (any, error) { return p, nil } -func validateURLPathMatcher(args []string) (any, error) { +func validateURLPathMatcher(args []string) (any, gperr.Error) { path, err := validateURLPath(args) if err != nil { return nil, err @@ -157,7 +153,7 @@ func validateURLPathMatcher(args []string) (any, error) { } // validateFSPath returns string with the path validated. -func validateFSPath(args []string) (any, error) { +func validateFSPath(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } @@ -169,7 +165,7 @@ func validateFSPath(args []string) (any, error) { } // validateMethod returns string with the method validated. -func validateMethod(args []string) (any, error) { +func validateMethod(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } @@ -200,7 +196,7 @@ func validateStatusCode(status string) (int, error) { // - 3xx // - 4xx // - 5xx -func validateStatusRange(args []string) (any, error) { +func validateStatusRange(args []string) (any, gperr.Error) { if len(args) != 1 { return nil, ErrExpectOneArg } @@ -232,7 +228,7 @@ func validateStatusRange(args []string) (any, error) { } // validateUserBCryptPassword returns *HashedCrendential with the password validated. -func validateUserBCryptPassword(args []string) (any, error) { +func validateUserBCryptPassword(args []string) (any, gperr.Error) { if len(args) != 2 { return nil, ErrExpectTwoArgs } @@ -240,64 +236,93 @@ func validateUserBCryptPassword(args []string) (any, error) { } // validateModField returns CommandHandler with the field validated. -func validateModField(mod FieldModifier, args []string) (CommandHandler, error) { +func validateModField(mod FieldModifier, args []string) (phase PhaseFlag, handler HandlerFunc, err error) { if len(args) == 0 { - return nil, ErrExpectTwoOrThreeArgs + return phase, nil, ErrExpectTwoOrThreeArgs } setField, ok := modFields[args[0]] if !ok { - return nil, ErrUnknownModField.Subject(args[0]) + return phase, nil, ErrUnknownModField.Subject(args[0]) } if mod == ModFieldRemove { if len(args) != 2 { - return nil, ErrExpectTwoArgs + return phase, nil, ErrExpectTwoArgs } // setField expect validateStrTuple args = append(args, "") } - validArgs, err := setField.validate(args[1:]) + phase, validArgs, err := setField.validate(args[1:]) if err != nil { - return nil, gperr.Wrap(err).With(setField.help.Error()) + return phase, nil, gperr.Wrap(err).With(setField.help.Error()) } + modder := setField.builder(validArgs) switch mod { case ModFieldAdd: add := modder.add if add == nil { - return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod) + return phase, nil, ErrInvalidArguments.Withf("add is not supported for field %s", args[0]) } - return add, nil + return phase, add, nil case ModFieldRemove: remove := modder.remove if remove == nil { - return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod) + return phase, nil, ErrInvalidArguments.Withf("remove is not supported for field %s", args[0]) } - return remove, nil + return phase, remove, nil } set := modder.set if set == nil { - return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod) + return phase, nil, ErrInvalidArguments.Withf("set is not supported for field %s", args[0]) } - return set, nil + return phase, set, nil } -func validateTemplate(tmplStr string, newline bool) (templateString, error) { +func validateTemplate(tmplStr string, newline bool) (phase PhaseFlag, tmpl templateString, err error) { if newline && !strings.HasSuffix(tmplStr, "\n") { tmplStr += "\n" } if !NeedExpandVars(tmplStr) { - return templateString{tmplStr, false}, nil + return phase, templateString{tmplStr, false}, nil } - err := ValidateVars(tmplStr) + phase, err = ValidateVars(tmplStr) if err != nil { - return templateString{}, err + return phase, templateString{}, gperr.Wrap(err) } - return templateString{tmplStr, true}, nil + return phase, templateString{tmplStr, true}, nil } -func validateLevel(level string) (zerolog.Level, error) { +func validatePreRequestKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) { + if len(args) != 2 { + return phase, nil, ErrExpectTwoArgs + } + + phase = PhasePre + tmplReq, tmpl, err := validateTemplate(args[1], false) + if err != nil { + return phase, nil, err + } + phase |= tmplReq + return phase, &keyValueTemplate{args[0], tmpl}, nil +} + +func validatePostResponseKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) { + if len(args) != 2 { + return phase, nil, ErrExpectTwoArgs + } + + phase = PhasePost + tmplReq, tmpl, err := validateTemplate(args[1], false) + if err != nil { + return phase, nil, err + } + phase |= tmplReq + return phase, &keyValueTemplate{args[0], tmpl}, nil +} + +func validateLevel(level string) (zerolog.Level, gperr.Error) { l, err := zerolog.ParseLevel(level) if err != nil { return zerolog.NoLevel, ErrInvalidArguments.With(err) diff --git a/internal/route/rules/var_bench_test.go b/internal/route/rules/var_bench_test.go index 328ad338..2acf5f11 100644 --- a/internal/route/rules/var_bench_test.go +++ b/internal/route/rules/var_bench_test.go @@ -23,7 +23,7 @@ func BenchmarkExpandVars(b *testing.B) { testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}} for b.Loop() { - err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard) + _, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard) if err != nil { b.Fatal(err) } diff --git a/internal/route/rules/vars.go b/internal/route/rules/vars.go index 08de2d61..20023d19 100644 --- a/internal/route/rules/vars.go +++ b/internal/route/rules/vars.go @@ -1,15 +1,16 @@ package rules import ( + "bytes" "io" "net/http" "net/http/httptest" "net/url" "regexp" "strings" + "unsafe" httputils "github.com/yusing/goutils/http" - ioutils "github.com/yusing/goutils/io" ) // TODO: remove middleware/vars.go and use this instead @@ -45,41 +46,84 @@ var ( } ) +type bytesBufferLike interface { + io.Writer + WriteByte(c byte) error + WriteString(s string) (int, error) +} + +type bytesBufferAdapter struct { + io.Writer +} + +func (b bytesBufferAdapter) WriteByte(c byte) error { + buf := [1]byte{c} + _, err := b.Write(buf[:]) + return err +} + +func (b bytesBufferAdapter) WriteString(s string) (int, error) { + return b.Write(unsafe.Slice(unsafe.StringData(s), len(s))) // avoid copy +} + +func asBytesBufferLike(w io.Writer) bytesBufferLike { + switch w := w.(type) { + case *bytes.Buffer: + return w + case bytesBufferLike: + return w + default: + return bytesBufferAdapter{w} + } +} + // ValidateVars validates the variables in the given string. -// It returns ErrUnexpectedVar if any invalid variable is found. -func ValidateVars(s string) error { +// It returns the phase that the variables require and an error if any error occurs. +// +// Possible errors: +// - ErrUnexpectedVar: if any invalid variable is found +// - ErrUnterminatedEnvVar: missing closing } +// - ErrUnterminatedQuotes: missing closing " or ' or ` +// - ErrUnterminatedParenthesis: missing closing ) +func ValidateVars(s string) (phase PhaseFlag, err error) { return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard) } -func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error { - dst := ioutils.NewBufferedWriter(dstW, 1024) - defer dst.Close() - +// ExpandVars expands the variables in the given string and writes the result to the given writer. +// It returns the phase that the variables require and an error if any error occurs. +// +// Possible errors: +// - ErrUnexpectedVar: if any invalid variable is found +// - ErrUnterminatedEnvVar: missing closing } +// - ErrUnterminatedQuotes: missing closing " or ' or ` +// - ErrUnterminatedParenthesis: missing closing ) +func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) (phase PhaseFlag, err error) { + dst := asBytesBufferLike(dstW) for i := 0; i < len(src); i++ { ch := src[i] if ch != '$' { - if err := dst.WriteByte(ch); err != nil { - return err + if err = dst.WriteByte(ch); err != nil { + return phase, err } continue } // Look ahead if i+1 >= len(src) { - return ErrUnterminatedEnvVar + return phase, ErrUnterminatedEnvVar } j := i + 1 switch src[j] { case '$': // $$ -> literal '$' if err := dst.WriteByte('$'); err != nil { - return err + return phase, err } i = j continue case '{': // ${...} pass through as-is if _, err := dst.WriteString("${"); err != nil { - return err + return phase, err } i = j // we've consumed the '{' too continue @@ -102,24 +146,26 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds if getter, ok := dynamicVarSubsMap[name]; ok { // Function-like variables isStatic = false + phase |= getter.phase args, nextIdx, err := extractArgs(src, j, name) if err != nil { - return err + return phase, err } i = nextIdx - actual, err = getter(args, w, req) + actual, err = getter.get(args, w, req) if err != nil { - return err + return phase, err } - } else if getter, ok := staticReqVarSubsMap[name]; ok { + } else if getter, ok := staticReqVarSubsMap[name]; ok { // always available actual = getter(req) - } else if getter, ok := staticRespVarSubsMap[name]; ok { + } else if getter, ok := staticRespVarSubsMap[name]; ok { // post response actual = getter(w) + phase |= PhasePost } else { - return ErrUnexpectedVar.Subject(name) + return phase, ErrUnexpectedVar.Subject(name) } if _, err := dst.WriteString(actual); err != nil { - return err + return phase, err } if isStatic { i = k - 1 @@ -128,10 +174,10 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds } // No valid construct after '$' - return ErrUnterminatedEnvVar.Withf("around $ at position %d", j) + return phase, ErrUnterminatedEnvVar.Withf("around $ at position %d", j) } - return nil + return phase, nil } func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) { diff --git a/internal/route/rules/vars_dynamic.go b/internal/route/rules/vars_dynamic.go index 717064bf..075d030a 100644 --- a/internal/route/rules/vars_dynamic.go +++ b/internal/route/rules/vars_dynamic.go @@ -11,58 +11,88 @@ import ( var ( VarHeader = "header" VarResponseHeader = "resp_header" + VarCookie = "cookie" VarQuery = "arg" VarForm = "form" VarPostForm = "postform" ) -type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) +type dynamicVarGetter struct { + phase PhaseFlag + get func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) +} var dynamicVarSubsMap = map[string]dynamicVarGetter{ - VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { - key, index, err := getKeyAndIndex(args) - if err != nil { - return "", err - } - return getValueByKeyAtIndex(req.Header, key, index) - }, - VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { - key, index, err := getKeyAndIndex(args) - if err != nil { - return "", err - } - return getValueByKeyAtIndex(w.Header(), key, index) - }, - VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { - key, index, err := getKeyAndIndex(args) - if err != nil { - return "", err - } - return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index) - }, - VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { - key, index, err := getKeyAndIndex(args) - if err != nil { - return "", err - } - if req.Form == nil { - if err := req.ParseForm(); err != nil { + VarHeader: { + phase: PhaseNone, + get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { return "", err } - } - return getValueByKeyAtIndex(req.Form, key, index) + return getValueByKeyAtIndex(req.Header, key, index) + }, }, - VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { - key, index, err := getKeyAndIndex(args) - if err != nil { - return "", err - } - if req.Form == nil { - if err := req.ParseForm(); err != nil { + VarResponseHeader: { + phase: PhasePost, + get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { return "", err } - } - return getValueByKeyAtIndex(req.PostForm, key, index) + return getValueByKeyAtIndex(w.Header(), key, index) + }, + }, + VarCookie: { + phase: PhaseNone, + get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + sharedData := httputils.GetSharedData(w) + return getValueByKeyAtIndex(sharedData.GetCookiesMap(req), key, index) + }, + }, + VarQuery: { + phase: PhaseNone, + get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index) + }, + }, + VarForm: { + phase: PhaseNone, + get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + if req.Form == nil { + if err := req.ParseForm(); err != nil { + return "", err + } + } + return getValueByKeyAtIndex(req.Form, key, index) + }, + }, + VarPostForm: { + phase: PhaseNone, + get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + if req.Form == nil { + if err := req.ParseForm(); err != nil { + return "", err + } + } + return getValueByKeyAtIndex(req.PostForm, key, index) + }, }, } diff --git a/internal/route/rules/vars_test.go b/internal/route/rules/vars_test.go index ca487bad..279a51ad 100644 --- a/internal/route/rules/vars_test.go +++ b/internal/route/rules/vars_test.go @@ -232,7 +232,7 @@ func TestExpandVars(t *testing.T) { { name: "req_method", input: "$req_method", - want: "POST", + want: http.MethodPost, }, { name: "req_path", @@ -484,7 +484,7 @@ func TestExpandVars(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, tt.input, &out) + _, err := ExpandVars(testResponseModifier, testRequest, tt.input, &out) if tt.wantErr { require.Error(t, err) @@ -506,7 +506,7 @@ func TestExpandVars_Integration(t *testing.T) { testResponseModifier.WriteHeader(http.StatusOK) var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, + _, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_url $status_code User-Agent=$header(User-Agent)", &out) @@ -520,7 +520,7 @@ func TestExpandVars_Integration(t *testing.T) { testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, + _, err := ExpandVars(testResponseModifier, testRequest, "Query: $arg(q), Page: $arg(page)", &out) @@ -537,7 +537,7 @@ func TestExpandVars_Integration(t *testing.T) { testResponseModifier.WriteHeader(http.StatusOK) var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, + _, err := ExpandVars(testResponseModifier, testRequest, "Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)", &out) @@ -560,7 +560,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) { { name: "https scheme", request: &http.Request{ - Method: "GET", + Method: http.MethodGet, URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"}, TLS: &tls.ConnectionState{}, // Simulate TLS connection }, @@ -572,7 +572,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder - err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out) + _, err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out) require.NoError(t, err) require.Equal(t, tt.expected, out.String()) }) @@ -598,7 +598,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) { for _, varExpr := range upstreamVars { t.Run(varExpr, func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, varExpr, &out) + _, err := ExpandVars(testResponseModifier, testRequest, varExpr, &out) // Should not error, may return empty string require.NoError(t, err) }) @@ -614,16 +614,16 @@ func TestExpandVars_NoHostPort(t *testing.T) { t.Run("req_host without port", func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out) + _, err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out) require.NoError(t, err) require.Equal(t, "example.com", out.String()) }) t.Run("req_port without port", func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out) + _, err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out) require.NoError(t, err) - require.Empty(t, out.String()) + require.Equal(t, "", out.String()) }) } @@ -636,16 +636,16 @@ func TestExpandVars_NoRemotePort(t *testing.T) { t.Run("remote_host without port", func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out) + _, err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out) require.NoError(t, err) - require.Empty(t, out.String()) + require.Equal(t, "", out.String()) }) t.Run("remote_port without port", func(t *testing.T) { var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out) + _, err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out) require.NoError(t, err) - require.Empty(t, out.String()) + require.Equal(t, "", out.String()) }) } @@ -654,7 +654,7 @@ func TestExpandVars_WhitespaceHandling(t *testing.T) { testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder()) var out strings.Builder - err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out) + _, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out) require.NoError(t, err) require.Equal(t, "GET /test", out.String()) } @@ -699,7 +699,7 @@ func TestValidateVars(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateVars(tt.input) + _, err := ValidateVars(tt.input) if tt.wantErr { require.Error(t, err) } else {