From 1ec2872f3d4c2a89a210727c19abbd8a6cfdaf7b Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 25 Oct 2025 22:43:47 +0800 Subject: [PATCH] feat(rules): replace go templates with custom variable expansion - Replace template syntax ({{ .Request.Method }}) with $-prefixed variables ($req_method) - Implement custom variable parser with static ($req_method, $status_code) and dynamic ($header(), $arg(), $form()) variables - Replace templateOrStr interface with templateString struct and ExpandVars methods - Add parser improvements for reliable quote handling - Add new error types: ErrUnterminatedParenthesis, ErrUnexpectedVar, ErrExpectOneOrTwoArgs - Update all tests and help text to use new variable syntax - Add comprehensive unit and benchmark tests for variable expansion --- internal/route/rules/do.go | 22 +- internal/route/rules/do_log_test.go | 41 +- internal/route/rules/do_set.go | 28 +- internal/route/rules/do_set_test.go | 10 +- internal/route/rules/errors.go | 25 +- internal/route/rules/help.go | 71 ++- internal/route/rules/http_flow_test.go | 14 +- internal/route/rules/on.go | 2 +- internal/route/rules/parser.go | 30 +- internal/route/rules/template.go | 54 +- internal/route/rules/validate.go | 30 +- internal/route/rules/var_bench_test.go | 28 ++ internal/route/rules/vars.go | 214 ++++++++ internal/route/rules/vars_dynamic.go | 81 +++ internal/route/rules/vars_static.go | 92 ++++ internal/route/rules/vars_test.go | 672 +++++++++++++++++++++++++ 16 files changed, 1253 insertions(+), 161 deletions(-) create mode 100644 internal/route/rules/var_bench_test.go create mode 100644 internal/route/rules/vars.go create mode 100644 internal/route/rules/vars_dynamic.go create mode 100644 internal/route/rules/vars_static.go create mode 100644 internal/route/rules/vars_test.go diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 4ed7f19c..e29968c2 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -334,7 +334,7 @@ var commands = map[string]struct { helpListItem("Response", "the response object"), "", "Example:", - helpExample(CommandLog, "info", "/dev/stdout", "{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"), + helpExample(CommandLog, "info", "/dev/stdout", "$req_method $req_url $status_code"), ), args: map[string]string{ "level": "the log level", @@ -372,7 +372,7 @@ var commands = map[string]struct { logger = f } return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { - err := executeReqRespTemplateTo(tmpl, logger, w, r) + err := tmpl.ExpandVars(w, r, logger) if err != nil { return err } @@ -390,7 +390,7 @@ var commands = map[string]struct { helpListItem("Response", "the response object"), "", "Example:", - helpExample(CommandNotify, "info", "ntfy", "Received request to {{ .Request.URL }}", "{{ .Request.Method }} {{ .Response.StatusCode }}"), + helpExample(CommandNotify, "info", "ntfy", "Received request to $req_url", "$req_method $status_code"), ), args: map[string]string{ "level": "the log level", @@ -432,12 +432,12 @@ var commands = map[string]struct { return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len())) - err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r) + err := titleTmpl.ExpandVars(w, r, respBuf) if err != nil { return err } titleLen := respBuf.Len() - err = executeReqRespTemplateTo(bodyTmpl, respBuf, w, r) + err = bodyTmpl.ExpandVars(w, r, respBuf) if err != nil { return err } @@ -455,16 +455,8 @@ var commands = map[string]struct { }, } -type reqResponseTemplateData struct { - Request *http.Request - Response struct { - StatusCode int - Header http.Header - } -} - -type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr] -type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr] +type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString] +type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString] // Parse implements strutils.Parser. func (cmd *Command) Parse(v string) error { diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go index 272f6a31..94839441 100644 --- a/internal/route/rules/do_log_test.go +++ b/internal/route/rules/do_log_test.go @@ -54,7 +54,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) { err = parseRules(fmt.Sprintf(` - name: log-request-response do: | - log info %q '{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }} {{ index (index .Response.Header "Content-Type") 0 }}' + log info %q '$req_method $req_url $status_code $resp_header(Content-Type)' `, tempFile.Name()), &rules) require.NoError(t, err) @@ -84,10 +84,10 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) { err := parseRules(` - name: log-stdout do: | - log info /dev/stdout "stdout: {{ .Request.Method }} {{ .Response.StatusCode }}" + log info /dev/stdout "stdout: $req_method $status_code" - name: log-stderr do: | - log error /dev/stderr "stderr: {{ .Request.URL.Path }} {{ .Response.StatusCode }}" + log error /dev/stderr "stderr: $req_path $status_code" `, &rules) require.NoError(t, err) @@ -126,13 +126,13 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) { err = parseRules(fmt.Sprintf(` - name: log-info do: | - log info %s "INFO: {{ .Request.Method }} {{ .Response.StatusCode }}" + log info %s "INFO: $req_method $status_code" - name: log-warn do: | - log warn %s "WARN: {{ .Request.URL.Path }} {{ .Response.StatusCode }}" + log warn %s "WARN: $req_path $status_code" - name: log-error do: | - log error %s "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}" + log error %s "ERROR: $req_method $req_path $status_code" `, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules) require.NoError(t, err) @@ -177,7 +177,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) { err = parseRules(fmt.Sprintf(` - name: log-with-templates do: | - log info %s 'Request: {{ .Request.Method }} {{ .Request.URL }} Host: {{ .Request.Host }} User-Agent: {{ index .Request.Header "User-Agent" 0 }} Response: {{ .Response.StatusCode }} Custom-Header: {{ index .Response.Header "X-Custom-Header" 0 }} Content-Length: {{ index .Response.Header "Content-Length" 0 }}' + log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)' `, tempFile.Name()), &rules) require.NoError(t, err) @@ -231,11 +231,11 @@ func TestLogCommand_ConditionalLogging(t *testing.T) { - name: log-success on: status 2xx do: | - log info %q "SUCCESS: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}" + log info %q "SUCCESS: $req_method $req_path $status_code" - name: log-error on: status 4xx | status 5xx do: | - log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}" + log error %q "ERROR: $req_method $req_path $status_code" `, successFile.Name(), errorFile.Name()), &rules) require.NoError(t, err) @@ -288,7 +288,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) { err = parseRules(fmt.Sprintf(` - name: log-multiple do: | - log info %q "{{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"`, tempFile.Name()), &rules) + log info %q "$req_method $req_path $status_code"`, tempFile.Name()), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -339,7 +339,7 @@ func TestLogCommand_FilePermissions(t *testing.T) { var rules Rules err = parseRules(fmt.Sprintf(` - on: status 2xx - do: log info %q "{{ .Request.Method }} {{ .Response.StatusCode }}"`, logFilePath), &rules) + do: log info %q "$req_method $status_code"`, logFilePath), &rules) require.NoError(t, err) handler := rules.BuildHandler(upstream) @@ -374,27 +374,12 @@ func TestLogCommand_FilePermissions(t *testing.T) { } func TestLogCommand_InvalidTemplate(t *testing.T) { - upstream := mockUpstream(200, "success") - var rules Rules // Test with invalid template syntax err := parseRules(` - name: log-invalid do: | - log info /dev/stdout "{{ .Invalid.Field }}"`, &rules) - // Should not error during parsing, but template execution will fail gracefully - assert.NoError(t, err) - - handler := rules.BuildHandler(upstream) - - req := httptest.NewRequest("GET", "/test", nil) - w := httptest.NewRecorder() - - // Should not panic - assert.NotPanics(t, func() { - handler.ServeHTTP(w, req) - }) - - assert.Equal(t, 200, w.Code) + log info /dev/stdout "$invalid_var"`, &rules) + assert.ErrorIs(t, err, ErrUnexpectedVar) } diff --git a/internal/route/rules/do_set.go b/internal/route/rules/do_set.go index edab827f..6a18f60f 100644 --- a/internal/route/rules/do_set.go +++ b/internal/route/rules/do_set.go @@ -54,7 +54,7 @@ var modFields = map[string]struct { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -62,7 +62,7 @@ var modFields = map[string]struct { return nil }), add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -89,7 +89,7 @@ var modFields = map[string]struct { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -97,7 +97,7 @@ var modFields = map[string]struct { return nil }), add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -124,7 +124,7 @@ var modFields = map[string]struct { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -134,7 +134,7 @@ var modFields = map[string]struct { return nil }), add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -165,7 +165,7 @@ var modFields = map[string]struct { k, tmpl := args.(*keyValueTemplate).Unpack() return &FieldHandler{ set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -181,7 +181,7 @@ var modFields = map[string]struct { return nil }), add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { - v, err := executeRequestTemplateString(tmpl, r) + v, err := tmpl.ExpandVarsToString(w, r) if err != nil { return err } @@ -221,7 +221,7 @@ var modFields = map[string]struct { helpListItem("Request", "the request object"), "", "Example:", - helpExample(FieldBody, "HTTP STATUS: {{ .Request.Method }} {{ .Request.URL.Path }}"), + helpExample(FieldBody, "HTTP STATUS: $req_method $req_path"), ), args: map[string]string{ "template": "the body template", @@ -234,7 +234,7 @@ var modFields = map[string]struct { return validateTemplate(args[0], true) }, builder: func(args any) *FieldHandler { - tmpl := args.(templateOrStr) + tmpl := args.(templateString) return &FieldHandler{ set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { if r.Body != nil { @@ -244,7 +244,7 @@ var modFields = map[string]struct { bufPool := GetInitResponseModifier(w).BufPool() b := bufPool.GetBuffer() - err := executeRequestTemplateTo(tmpl, b, r) + err := tmpl.ExpandVars(w, r, b) if err != nil { return err } @@ -266,7 +266,7 @@ var modFields = map[string]struct { helpListItem("Response", "the response object"), "", "Example:", - helpExample(FieldResponseBody, "HTTP STATUS: {{ .Request.Method }} {{ .Response.StatusCode }}"), + helpExample(FieldResponseBody, "HTTP STATUS: $req_method $status_code"), ), args: map[string]string{ "template": "the response body template", @@ -279,12 +279,12 @@ var modFields = map[string]struct { return validateTemplate(args[0], true) }, builder: func(args any) *FieldHandler { - tmpl := args.(templateOrStr) + tmpl := args.(templateString) return &FieldHandler{ set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { rm := GetInitResponseModifier(w) rm.ResetBody() - return executeReqRespTemplateTo(tmpl, rm, rm, r) + return tmpl.ExpandVars(w, r, rm) }), } }, diff --git a/internal/route/rules/do_set_test.go b/internal/route/rules/do_set_test.go index 3f007d6e..aae2c0a6 100644 --- a/internal/route/rules/do_set_test.go +++ b/internal/route/rules/do_set_test.go @@ -367,7 +367,7 @@ func TestFieldHandler_Body(t *testing.T) { }{ { name: "set body with template", - template: "Hello {{ .Request.Method }} {{ .Request.URL.Path }}", + template: "Hello $req_method $req_path", setup: func(r *http.Request) { r.Method = "POST" r.URL.Path = "/test" @@ -424,7 +424,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) { }{ { name: "set response body with template", - template: "Response: {{ .Request.Method }} {{ .Request.URL.Path }}", + template: "Response: $req_method $req_path", setup: func(r *http.Request) { r.Method = "GET" r.URL.Path = "/api/test" @@ -552,19 +552,19 @@ func TestFieldValidation(t *testing.T) { { name: "body valid template", field: FieldBody, - args: []string{"Hello {{ .Request.Method }}"}, + args: []string{"Hello $req_method"}, wantError: false, }, { name: "body invalid template syntax", field: FieldBody, - args: []string{"Hello {{ .InvalidField "}, + args: []string{"Hello $invalid_field"}, wantError: true, }, { name: "response body valid template", field: FieldResponseBody, - args: []string{"Response: {{ .Request.Method }}"}, + args: []string{"Response: $req_method"}, wantError: false, }, { diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index b15efa7f..f99a7091 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -5,18 +5,25 @@ import ( ) var ( - ErrUnterminatedQuotes = gperr.New("unterminated quotes") - ErrUnterminatedBrackets = gperr.New("unterminated brackets") - ErrUnterminatedEnvVar = gperr.New("unterminated env var") - ErrUnknownDirective = gperr.New("unknown directive") - ErrUnknownModField = gperr.New("unknown field") - ErrEnvVarNotFound = gperr.New("env variable not found") - ErrInvalidArguments = gperr.New("invalid arguments") - ErrInvalidOnTarget = gperr.New("invalid `rule.on` target") - ErrInvalidCommandSequence = gperr.New("invalid command sequence") + ErrUnterminatedQuotes = gperr.New("unterminated quotes") + ErrUnterminatedBrackets = gperr.New("unterminated brackets") + ErrUnterminatedParenthesis = gperr.New("unterminated parenthesis") + ErrUnterminatedEnvVar = gperr.New("unterminated env var") + ErrUnknownDirective = gperr.New("unknown directive") + ErrUnknownModField = gperr.New("unknown field") + ErrEnvVarNotFound = gperr.New("env variable not found") + ErrInvalidArguments = gperr.New("invalid arguments") + ErrInvalidOnTarget = gperr.New("invalid `rule.on` target") + ErrInvalidCommandSequence = gperr.New("invalid command sequence") + + // vars errors + ErrNoArgProvided = gperr.New("no argument provided") + ErrUnexpectedVar = gperr.New("unexpected variable") + ErrUnexpectedQuote = gperr.New("unexpected quote") ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg") ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg") + ErrExpectOneOrTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 1 or 2 args") ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args") ErrExpectTwoOrThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 or 3 args") ErrExpectThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 3 args") diff --git a/internal/route/rules/help.go b/internal/route/rules/help.go index 1f92f4e4..de1bf00c 100644 --- a/internal/route/rules/help.go +++ b/internal/route/rules/help.go @@ -28,11 +28,11 @@ func helpExample(cmd string, args ...string) string { var out strings.Builder pos := 0 for { - start := strings.Index(arg[pos:], "{{") + start := strings.IndexByte(arg[pos:], '$') if start == -1 { if pos < len(arg) { - // If no template at all (pos == 0), cyan highlight for whole-arg - // Otherwise, for mixed strings containing templates, leave non-template text unhighlighted + // If no variable at all (pos == 0), cyan highlight for whole-arg + // Otherwise, for mixed strings containing variables, leave non-variable text unhighlighted if pos == 0 { out.WriteString(ansi.WithANSI(arg[pos:], ansi.HighlightCyan)) } else { @@ -43,20 +43,31 @@ func helpExample(cmd string, args ...string) string { } start += pos if start > pos { - // Non-template text should not be highlighted + // Non-variable text should not be highlighted out.WriteString(arg[pos:start]) } - end := strings.Index(arg[start+2:], "}}") - if end == -1 { - // Unmatched template start; write remainder without highlighting - out.WriteString(arg[start:]) - break + // Parse variable name and optional function call + end := start + 1 + for end < len(arg) && (arg[end] == '_' || (arg[end] >= 'a' && arg[end] <= 'z') || (arg[end] >= 'A' && arg[end] <= 'Z') || (arg[end] >= '0' && arg[end] <= '9')) { + end++ } - end += start + 2 - inner := strings.TrimSpace(arg[start+2 : end]) - parts := strings.Split(inner, ".") - out.WriteString(helpTemplateVar(parts...)) - pos = end + 2 + // Check for function call + if end < len(arg) && arg[end] == '(' { + parenCount := 1 + end++ + for end < len(arg) && parenCount > 0 { + switch arg[end] { + case '(': + parenCount++ + case ')': + parenCount-- + } + end++ + } + } + varExpr := arg[start:end] + out.WriteString(helpVar(varExpr)) + pos = end } fmt.Fprintf(&sb, ` "%s"`, out.String()) } @@ -87,17 +98,29 @@ func helpFuncCall(fn string, args ...string) string { return sb.String() } -// helpTemplateVar generates a string like "{{ .Request.Method }} {{ .Request.URL.Path }}" -func helpTemplateVar(parts ...string) string { - var sb strings.Builder - sb.WriteString(ansi.WithANSI("{{ ", ansi.HighlightWhite)) - for i, part := range parts { - sb.WriteString(ansi.WithANSI(part, ansi.HighlightCyan)) - if i < len(parts)-1 { - sb.WriteString(".") - } +// helpVar generates a highlighted string for a variable like "$req_method" or "$header(X-Test)" +func helpVar(varExpr string) string { + if !strings.HasPrefix(varExpr, "$") { + return varExpr } - sb.WriteString(ansi.WithANSI(" }}", ansi.HighlightWhite)) + + // Check if it's a function call + parenIdx := strings.IndexByte(varExpr, '(') + if parenIdx == -1 { + // Simple variable like "$req_method" + return ansi.WithANSI(varExpr, ansi.HighlightCyan) + } + + // Function call like "$header(X-Test)" + var sb strings.Builder + sb.WriteString(ansi.WithANSI(varExpr[:parenIdx], ansi.HighlightCyan)) + sb.WriteString(ansi.WithANSI("(", ansi.HighlightWhite)) + + // Extract and highlight the arguments + argsStr := varExpr[parenIdx+1 : len(varExpr)-1] + sb.WriteString(ansi.WithANSI(argsStr, ansi.HighlightYellow)) + + sb.WriteString(ansi.WithANSI(")", ansi.HighlightWhite)) return sb.String() } diff --git a/internal/route/rules/http_flow_test.go b/internal/route/rules/http_flow_test.go index 375a8992..f0bf741a 100644 --- a/internal/route/rules/http_flow_test.go +++ b/internal/route/rules/http_flow_test.go @@ -218,7 +218,7 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) { err = parseRules(fmt.Sprintf(` - name: log-response on: path /test - do: log info %s "{{ .Request.Method }} {{ .Response.StatusCode }}" + do: log info %s "$req_method $status_code" `, tempFile.Name()), &rules) require.NoError(t, err) @@ -261,7 +261,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { err = parseRules(fmt.Sprintf(` - name: log-errors on: status 4xx - do: log error %s "{{ .Request.URL }} returned {{ .Response.StatusCode }}" + do: log error %s "$req_url returned $status_code" `, tempFile.Name()), &rules) require.NoError(t, err) @@ -364,11 +364,11 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { do: require_basic_auth "Protected Area" - name: log-all-requests do: | - log info %q "{{ .Request.Method }} {{ .Request.URL }} -> {{ .Response.StatusCode }}" + log info %q "$req_method $req_url -> $status_code" - name: log-errors on: status 4xx do: | - log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}" + log error %q "ERROR: $req_method $req_url $status_code" `, logFile.Name(), errorLogFile.Name()), &rules) require.NoError(t, err) @@ -610,10 +610,10 @@ func TestHTTPFlow_FormConditions(t *testing.T) { err := parseRules(` - name: process-form on: form username - do: set resp_header X-Username "{{ index .Request.Form.username 0 }}" + do: set resp_header X-Username "$form(username)" - name: process-postform on: postform email - do: set resp_header X-Email "{{ index .Request.PostForm.email 0 }}" + do: set resp_header X-Email "$postform(email)" `, &rules) require.NoError(t, err) @@ -923,7 +923,7 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) { - name: modify-response do: | set resp_header X-Modified "true" - set resp_body "Modified: {{ .Request.Method }} {{ .Request.URL.Path }}" + set resp_body "Modified: $req_method $req_path" `, &rules) require.NoError(t, err) diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 5bee9436..1d0ef54f 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -591,7 +591,7 @@ func parseOn(line string) (Checker, bool, gperr.Error) { validArgs, err := checker.validate(args) if err != nil { - return nil, false, err.Subject(subject).With(checker.help.Error()) + return nil, false, err.With(checker.help.Error()) } checkFunc := checker.builder(validArgs) diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go index 87375ffa..149da13b 100644 --- a/internal/route/rules/parser.go +++ b/internal/route/rules/parser.go @@ -19,6 +19,12 @@ var escapedChars = map[rune]rune{ ' ': ' ', } +var quoteChars = [256]bool{ + '"': true, + '\'': true, + '`': true, +} + // parse expression to subject and args // with support for quotes, escaped chars, and env substitution, e.g. // @@ -74,6 +80,19 @@ func parse(v string) (subject string, args []string, err gperr.Error) { buf.WriteRune('$') expectingBrace = false } + if quoteChars[r] { + switch { + case quote == 0 && brackets == 0: + quote = r + flush(false) + case r == quote: + quote = 0 + flush(true) + default: + buf.WriteRune(r) + } + continue + } switch r { case '\\': escaped = true @@ -106,17 +125,6 @@ func parse(v string) (subject string, args []string, err gperr.Error) { } else { buf.WriteRune(r) } - case '"', '\'', '`': - switch { - case quote == 0 && brackets == 0: - quote = r - flush(false) - case r == quote: - quote = 0 - flush(true) - default: - buf.WriteRune(r) - } case '(': brackets++ buf.WriteRune(r) diff --git a/internal/route/rules/template.go b/internal/route/rules/template.go index 7415c863..42c6f2dd 100644 --- a/internal/route/rules/template.go +++ b/internal/route/rules/template.go @@ -1,48 +1,52 @@ package rules import ( - "bytes" "io" "net/http" + "strings" + "unsafe" ) -type templateOrStr interface { - Execute(w io.Writer, data any) error - Len() int +type templateString struct { + string + isTemplate bool } -type strTemplate string +type keyValueTemplate struct { + key string + tmpl templateString +} -func (t strTemplate) Execute(w io.Writer, _ any) error { - n, err := w.Write([]byte(t)) - if err != nil { +func (tmpl *keyValueTemplate) Unpack() (string, templateString) { + return tmpl.key, tmpl.tmpl +} + +func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error { + if !tmpl.isTemplate { + _, err := dstW.Write(strtobNoCopy(tmpl.string)) return err } - if n != len(t) { - return io.ErrShortWrite + + return ExpandVars(GetInitResponseModifier(w), req, tmpl.string, dstW) +} + +func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) { + if !tmpl.isTemplate { + return tmpl.string, nil } - return nil -} -func (t strTemplate) Len() int { - return len(t) -} - -type keyValueTemplate = Tuple[string, templateOrStr] - -func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) { - var buf bytes.Buffer - err := tmpl.Execute(&buf, reqResponseTemplateData{Request: r}) + var buf strings.Builder + err := ExpandVars(GetInitResponseModifier(w), req, tmpl.string, &buf) if err != nil { return "", err } return buf.String(), nil } -func executeRequestTemplateTo(tmpl templateOrStr, o io.Writer, r *http.Request) error { - return tmpl.Execute(o, reqResponseTemplateData{Request: r}) +func (tmpl *templateString) Len() int { + return len(tmpl.string) } -func executeReqRespTemplateTo(tmpl templateOrStr, o io.Writer, w http.ResponseWriter, r *http.Request) error { - return tmpl.Execute(o, reqResponseTemplateData{Request: r, Response: GetInitResponseModifier(w).Response()}) +func strtobNoCopy(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) } diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index bed8f935..6b2e2062 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -8,7 +8,6 @@ import ( "path/filepath" "strconv" "strings" - "text/template" "github.com/rs/zerolog" nettypes "github.com/yusing/godoxy/internal/net/types" @@ -91,11 +90,11 @@ func toKeyValueTemplate(args []string) (any, gperr.Error) { return nil, ErrExpectTwoArgs } - tmpl, err := validateTemplate(args[1], false) + isTemplate, err := validateTemplate(args[1], false) if err != nil { return nil, err } - return &keyValueTemplate{args[0], tmpl}, nil + return &keyValueTemplate{args[0], isTemplate}, nil } // validateURL returns types.URL with the URL validated. @@ -300,33 +299,20 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E return set, nil } -func isTemplate(tmplStr string) bool { - return strings.Contains(tmplStr, "{{") -} - -type templateWithLen struct { - *template.Template - len int -} - -func (t *templateWithLen) Len() int { - return t.len -} - -func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) { +func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error) { if newline && !strings.HasSuffix(tmplStr, "\n") { tmplStr += "\n" } - if !isTemplate(tmplStr) { - return strTemplate(tmplStr), nil + if !NeedExpandVars(tmplStr) { + return templateString{tmplStr, false}, nil } - tmpl, err := template.New("template").Parse(tmplStr) + err := ValidateVars(tmplStr) if err != nil { - return nil, ErrInvalidArguments.With(err) + return templateString{}, gperr.Wrap(err) } - return &templateWithLen{tmpl, len(tmplStr)}, nil + return templateString{tmplStr, true}, nil } func validateLevel(level string) (zerolog.Level, gperr.Error) { diff --git a/internal/route/rules/var_bench_test.go b/internal/route/rules/var_bench_test.go new file mode 100644 index 00000000..bde8cf52 --- /dev/null +++ b/internal/route/rules/var_bench_test.go @@ -0,0 +1,28 @@ +package rules + +import ( + "io" + "net/http/httptest" + "net/url" + "testing" +) + +func BenchmarkExpandVars(b *testing.B) { + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier.WriteHeader(200) + testResponseModifier.Write([]byte("Hello, world!")) + testRequest := httptest.NewRequest("GET", "/", nil) + testRequest.Header.Set("User-Agent", "test-agent/1.0") + testRequest.Header.Set("X-Custom", "value1,value2") + testRequest.ContentLength = 12345 + testRequest.RemoteAddr = "192.168.1.100:54321" + testRequest.Form = url.Values{"param1": {"value1"}, "param2": {"value2"}} + testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}} + + for b.Loop() { + err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/internal/route/rules/vars.go b/internal/route/rules/vars.go new file mode 100644 index 00000000..9443f4f0 --- /dev/null +++ b/internal/route/rules/vars.go @@ -0,0 +1,214 @@ +package rules + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + + ioutils "github.com/yusing/goutils/io" +) + +// TODO: remove middleware/vars.go and use this instead + +type ( + reqVarGetter func(*http.Request) string + respVarGetter func(*ResponseModifier) string +) + +var reVar = regexp.MustCompile(`\$[\w_]+`) + +var validVarNameCharset = func() (ret [256]bool) { + for c := byte('a'); c <= 'z'; c++ { + ret[c] = true + } + for c := byte('A'); c <= 'Z'; c++ { + ret[c] = true + } + ret['_'] = true + return +}() + +func NeedExpandVars(s string) bool { + return reVar.MatchString(s) +} + +var ( + voidResponseModifier = NewResponseModifier(httptest.NewRecorder()) + dummyRequest = http.Request{ + Method: "GET", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + } +) + +// ValidateVars validates the variables in the given string. +// It returns ErrUnexpectedVar if any invalid variable is found. +func ValidateVars(s string) error { + return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard) +} + +func ExpandVars(w *ResponseModifier, req *http.Request, src string, dstW io.Writer) error { + dst := ioutils.NewBufferedWriter(dstW, 1024) + defer func() { + dst.Flush() + dst.Release() + }() + + for i := 0; i < len(src); i++ { + ch := src[i] + if ch != '$' { + if err := dst.WriteByte(ch); err != nil { + return err + } + continue + } + + // Look ahead + if i+1 >= len(src) { + return ErrUnterminatedEnvVar + } + j := i + 1 + + switch src[j] { + case '$': // $$ -> literal '$' + if err := dst.WriteByte('$'); err != nil { + return err + } + i = j + continue + case '{': // ${...} pass through as-is + if _, err := dst.WriteString("${"); err != nil { + return err + } + i = j // we've consumed the '{' too + continue + } + + if validVarNameCharset[src[j]] { + k := j + for k < len(src) { + c := src[k] + if validVarNameCharset[c] { + k++ + continue + } + break + } + name := src[j:k] + isStatic := true + + var actual string + if getter, ok := dynamicVarSubsMap[name]; ok { + // Function-like variables + isStatic = false + args, nextIdx, err := extractArgs(src, j, name) + if err != nil { + return err + } + i = nextIdx + actual, err = getter(args, w, req) + if err != nil { + return err + } + } else if getter, ok := staticReqVarSubsMap[name]; ok { + actual = getter(req) + } else if getter, ok := staticRespVarSubsMap[name]; ok { + actual = getter(w) + } else { + return ErrUnexpectedVar.Subject(name) + } + if _, err := dst.WriteString(actual); err != nil { + return err + } + if isStatic { + i = k - 1 + } + continue + } + + // No valid construct after '$' + return ErrUnterminatedEnvVar.Withf("around $ at position %d", j) + } + + return nil +} + +func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) { + // Find opening parenthesis + parenIdx := strings.IndexByte(src[i:], '(') + if parenIdx == -1 { + return nil, 0, ErrUnterminatedParenthesis.Withf("func %q at position %d", funcName, i) + } + parenIdx += i + + var ( + quote byte // current quote character (0 if not in quotes) + arg strings.Builder + ) + + nextIdx = parenIdx + 1 + for nextIdx < len(src) { + ch := src[nextIdx] + + if quote != 0 { + // We're inside a quoted string + if ch == quote { + // Closing quote - the content between quotes is now complete, add it + args = append(args, arg.String()) + arg.Reset() + quote = 0 + nextIdx++ + continue + } + // Inside quotes - add everything as-is + arg.WriteByte(ch) + nextIdx++ + continue + } + + // Not inside quotes + if quoteChars[ch] { + // Opening quote + quote = ch + nextIdx++ + continue + } + + if ch == ')' { + // End of arguments + if arg.Len() > 0 { + args = append(args, arg.String()) + } + return args, nextIdx, nil + } + + if ch == ',' { + // Argument separator + if arg.Len() > 0 { + args = append(args, arg.String()) + arg.Reset() + } + nextIdx++ + continue + } + + if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' { + // Whitespace outside quotes - skip + nextIdx++ + continue + } + + // Regular character - accumulate until we hit a delimiter + arg.WriteByte(ch) + nextIdx++ + } + + // Reached end of string without closing parenthesis + if quote != 0 { + return nil, 0, ErrUnterminatedQuotes.Withf("func %q", funcName) + } + return nil, 0, ErrUnterminatedParenthesis.Withf("func %q", funcName) +} diff --git a/internal/route/rules/vars_dynamic.go b/internal/route/rules/vars_dynamic.go new file mode 100644 index 00000000..38b8152d --- /dev/null +++ b/internal/route/rules/vars_dynamic.go @@ -0,0 +1,81 @@ +package rules + +import ( + "net/http" + "net/url" + "strconv" +) + +var ( + VarHeader = "header" + VarResponseHeader = "resp_header" + VarQuery = "arg" + VarForm = "form" + VarPostForm = "postform" +) + +type dynamicVarGetter func(args []string, w *ResponseModifier, req *http.Request) (string, error) + +var dynamicVarSubsMap = map[string]dynamicVarGetter{ + VarHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + return getValueByKeyAtIndex(req.Header, key, index) + }, + VarResponseHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + return getValueByKeyAtIndex(w.Header(), key, index) + }, + VarQuery: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + return getValueByKeyAtIndex(GetSharedData(w).GetQueries(req), key, index) + }, + VarForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + return getValueByKeyAtIndex(req.Form, key, index) + }, + VarPostForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) { + key, index, err := getKeyAndIndex(args) + if err != nil { + return "", err + } + return getValueByKeyAtIndex(req.PostForm, key, index) + }, +} + +func getValueByKeyAtIndex[Values http.Header | url.Values](values Values, key string, index int) (string, error) { + // NOTE: do not use Header.Get or http.CanonicalHeaderKey here, respect to user input + if values, ok := values[key]; ok && index < len(values) { + return values[index], nil + } + // ignore unknown header or index out of range + return "", nil +} + +func getKeyAndIndex(args []string) (key string, index int, err error) { + switch len(args) { + case 0: + return "", 0, ErrExpectNoArg + case 1: + return args[0], 0, nil + case 2: + index, err = strconv.Atoi(args[1]) + if err != nil { + return "", 0, ErrInvalidArguments.Withf("invalid index %q", args[1]) + } + return args[0], index, nil + default: + return "", 0, ErrExpectOneOrTwoArgs + } +} diff --git a/internal/route/rules/vars_static.go b/internal/route/rules/vars_static.go new file mode 100644 index 00000000..38b8d22b --- /dev/null +++ b/internal/route/rules/vars_static.go @@ -0,0 +1,92 @@ +package rules + +import ( + "net" + "net/http" + "strconv" + + "github.com/yusing/godoxy/internal/route/routes" +) + +const ( + VarRequestMethod = "req_method" + VarRequestScheme = "req_scheme" + VarRequestHost = "req_host" + VarRequestPort = "req_port" + VarRequestPath = "req_path" + VarRequestAddr = "req_addr" + VarRequestQuery = "req_query" + VarRequestURL = "req_url" + VarRequestURI = "req_uri" + VarRequestContentType = "req_content_type" + VarRequestContentLen = "req_content_length" + VarRemoteHost = "remote_host" + VarRemotePort = "remote_port" + VarRemoteAddr = "remote_addr" + + VarUpstreamName = "upstream_name" + VarUpstreamScheme = "upstream_scheme" + VarUpstreamHost = "upstream_host" + VarUpstreamPort = "upstream_port" + VarUpstreamAddr = "upstream_addr" + VarUpstreamURL = "upstream_url" + + VarRespContentType = "resp_content_type" + VarRespContentLen = "resp_content_length" + VarRespStatusCode = "status_code" +) + +var staticReqVarSubsMap = map[string]reqVarGetter{ + VarRequestMethod: func(req *http.Request) string { return req.Method }, + VarRequestScheme: func(req *http.Request) string { + if req.TLS != nil { + return "https" + } + return "http" + }, + VarRequestHost: func(req *http.Request) string { + reqHost, _, err := net.SplitHostPort(req.Host) + if err != nil { + return req.Host + } + return reqHost + }, + VarRequestPort: func(req *http.Request) string { + _, reqPort, _ := net.SplitHostPort(req.Host) + return reqPort + }, + VarRequestAddr: func(req *http.Request) string { return req.Host }, + VarRequestPath: func(req *http.Request) string { return req.URL.Path }, + VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery }, + VarRequestURL: func(req *http.Request) string { return req.URL.String() }, + VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() }, + VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") }, + VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) }, + VarRemoteHost: func(req *http.Request) string { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + return clientIP + } + return "" + }, + VarRemotePort: func(req *http.Request) string { + _, clientPort, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + return clientPort + } + return "" + }, + VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr }, + VarUpstreamName: routes.TryGetUpstreamName, + VarUpstreamScheme: routes.TryGetUpstreamScheme, + VarUpstreamHost: routes.TryGetUpstreamHost, + VarUpstreamPort: routes.TryGetUpstreamPort, + VarUpstreamAddr: routes.TryGetUpstreamAddr, + VarUpstreamURL: routes.TryGetUpstreamURL, +} + +var staticRespVarSubsMap = map[string]respVarGetter{ + VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") }, + VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) }, + VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) }, +} diff --git a/internal/route/rules/vars_test.go b/internal/route/rules/vars_test.go new file mode 100644 index 00000000..b6355daa --- /dev/null +++ b/internal/route/rules/vars_test.go @@ -0,0 +1,672 @@ +package rules + +import ( + "bytes" + "crypto/tls" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExtractArgs(t *testing.T) { + tests := []struct { + name string + src string + startPos int + funcName string + wantArgs []string + wantNextIdx int + wantErr bool + }{ + { + name: "unquoted single arg", + src: "header(X-Some-Header)", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header"}, + wantNextIdx: 20, + }, + { + name: "double quoted arg", + src: `header("X-Some-Header")`, + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header"}, + wantNextIdx: 22, + }, + { + name: "single quoted arg", + src: "header('X-Some-Header')", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header"}, + wantNextIdx: 22, + }, + { + name: "backtick quoted arg", + src: "header(`X-Some-Header`)", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header"}, + wantNextIdx: 22, + }, + { + name: "two args with double quotes and unquoted", + src: `header("X-Some-Header", 1)`, + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header", "1"}, + wantNextIdx: 25, + }, + { + name: "two args with single and double quotes", + src: "header('X-Some-Header', \"1\")", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header", "1"}, + wantNextIdx: 27, + }, + { + name: "two args with backtick and single quotes", + src: "header(`X-Some-Header`, '1')", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Some-Header", "1"}, + wantNextIdx: 27, + }, + { + name: "quoted string with nested different quotes", + src: `arg("'(value)'")`, + startPos: 0, + funcName: "arg", + wantArgs: []string{"'(value)'"}, + wantNextIdx: 15, + }, + { + name: "quoted string with backticks inside double quotes", + src: "header(\"value`with`backticks\")", + startPos: 0, + funcName: "header", + wantArgs: []string{"value`with`backticks"}, + wantNextIdx: 29, + }, + { + name: "multiple args with whitespace", + src: "header( \"X-Header\" , 2 )", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Header", "2"}, + wantNextIdx: 27, + }, + { + name: "empty quoted string", + src: `header("")`, + startPos: 0, + funcName: "header", + wantArgs: []string{""}, + wantNextIdx: 9, + }, + { + name: "multiple empty args", + src: `header("", "")`, + startPos: 0, + funcName: "header", + wantArgs: []string{"", ""}, + wantNextIdx: 13, + }, + { + name: "unquoted args separated by comma", + src: "header(key1,key2,key3)", + startPos: 0, + funcName: "header", + wantArgs: []string{"key1", "key2", "key3"}, + wantNextIdx: 21, + }, + { + name: "trailing whitespace before closing paren", + src: `header("value" )`, + startPos: 0, + funcName: "header", + wantArgs: []string{"value"}, + wantNextIdx: 16, + }, + { + name: "startPos not at beginning", + src: "prefix_header(X-Header)", + startPos: 7, + funcName: "header", + wantArgs: []string{"X-Header"}, + wantNextIdx: 22, + }, + { + name: "special chars in unquoted arg", + src: "header(X-Custom_Header.v1)", + startPos: 0, + funcName: "header", + wantArgs: []string{"X-Custom_Header.v1"}, + wantNextIdx: 25, + }, + { + name: "unterminated quote", + src: `header("X-Header`, + startPos: 0, + funcName: "header", + wantErr: true, + }, + { + name: "missing closing parenthesis", + src: `header("X-Header"`, + startPos: 0, + funcName: "header", + wantErr: true, + }, + { + name: "no opening parenthesis", + src: `header"X-Header"`, + startPos: 0, + funcName: "header", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args, nextIdx, err := extractArgs(tt.src, tt.startPos, tt.funcName) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.wantArgs, args) + require.Equal(t, tt.wantNextIdx, nextIdx) + } + }) + } +} + +func TestExpandVars(t *testing.T) { + // Create a comprehensive test request + testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", nil) + testRequest.Header.Set("Content-Type", "application/json") + testRequest.Header.Set("User-Agent", "test-agent/1.0") + testRequest.Header.Set("X-Custom", "value1,value2") + testRequest.ContentLength = 12345 + testRequest.RemoteAddr = "192.168.1.100:54321" + + // Create response modifier with headers + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier.Header().Set("Content-Type", "text/html") + testResponseModifier.Header().Set("X-Custom-Resp", "resp-value") + testResponseModifier.WriteHeader(200) + // set content length to 9876 by writing 9876 'a' bytes + testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876)) + + tests := []struct { + name string + input string + want string + wantErr bool + }{ + // Basic request variables + { + name: "req_method", + input: "$req_method", + want: "POST", + }, + { + name: "req_path", + input: "$req_path", + want: "/api/users", + }, + { + name: "req_query", + input: "$req_query", + want: "param1=value1¶m2=value2", + }, + { + name: "req_url", + input: "$req_url", + want: "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", + }, + { + name: "req_uri", + input: "$req_uri", + want: "/api/users?param1=value1¶m2=value2", + }, + { + name: "req_host", + input: "$req_host", + want: "example.com", + }, + { + name: "req_port", + input: "$req_port", + want: "8080", + }, + { + name: "req_addr", + input: "$req_addr", + want: "example.com:8080", + }, + { + name: "req_content_type", + input: "$req_content_type", + want: "application/json", + }, + { + name: "req_content_length", + input: "$req_content_length", + want: "12345", + }, + { + name: "remote_host", + input: "$remote_host", + want: "192.168.1.100", + }, + { + name: "remote_port", + input: "$remote_port", + want: "54321", + }, + { + name: "remote_addr", + input: "$remote_addr", + want: "192.168.1.100:54321", + }, + // Response variables + { + name: "status_code", + input: "$status_code", + want: "200", + }, + { + name: "resp_content_type", + input: "$resp_content_type", + want: "text/html", + }, + { + name: "resp_content_length", + input: "$resp_content_length", + want: "9876", + }, + // Function-like variables - header + { + name: "header single value", + input: "$header(User-Agent)", + want: "test-agent/1.0", + }, + { + name: "header with index 0", + input: "$header(X-Custom, 0)", + want: "value1", + }, + { + name: "header with index 1", + input: "$header(X-Custom, 1)", + want: "value2", + }, + { + name: "header not found", + input: "$header(X-Not-Found)", + want: "", + }, + { + name: "header index out of range", + input: "$header(X-Custom, 99)", + want: "", + }, + // Function-like variables - resp_header + { + name: "resp_header single value", + input: "$resp_header(Content-Type)", + want: "text/html", + }, + { + name: "resp_header custom header", + input: "$resp_header(X-Custom-Resp)", + want: "resp-value", + }, + { + name: "resp_header not found", + input: "$resp_header(X-Not-Found)", + want: "", + }, + // Function-like variables - arg (query parameters) + { + name: "arg single parameter", + input: "$arg(param1)", + want: "value1", + }, + { + name: "arg second parameter", + input: "$arg(param2)", + want: "value2", + }, + { + name: "arg not found", + input: "$arg(param3)", + want: "", + }, + // Mixed variables + { + name: "mixed variables", + input: "$req_method $req_path $status_code", + want: "POST /api/users 200", + }, + { + name: "variables with text", + input: "Method: $req_method, Path: $req_path", + want: "Method: POST, Path: /api/users", + }, + { + name: "function variables with text", + input: "Header: $header(User-Agent), Status: $status_code", + want: "Header: test-agent/1.0, Status: 200", + }, + // Escaped dollar signs + { + name: "escaped dollar", + input: "$$req_method", + want: "$req_method", + }, + { + name: "mixed escaped and unescaped", + input: "$$req_method $req_path", + want: "$req_method /api/users", + }, + // Environment variable syntax ${} + { + name: "env var syntax", + input: "${VAR}", + want: "${VAR}", + }, + // Error cases + { + name: "unknown variable", + input: "$unknown_var", + wantErr: true, + }, + { + name: "invalid function syntax", + input: "$arg(param1", + wantErr: true, + }, + { + name: "incomplete dollar", + input: "test$", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, tt.input, &out) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want, out.String()) + } + }) + } +} + +func TestExpandVars_Integration(t *testing.T) { + t.Run("complex log format", func(t *testing.T) { + testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil) + testRequest.Header.Set("User-Agent", "curl/7.68.0") + testRequest.RemoteAddr = "10.0.0.1:54321" + + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier.WriteHeader(200) + + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, + "$req_method $req_url $status_code User-Agent=$header(User-Agent)", + &out) + + require.NoError(t, err) + require.Equal(t, "GET https://api.example.com/users/123?sort=asc 200 User-Agent=curl/7.68.0", out.String()) + }) + + t.Run("with query parameters", func(t *testing.T) { + testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil) + + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, + "Query: $arg(q), Page: $arg(page)", + &out) + + require.NoError(t, err) + require.Equal(t, "Query: test, Page: 1", out.String()) + }) + + t.Run("response headers", func(t *testing.T) { + testRequest := httptest.NewRequest("GET", "/", nil) + + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + testResponseModifier.Header().Set("Cache-Control", "no-cache") + testResponseModifier.Header().Set("X-Rate-Limit", "100") + testResponseModifier.WriteHeader(200) + + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, + "Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)", + &out) + + require.NoError(t, err) + require.Equal(t, "Status: 200, Cache: no-cache, Limit: 100", out.String()) + }) +} + +func TestExpandVars_RequestSchemes(t *testing.T) { + tests := []struct { + name string + request *http.Request + expected string + }{ + { + name: "http scheme", + request: httptest.NewRequest("GET", "http://example.com/", nil), + expected: "http", + }, + { + name: "https scheme", + request: &http.Request{ + Method: "GET", + URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"}, + TLS: &tls.ConnectionState{}, // Simulate TLS connection + }, + expected: "https", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + var out strings.Builder + err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out) + require.NoError(t, err) + require.Equal(t, tt.expected, out.String()) + }) + } +} + +func TestExpandVars_UpstreamVariables(t *testing.T) { + // Upstream variables require context from routes package + testRequest := httptest.NewRequest("GET", "/", nil) + + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + + // Test that upstream variables don't cause errors even when not set + upstreamVars := []string{ + "$upstream_name", + "$upstream_scheme", + "$upstream_host", + "$upstream_port", + "$upstream_addr", + "$upstream_url", + } + + for _, varExpr := range upstreamVars { + t.Run(varExpr, func(t *testing.T) { + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, varExpr, &out) + // Should not error, may return empty string + require.NoError(t, err) + }) + } +} + +func TestExpandVars_NoHostPort(t *testing.T) { + // Test request without port in Host header + testRequest := httptest.NewRequest("GET", "/", nil) + testRequest.Host = "example.com" // No port + + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + + t.Run("req_host without port", func(t *testing.T) { + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out) + require.NoError(t, err) + require.Equal(t, "example.com", out.String()) + }) + + t.Run("req_port without port", func(t *testing.T) { + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out) + require.NoError(t, err) + require.Equal(t, "", out.String()) + }) +} + +func TestExpandVars_NoRemotePort(t *testing.T) { + // Test request without port in RemoteAddr + testRequest := httptest.NewRequest("GET", "/", nil) + testRequest.RemoteAddr = "192.168.1.1" // No port + + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + + t.Run("remote_host without port", func(t *testing.T) { + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out) + require.NoError(t, err) + require.Equal(t, "", out.String()) + }) + + t.Run("remote_port without port", func(t *testing.T) { + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out) + require.NoError(t, err) + require.Equal(t, "", out.String()) + }) +} + +func TestExpandVars_WhitespaceHandling(t *testing.T) { + testRequest := httptest.NewRequest("GET", "/test", nil) + testResponseModifier := NewResponseModifier(httptest.NewRecorder()) + + var out strings.Builder + err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out) + require.NoError(t, err) + require.Equal(t, "GET /test", out.String()) +} + +func TestValidateVars(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "valid simple variable", + input: "$req_method", + }, + { + name: "valid function variable", + input: "$header(User-Agent)", + }, + { + name: "valid response variable", + input: "$status_code", + }, + { + name: "invalid variable", + input: "$unknown_var", + wantErr: true, + }, + { + name: "incomplete variable", + input: "test$", + wantErr: true, + }, + { + name: "valid variables with text", + input: "Method: $req_method", + }, + { + name: "valid escaped dollar", + input: "$$req_method", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateVars(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestNeedExpandVars(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "contains variable", + input: "$req_method", + want: true, + }, + { + name: "contains function variable", + input: "$header(X-Test)", + want: true, + }, + { + name: "no variable", + input: "plain text", + want: false, + }, + { + name: "escaped dollar", + input: "$$req_method", + want: true, + }, + { + name: "mixed content", + input: "Method: $req_method", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NeedExpandVars(tt.input) + require.Equal(t, tt.want, got) + }) + } +}