feat(rules): replace go templates with custom variable expansion

- Replace template syntax ({{ .Request.Method }}) with $-prefixed variables ($req_method)
- Implement custom variable parser with static ($req_method, $status_code) and dynamic ($header(), $arg(), $form()) variables
- Replace templateOrStr interface with templateString struct and ExpandVars methods
- Add parser improvements for reliable quote handling
- Add new error types: ErrUnterminatedParenthesis, ErrUnexpectedVar, ErrExpectOneOrTwoArgs
- Update all tests and help text to use new variable syntax
- Add comprehensive unit and benchmark tests for variable expansion
This commit is contained in:
yusing
2025-10-25 22:43:47 +08:00
parent 9c3346dd9d
commit 1ec2872f3d
16 changed files with 1253 additions and 161 deletions

View File

@@ -334,7 +334,7 @@ var commands = map[string]struct {
helpListItem("Response", "the response object"),
"",
"Example:",
helpExample(CommandLog, "info", "/dev/stdout", "{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"),
helpExample(CommandLog, "info", "/dev/stdout", "$req_method $req_url $status_code"),
),
args: map[string]string{
"level": "the log level",
@@ -372,7 +372,7 @@ var commands = map[string]struct {
logger = f
}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
err := executeReqRespTemplateTo(tmpl, logger, w, r)
err := tmpl.ExpandVars(w, r, logger)
if err != nil {
return err
}
@@ -390,7 +390,7 @@ var commands = map[string]struct {
helpListItem("Response", "the response object"),
"",
"Example:",
helpExample(CommandNotify, "info", "ntfy", "Received request to {{ .Request.URL }}", "{{ .Request.Method }} {{ .Response.StatusCode }}"),
helpExample(CommandNotify, "info", "ntfy", "Received request to $req_url", "$req_method $status_code"),
),
args: map[string]string{
"level": "the log level",
@@ -432,12 +432,12 @@ var commands = map[string]struct {
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r)
err := titleTmpl.ExpandVars(w, r, respBuf)
if err != nil {
return err
}
titleLen := respBuf.Len()
err = executeReqRespTemplateTo(bodyTmpl, respBuf, w, r)
err = bodyTmpl.ExpandVars(w, r, respBuf)
if err != nil {
return err
}
@@ -455,16 +455,8 @@ var commands = map[string]struct {
},
}
type reqResponseTemplateData struct {
Request *http.Request
Response struct {
StatusCode int
Header http.Header
}
}
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr]
type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr]
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {

View File

@@ -54,7 +54,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
err = parseRules(fmt.Sprintf(`
- name: log-request-response
do: |
log info %q '{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }} {{ index (index .Response.Header "Content-Type") 0 }}'
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
`, tempFile.Name()), &rules)
require.NoError(t, err)
@@ -84,10 +84,10 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
err := parseRules(`
- name: log-stdout
do: |
log info /dev/stdout "stdout: {{ .Request.Method }} {{ .Response.StatusCode }}"
log info /dev/stdout "stdout: $req_method $status_code"
- name: log-stderr
do: |
log error /dev/stderr "stderr: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
log error /dev/stderr "stderr: $req_path $status_code"
`, &rules)
require.NoError(t, err)
@@ -126,13 +126,13 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
err = parseRules(fmt.Sprintf(`
- name: log-info
do: |
log info %s "INFO: {{ .Request.Method }} {{ .Response.StatusCode }}"
log info %s "INFO: $req_method $status_code"
- name: log-warn
do: |
log warn %s "WARN: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
log warn %s "WARN: $req_path $status_code"
- name: log-error
do: |
log error %s "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
log error %s "ERROR: $req_method $req_path $status_code"
`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules)
require.NoError(t, err)
@@ -177,7 +177,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
err = parseRules(fmt.Sprintf(`
- name: log-with-templates
do: |
log info %s 'Request: {{ .Request.Method }} {{ .Request.URL }} Host: {{ .Request.Host }} User-Agent: {{ index .Request.Header "User-Agent" 0 }} Response: {{ .Response.StatusCode }} Custom-Header: {{ index .Response.Header "X-Custom-Header" 0 }} Content-Length: {{ index .Response.Header "Content-Length" 0 }}'
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
`, tempFile.Name()), &rules)
require.NoError(t, err)
@@ -231,11 +231,11 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
- name: log-success
on: status 2xx
do: |
log info %q "SUCCESS: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
log info %q "SUCCESS: $req_method $req_path $status_code"
- name: log-error
on: status 4xx | status 5xx
do: |
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
log error %q "ERROR: $req_method $req_path $status_code"
`, successFile.Name(), errorFile.Name()), &rules)
require.NoError(t, err)
@@ -288,7 +288,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
err = parseRules(fmt.Sprintf(`
- name: log-multiple
do: |
log info %q "{{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"`, tempFile.Name()), &rules)
log info %q "$req_method $req_path $status_code"`, tempFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
@@ -339,7 +339,7 @@ func TestLogCommand_FilePermissions(t *testing.T) {
var rules Rules
err = parseRules(fmt.Sprintf(`
- on: status 2xx
do: log info %q "{{ .Request.Method }} {{ .Response.StatusCode }}"`, logFilePath), &rules)
do: log info %q "$req_method $status_code"`, logFilePath), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
@@ -374,27 +374,12 @@ func TestLogCommand_FilePermissions(t *testing.T) {
}
func TestLogCommand_InvalidTemplate(t *testing.T) {
upstream := mockUpstream(200, "success")
var rules Rules
// Test with invalid template syntax
err := parseRules(`
- name: log-invalid
do: |
log info /dev/stdout "{{ .Invalid.Field }}"`, &rules)
// Should not error during parsing, but template execution will fail gracefully
assert.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Should not panic
assert.NotPanics(t, func() {
handler.ServeHTTP(w, req)
})
assert.Equal(t, 200, w.Code)
log info /dev/stdout "$invalid_var"`, &rules)
assert.ErrorIs(t, err, ErrUnexpectedVar)
}

View File

@@ -54,7 +54,7 @@ var modFields = map[string]struct {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -62,7 +62,7 @@ var modFields = map[string]struct {
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -89,7 +89,7 @@ var modFields = map[string]struct {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -97,7 +97,7 @@ var modFields = map[string]struct {
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -124,7 +124,7 @@ var modFields = map[string]struct {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -134,7 +134,7 @@ var modFields = map[string]struct {
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -165,7 +165,7 @@ var modFields = map[string]struct {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -181,7 +181,7 @@ var modFields = map[string]struct {
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
v, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
@@ -221,7 +221,7 @@ var modFields = map[string]struct {
helpListItem("Request", "the request object"),
"",
"Example:",
helpExample(FieldBody, "HTTP STATUS: {{ .Request.Method }} {{ .Request.URL.Path }}"),
helpExample(FieldBody, "HTTP STATUS: $req_method $req_path"),
),
args: map[string]string{
"template": "the body template",
@@ -234,7 +234,7 @@ var modFields = map[string]struct {
return validateTemplate(args[0], true)
},
builder: func(args any) *FieldHandler {
tmpl := args.(templateOrStr)
tmpl := args.(templateString)
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
if r.Body != nil {
@@ -244,7 +244,7 @@ var modFields = map[string]struct {
bufPool := GetInitResponseModifier(w).BufPool()
b := bufPool.GetBuffer()
err := executeRequestTemplateTo(tmpl, b, r)
err := tmpl.ExpandVars(w, r, b)
if err != nil {
return err
}
@@ -266,7 +266,7 @@ var modFields = map[string]struct {
helpListItem("Response", "the response object"),
"",
"Example:",
helpExample(FieldResponseBody, "HTTP STATUS: {{ .Request.Method }} {{ .Response.StatusCode }}"),
helpExample(FieldResponseBody, "HTTP STATUS: $req_method $status_code"),
),
args: map[string]string{
"template": "the response body template",
@@ -279,12 +279,12 @@ var modFields = map[string]struct {
return validateTemplate(args[0], true)
},
builder: func(args any) *FieldHandler {
tmpl := args.(templateOrStr)
tmpl := args.(templateString)
return &FieldHandler{
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
rm := GetInitResponseModifier(w)
rm.ResetBody()
return executeReqRespTemplateTo(tmpl, rm, rm, r)
return tmpl.ExpandVars(w, r, rm)
}),
}
},

View File

@@ -367,7 +367,7 @@ func TestFieldHandler_Body(t *testing.T) {
}{
{
name: "set body with template",
template: "Hello {{ .Request.Method }} {{ .Request.URL.Path }}",
template: "Hello $req_method $req_path",
setup: func(r *http.Request) {
r.Method = "POST"
r.URL.Path = "/test"
@@ -424,7 +424,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
}{
{
name: "set response body with template",
template: "Response: {{ .Request.Method }} {{ .Request.URL.Path }}",
template: "Response: $req_method $req_path",
setup: func(r *http.Request) {
r.Method = "GET"
r.URL.Path = "/api/test"
@@ -552,19 +552,19 @@ func TestFieldValidation(t *testing.T) {
{
name: "body valid template",
field: FieldBody,
args: []string{"Hello {{ .Request.Method }}"},
args: []string{"Hello $req_method"},
wantError: false,
},
{
name: "body invalid template syntax",
field: FieldBody,
args: []string{"Hello {{ .InvalidField "},
args: []string{"Hello $invalid_field"},
wantError: true,
},
{
name: "response body valid template",
field: FieldResponseBody,
args: []string{"Response: {{ .Request.Method }}"},
args: []string{"Response: $req_method"},
wantError: false,
},
{

View File

@@ -5,18 +5,25 @@ import (
)
var (
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
ErrUnterminatedEnvVar = gperr.New("unterminated env var")
ErrUnknownDirective = gperr.New("unknown directive")
ErrUnknownModField = gperr.New("unknown field")
ErrEnvVarNotFound = gperr.New("env variable not found")
ErrInvalidArguments = gperr.New("invalid arguments")
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
ErrUnterminatedParenthesis = gperr.New("unterminated parenthesis")
ErrUnterminatedEnvVar = gperr.New("unterminated env var")
ErrUnknownDirective = gperr.New("unknown directive")
ErrUnknownModField = gperr.New("unknown field")
ErrEnvVarNotFound = gperr.New("env variable not found")
ErrInvalidArguments = gperr.New("invalid arguments")
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
// vars errors
ErrNoArgProvided = gperr.New("no argument provided")
ErrUnexpectedVar = gperr.New("unexpected variable")
ErrUnexpectedQuote = gperr.New("unexpected quote")
ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg")
ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg")
ErrExpectOneOrTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 1 or 2 args")
ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args")
ErrExpectTwoOrThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 or 3 args")
ErrExpectThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 3 args")

View File

@@ -28,11 +28,11 @@ func helpExample(cmd string, args ...string) string {
var out strings.Builder
pos := 0
for {
start := strings.Index(arg[pos:], "{{")
start := strings.IndexByte(arg[pos:], '$')
if start == -1 {
if pos < len(arg) {
// If no template at all (pos == 0), cyan highlight for whole-arg
// Otherwise, for mixed strings containing templates, leave non-template text unhighlighted
// If no variable at all (pos == 0), cyan highlight for whole-arg
// Otherwise, for mixed strings containing variables, leave non-variable text unhighlighted
if pos == 0 {
out.WriteString(ansi.WithANSI(arg[pos:], ansi.HighlightCyan))
} else {
@@ -43,20 +43,31 @@ func helpExample(cmd string, args ...string) string {
}
start += pos
if start > pos {
// Non-template text should not be highlighted
// Non-variable text should not be highlighted
out.WriteString(arg[pos:start])
}
end := strings.Index(arg[start+2:], "}}")
if end == -1 {
// Unmatched template start; write remainder without highlighting
out.WriteString(arg[start:])
break
// Parse variable name and optional function call
end := start + 1
for end < len(arg) && (arg[end] == '_' || (arg[end] >= 'a' && arg[end] <= 'z') || (arg[end] >= 'A' && arg[end] <= 'Z') || (arg[end] >= '0' && arg[end] <= '9')) {
end++
}
end += start + 2
inner := strings.TrimSpace(arg[start+2 : end])
parts := strings.Split(inner, ".")
out.WriteString(helpTemplateVar(parts...))
pos = end + 2
// Check for function call
if end < len(arg) && arg[end] == '(' {
parenCount := 1
end++
for end < len(arg) && parenCount > 0 {
switch arg[end] {
case '(':
parenCount++
case ')':
parenCount--
}
end++
}
}
varExpr := arg[start:end]
out.WriteString(helpVar(varExpr))
pos = end
}
fmt.Fprintf(&sb, ` "%s"`, out.String())
}
@@ -87,17 +98,29 @@ func helpFuncCall(fn string, args ...string) string {
return sb.String()
}
// helpTemplateVar generates a string like "{{ .Request.Method }} {{ .Request.URL.Path }}"
func helpTemplateVar(parts ...string) string {
var sb strings.Builder
sb.WriteString(ansi.WithANSI("{{ ", ansi.HighlightWhite))
for i, part := range parts {
sb.WriteString(ansi.WithANSI(part, ansi.HighlightCyan))
if i < len(parts)-1 {
sb.WriteString(".")
}
// helpVar generates a highlighted string for a variable like "$req_method" or "$header(X-Test)"
func helpVar(varExpr string) string {
if !strings.HasPrefix(varExpr, "$") {
return varExpr
}
sb.WriteString(ansi.WithANSI(" }}", ansi.HighlightWhite))
// Check if it's a function call
parenIdx := strings.IndexByte(varExpr, '(')
if parenIdx == -1 {
// Simple variable like "$req_method"
return ansi.WithANSI(varExpr, ansi.HighlightCyan)
}
// Function call like "$header(X-Test)"
var sb strings.Builder
sb.WriteString(ansi.WithANSI(varExpr[:parenIdx], ansi.HighlightCyan))
sb.WriteString(ansi.WithANSI("(", ansi.HighlightWhite))
// Extract and highlight the arguments
argsStr := varExpr[parenIdx+1 : len(varExpr)-1]
sb.WriteString(ansi.WithANSI(argsStr, ansi.HighlightYellow))
sb.WriteString(ansi.WithANSI(")", ansi.HighlightWhite))
return sb.String()
}

View File

@@ -218,7 +218,7 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
err = parseRules(fmt.Sprintf(`
- name: log-response
on: path /test
do: log info %s "{{ .Request.Method }} {{ .Response.StatusCode }}"
do: log info %s "$req_method $status_code"
`, tempFile.Name()), &rules)
require.NoError(t, err)
@@ -261,7 +261,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
err = parseRules(fmt.Sprintf(`
- name: log-errors
on: status 4xx
do: log error %s "{{ .Request.URL }} returned {{ .Response.StatusCode }}"
do: log error %s "$req_url returned $status_code"
`, tempFile.Name()), &rules)
require.NoError(t, err)
@@ -364,11 +364,11 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
do: require_basic_auth "Protected Area"
- name: log-all-requests
do: |
log info %q "{{ .Request.Method }} {{ .Request.URL }} -> {{ .Response.StatusCode }}"
log info %q "$req_method $req_url -> $status_code"
- name: log-errors
on: status 4xx
do: |
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"
log error %q "ERROR: $req_method $req_url $status_code"
`, logFile.Name(), errorLogFile.Name()), &rules)
require.NoError(t, err)
@@ -610,10 +610,10 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
err := parseRules(`
- name: process-form
on: form username
do: set resp_header X-Username "{{ index .Request.Form.username 0 }}"
do: set resp_header X-Username "$form(username)"
- name: process-postform
on: postform email
do: set resp_header X-Email "{{ index .Request.PostForm.email 0 }}"
do: set resp_header X-Email "$postform(email)"
`, &rules)
require.NoError(t, err)
@@ -923,7 +923,7 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) {
- name: modify-response
do: |
set resp_header X-Modified "true"
set resp_body "Modified: {{ .Request.Method }} {{ .Request.URL.Path }}"
set resp_body "Modified: $req_method $req_path"
`, &rules)
require.NoError(t, err)

View File

@@ -591,7 +591,7 @@ func parseOn(line string) (Checker, bool, gperr.Error) {
validArgs, err := checker.validate(args)
if err != nil {
return nil, false, err.Subject(subject).With(checker.help.Error())
return nil, false, err.With(checker.help.Error())
}
checkFunc := checker.builder(validArgs)

View File

@@ -19,6 +19,12 @@ var escapedChars = map[rune]rune{
' ': ' ',
}
var quoteChars = [256]bool{
'"': true,
'\'': true,
'`': true,
}
// parse expression to subject and args
// with support for quotes, escaped chars, and env substitution, e.g.
//
@@ -74,6 +80,19 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
buf.WriteRune('$')
expectingBrace = false
}
if quoteChars[r] {
switch {
case quote == 0 && brackets == 0:
quote = r
flush(false)
case r == quote:
quote = 0
flush(true)
default:
buf.WriteRune(r)
}
continue
}
switch r {
case '\\':
escaped = true
@@ -106,17 +125,6 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
} else {
buf.WriteRune(r)
}
case '"', '\'', '`':
switch {
case quote == 0 && brackets == 0:
quote = r
flush(false)
case r == quote:
quote = 0
flush(true)
default:
buf.WriteRune(r)
}
case '(':
brackets++
buf.WriteRune(r)

View File

@@ -1,48 +1,52 @@
package rules
import (
"bytes"
"io"
"net/http"
"strings"
"unsafe"
)
type templateOrStr interface {
Execute(w io.Writer, data any) error
Len() int
type templateString struct {
string
isTemplate bool
}
type strTemplate string
type keyValueTemplate struct {
key string
tmpl templateString
}
func (t strTemplate) Execute(w io.Writer, _ any) error {
n, err := w.Write([]byte(t))
if err != nil {
func (tmpl *keyValueTemplate) Unpack() (string, templateString) {
return tmpl.key, tmpl.tmpl
}
func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error {
if !tmpl.isTemplate {
_, err := dstW.Write(strtobNoCopy(tmpl.string))
return err
}
if n != len(t) {
return io.ErrShortWrite
return ExpandVars(GetInitResponseModifier(w), req, tmpl.string, dstW)
}
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
if !tmpl.isTemplate {
return tmpl.string, nil
}
return nil
}
func (t strTemplate) Len() int {
return len(t)
}
type keyValueTemplate = Tuple[string, templateOrStr]
func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) {
var buf bytes.Buffer
err := tmpl.Execute(&buf, reqResponseTemplateData{Request: r})
var buf strings.Builder
err := ExpandVars(GetInitResponseModifier(w), req, tmpl.string, &buf)
if err != nil {
return "", err
}
return buf.String(), nil
}
func executeRequestTemplateTo(tmpl templateOrStr, o io.Writer, r *http.Request) error {
return tmpl.Execute(o, reqResponseTemplateData{Request: r})
func (tmpl *templateString) Len() int {
return len(tmpl.string)
}
func executeReqRespTemplateTo(tmpl templateOrStr, o io.Writer, w http.ResponseWriter, r *http.Request) error {
return tmpl.Execute(o, reqResponseTemplateData{Request: r, Response: GetInitResponseModifier(w).Response()})
func strtobNoCopy(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}

View File

@@ -8,7 +8,6 @@ import (
"path/filepath"
"strconv"
"strings"
"text/template"
"github.com/rs/zerolog"
nettypes "github.com/yusing/godoxy/internal/net/types"
@@ -91,11 +90,11 @@ func toKeyValueTemplate(args []string) (any, gperr.Error) {
return nil, ErrExpectTwoArgs
}
tmpl, err := validateTemplate(args[1], false)
isTemplate, err := validateTemplate(args[1], false)
if err != nil {
return nil, err
}
return &keyValueTemplate{args[0], tmpl}, nil
return &keyValueTemplate{args[0], isTemplate}, nil
}
// validateURL returns types.URL with the URL validated.
@@ -300,33 +299,20 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E
return set, nil
}
func isTemplate(tmplStr string) bool {
return strings.Contains(tmplStr, "{{")
}
type templateWithLen struct {
*template.Template
len int
}
func (t *templateWithLen) Len() int {
return t.len
}
func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) {
func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error) {
if newline && !strings.HasSuffix(tmplStr, "\n") {
tmplStr += "\n"
}
if !isTemplate(tmplStr) {
return strTemplate(tmplStr), nil
if !NeedExpandVars(tmplStr) {
return templateString{tmplStr, false}, nil
}
tmpl, err := template.New("template").Parse(tmplStr)
err := ValidateVars(tmplStr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
return templateString{}, gperr.Wrap(err)
}
return &templateWithLen{tmpl, len(tmplStr)}, nil
return templateString{tmplStr, true}, nil
}
func validateLevel(level string) (zerolog.Level, gperr.Error) {

View File

@@ -0,0 +1,28 @@
package rules
import (
"io"
"net/http/httptest"
"net/url"
"testing"
)
func BenchmarkExpandVars(b *testing.B) {
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
testResponseModifier.WriteHeader(200)
testResponseModifier.Write([]byte("Hello, world!"))
testRequest := httptest.NewRequest("GET", "/", nil)
testRequest.Header.Set("User-Agent", "test-agent/1.0")
testRequest.Header.Set("X-Custom", "value1,value2")
testRequest.ContentLength = 12345
testRequest.RemoteAddr = "192.168.1.100:54321"
testRequest.Form = url.Values{"param1": {"value1"}, "param2": {"value2"}}
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
for b.Loop() {
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
if err != nil {
b.Fatal(err)
}
}
}

View File

@@ -0,0 +1,214 @@
package rules
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
ioutils "github.com/yusing/goutils/io"
)
// TODO: remove middleware/vars.go and use this instead
type (
reqVarGetter func(*http.Request) string
respVarGetter func(*ResponseModifier) string
)
var reVar = regexp.MustCompile(`\$[\w_]+`)
var validVarNameCharset = func() (ret [256]bool) {
for c := byte('a'); c <= 'z'; c++ {
ret[c] = true
}
for c := byte('A'); c <= 'Z'; c++ {
ret[c] = true
}
ret['_'] = true
return
}()
func NeedExpandVars(s string) bool {
return reVar.MatchString(s)
}
var (
voidResponseModifier = NewResponseModifier(httptest.NewRecorder())
dummyRequest = http.Request{
Method: "GET",
URL: &url.URL{Path: "/"},
Header: http.Header{},
}
)
// ValidateVars validates the variables in the given string.
// It returns ErrUnexpectedVar if any invalid variable is found.
func ValidateVars(s string) error {
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
}
func ExpandVars(w *ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
dst := ioutils.NewBufferedWriter(dstW, 1024)
defer func() {
dst.Flush()
dst.Release()
}()
for i := 0; i < len(src); i++ {
ch := src[i]
if ch != '$' {
if err := dst.WriteByte(ch); err != nil {
return err
}
continue
}
// Look ahead
if i+1 >= len(src) {
return ErrUnterminatedEnvVar
}
j := i + 1
switch src[j] {
case '$': // $$ -> literal '$'
if err := dst.WriteByte('$'); err != nil {
return err
}
i = j
continue
case '{': // ${...} pass through as-is
if _, err := dst.WriteString("${"); err != nil {
return err
}
i = j // we've consumed the '{' too
continue
}
if validVarNameCharset[src[j]] {
k := j
for k < len(src) {
c := src[k]
if validVarNameCharset[c] {
k++
continue
}
break
}
name := src[j:k]
isStatic := true
var actual string
if getter, ok := dynamicVarSubsMap[name]; ok {
// Function-like variables
isStatic = false
args, nextIdx, err := extractArgs(src, j, name)
if err != nil {
return err
}
i = nextIdx
actual, err = getter(args, w, req)
if err != nil {
return err
}
} else if getter, ok := staticReqVarSubsMap[name]; ok {
actual = getter(req)
} else if getter, ok := staticRespVarSubsMap[name]; ok {
actual = getter(w)
} else {
return ErrUnexpectedVar.Subject(name)
}
if _, err := dst.WriteString(actual); err != nil {
return err
}
if isStatic {
i = k - 1
}
continue
}
// No valid construct after '$'
return ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
}
return nil
}
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
// Find opening parenthesis
parenIdx := strings.IndexByte(src[i:], '(')
if parenIdx == -1 {
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q at position %d", funcName, i)
}
parenIdx += i
var (
quote byte // current quote character (0 if not in quotes)
arg strings.Builder
)
nextIdx = parenIdx + 1
for nextIdx < len(src) {
ch := src[nextIdx]
if quote != 0 {
// We're inside a quoted string
if ch == quote {
// Closing quote - the content between quotes is now complete, add it
args = append(args, arg.String())
arg.Reset()
quote = 0
nextIdx++
continue
}
// Inside quotes - add everything as-is
arg.WriteByte(ch)
nextIdx++
continue
}
// Not inside quotes
if quoteChars[ch] {
// Opening quote
quote = ch
nextIdx++
continue
}
if ch == ')' {
// End of arguments
if arg.Len() > 0 {
args = append(args, arg.String())
}
return args, nextIdx, nil
}
if ch == ',' {
// Argument separator
if arg.Len() > 0 {
args = append(args, arg.String())
arg.Reset()
}
nextIdx++
continue
}
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
// Whitespace outside quotes - skip
nextIdx++
continue
}
// Regular character - accumulate until we hit a delimiter
arg.WriteByte(ch)
nextIdx++
}
// Reached end of string without closing parenthesis
if quote != 0 {
return nil, 0, ErrUnterminatedQuotes.Withf("func %q", funcName)
}
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q", funcName)
}

View File

@@ -0,0 +1,81 @@
package rules
import (
"net/http"
"net/url"
"strconv"
)
var (
VarHeader = "header"
VarResponseHeader = "resp_header"
VarQuery = "arg"
VarForm = "form"
VarPostForm = "postform"
)
type dynamicVarGetter func(args []string, w *ResponseModifier, req *http.Request) (string, error)
var dynamicVarSubsMap = map[string]dynamicVarGetter{
VarHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(req.Header, key, index)
},
VarResponseHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(w.Header(), key, index)
},
VarQuery: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(GetSharedData(w).GetQueries(req), key, index)
},
VarForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(req.Form, key, index)
},
VarPostForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(req.PostForm, key, index)
},
}
func getValueByKeyAtIndex[Values http.Header | url.Values](values Values, key string, index int) (string, error) {
// NOTE: do not use Header.Get or http.CanonicalHeaderKey here, respect to user input
if values, ok := values[key]; ok && index < len(values) {
return values[index], nil
}
// ignore unknown header or index out of range
return "", nil
}
func getKeyAndIndex(args []string) (key string, index int, err error) {
switch len(args) {
case 0:
return "", 0, ErrExpectNoArg
case 1:
return args[0], 0, nil
case 2:
index, err = strconv.Atoi(args[1])
if err != nil {
return "", 0, ErrInvalidArguments.Withf("invalid index %q", args[1])
}
return args[0], index, nil
default:
return "", 0, ErrExpectOneOrTwoArgs
}
}

View File

@@ -0,0 +1,92 @@
package rules
import (
"net"
"net/http"
"strconv"
"github.com/yusing/godoxy/internal/route/routes"
)
const (
VarRequestMethod = "req_method"
VarRequestScheme = "req_scheme"
VarRequestHost = "req_host"
VarRequestPort = "req_port"
VarRequestPath = "req_path"
VarRequestAddr = "req_addr"
VarRequestQuery = "req_query"
VarRequestURL = "req_url"
VarRequestURI = "req_uri"
VarRequestContentType = "req_content_type"
VarRequestContentLen = "req_content_length"
VarRemoteHost = "remote_host"
VarRemotePort = "remote_port"
VarRemoteAddr = "remote_addr"
VarUpstreamName = "upstream_name"
VarUpstreamScheme = "upstream_scheme"
VarUpstreamHost = "upstream_host"
VarUpstreamPort = "upstream_port"
VarUpstreamAddr = "upstream_addr"
VarUpstreamURL = "upstream_url"
VarRespContentType = "resp_content_type"
VarRespContentLen = "resp_content_length"
VarRespStatusCode = "status_code"
)
var staticReqVarSubsMap = map[string]reqVarGetter{
VarRequestMethod: func(req *http.Request) string { return req.Method },
VarRequestScheme: func(req *http.Request) string {
if req.TLS != nil {
return "https"
}
return "http"
},
VarRequestHost: func(req *http.Request) string {
reqHost, _, err := net.SplitHostPort(req.Host)
if err != nil {
return req.Host
}
return reqHost
},
VarRequestPort: func(req *http.Request) string {
_, reqPort, _ := net.SplitHostPort(req.Host)
return reqPort
},
VarRequestAddr: func(req *http.Request) string { return req.Host },
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
VarRemoteHost: func(req *http.Request) string {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientIP
}
return ""
},
VarRemotePort: func(req *http.Request) string {
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return clientPort
}
return ""
},
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
VarUpstreamName: routes.TryGetUpstreamName,
VarUpstreamScheme: routes.TryGetUpstreamScheme,
VarUpstreamHost: routes.TryGetUpstreamHost,
VarUpstreamPort: routes.TryGetUpstreamPort,
VarUpstreamAddr: routes.TryGetUpstreamAddr,
VarUpstreamURL: routes.TryGetUpstreamURL,
}
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") },
VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) },
VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
}

View File

@@ -0,0 +1,672 @@
package rules
import (
"bytes"
"crypto/tls"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestExtractArgs(t *testing.T) {
tests := []struct {
name string
src string
startPos int
funcName string
wantArgs []string
wantNextIdx int
wantErr bool
}{
{
name: "unquoted single arg",
src: "header(X-Some-Header)",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header"},
wantNextIdx: 20,
},
{
name: "double quoted arg",
src: `header("X-Some-Header")`,
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header"},
wantNextIdx: 22,
},
{
name: "single quoted arg",
src: "header('X-Some-Header')",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header"},
wantNextIdx: 22,
},
{
name: "backtick quoted arg",
src: "header(`X-Some-Header`)",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header"},
wantNextIdx: 22,
},
{
name: "two args with double quotes and unquoted",
src: `header("X-Some-Header", 1)`,
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header", "1"},
wantNextIdx: 25,
},
{
name: "two args with single and double quotes",
src: "header('X-Some-Header', \"1\")",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header", "1"},
wantNextIdx: 27,
},
{
name: "two args with backtick and single quotes",
src: "header(`X-Some-Header`, '1')",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Some-Header", "1"},
wantNextIdx: 27,
},
{
name: "quoted string with nested different quotes",
src: `arg("'(value)'")`,
startPos: 0,
funcName: "arg",
wantArgs: []string{"'(value)'"},
wantNextIdx: 15,
},
{
name: "quoted string with backticks inside double quotes",
src: "header(\"value`with`backticks\")",
startPos: 0,
funcName: "header",
wantArgs: []string{"value`with`backticks"},
wantNextIdx: 29,
},
{
name: "multiple args with whitespace",
src: "header( \"X-Header\" , 2 )",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Header", "2"},
wantNextIdx: 27,
},
{
name: "empty quoted string",
src: `header("")`,
startPos: 0,
funcName: "header",
wantArgs: []string{""},
wantNextIdx: 9,
},
{
name: "multiple empty args",
src: `header("", "")`,
startPos: 0,
funcName: "header",
wantArgs: []string{"", ""},
wantNextIdx: 13,
},
{
name: "unquoted args separated by comma",
src: "header(key1,key2,key3)",
startPos: 0,
funcName: "header",
wantArgs: []string{"key1", "key2", "key3"},
wantNextIdx: 21,
},
{
name: "trailing whitespace before closing paren",
src: `header("value" )`,
startPos: 0,
funcName: "header",
wantArgs: []string{"value"},
wantNextIdx: 16,
},
{
name: "startPos not at beginning",
src: "prefix_header(X-Header)",
startPos: 7,
funcName: "header",
wantArgs: []string{"X-Header"},
wantNextIdx: 22,
},
{
name: "special chars in unquoted arg",
src: "header(X-Custom_Header.v1)",
startPos: 0,
funcName: "header",
wantArgs: []string{"X-Custom_Header.v1"},
wantNextIdx: 25,
},
{
name: "unterminated quote",
src: `header("X-Header`,
startPos: 0,
funcName: "header",
wantErr: true,
},
{
name: "missing closing parenthesis",
src: `header("X-Header"`,
startPos: 0,
funcName: "header",
wantErr: true,
},
{
name: "no opening parenthesis",
src: `header"X-Header"`,
startPos: 0,
funcName: "header",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args, nextIdx, err := extractArgs(tt.src, tt.startPos, tt.funcName)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.wantArgs, args)
require.Equal(t, tt.wantNextIdx, nextIdx)
}
})
}
}
func TestExpandVars(t *testing.T) {
// Create a comprehensive test request
testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1&param2=value2#fragment", nil)
testRequest.Header.Set("Content-Type", "application/json")
testRequest.Header.Set("User-Agent", "test-agent/1.0")
testRequest.Header.Set("X-Custom", "value1,value2")
testRequest.ContentLength = 12345
testRequest.RemoteAddr = "192.168.1.100:54321"
// Create response modifier with headers
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
testResponseModifier.Header().Set("Content-Type", "text/html")
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
testResponseModifier.WriteHeader(200)
// set content length to 9876 by writing 9876 'a' bytes
testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876))
tests := []struct {
name string
input string
want string
wantErr bool
}{
// Basic request variables
{
name: "req_method",
input: "$req_method",
want: "POST",
},
{
name: "req_path",
input: "$req_path",
want: "/api/users",
},
{
name: "req_query",
input: "$req_query",
want: "param1=value1&param2=value2",
},
{
name: "req_url",
input: "$req_url",
want: "https://example.com:8080/api/users?param1=value1&param2=value2#fragment",
},
{
name: "req_uri",
input: "$req_uri",
want: "/api/users?param1=value1&param2=value2",
},
{
name: "req_host",
input: "$req_host",
want: "example.com",
},
{
name: "req_port",
input: "$req_port",
want: "8080",
},
{
name: "req_addr",
input: "$req_addr",
want: "example.com:8080",
},
{
name: "req_content_type",
input: "$req_content_type",
want: "application/json",
},
{
name: "req_content_length",
input: "$req_content_length",
want: "12345",
},
{
name: "remote_host",
input: "$remote_host",
want: "192.168.1.100",
},
{
name: "remote_port",
input: "$remote_port",
want: "54321",
},
{
name: "remote_addr",
input: "$remote_addr",
want: "192.168.1.100:54321",
},
// Response variables
{
name: "status_code",
input: "$status_code",
want: "200",
},
{
name: "resp_content_type",
input: "$resp_content_type",
want: "text/html",
},
{
name: "resp_content_length",
input: "$resp_content_length",
want: "9876",
},
// Function-like variables - header
{
name: "header single value",
input: "$header(User-Agent)",
want: "test-agent/1.0",
},
{
name: "header with index 0",
input: "$header(X-Custom, 0)",
want: "value1",
},
{
name: "header with index 1",
input: "$header(X-Custom, 1)",
want: "value2",
},
{
name: "header not found",
input: "$header(X-Not-Found)",
want: "",
},
{
name: "header index out of range",
input: "$header(X-Custom, 99)",
want: "",
},
// Function-like variables - resp_header
{
name: "resp_header single value",
input: "$resp_header(Content-Type)",
want: "text/html",
},
{
name: "resp_header custom header",
input: "$resp_header(X-Custom-Resp)",
want: "resp-value",
},
{
name: "resp_header not found",
input: "$resp_header(X-Not-Found)",
want: "",
},
// Function-like variables - arg (query parameters)
{
name: "arg single parameter",
input: "$arg(param1)",
want: "value1",
},
{
name: "arg second parameter",
input: "$arg(param2)",
want: "value2",
},
{
name: "arg not found",
input: "$arg(param3)",
want: "",
},
// Mixed variables
{
name: "mixed variables",
input: "$req_method $req_path $status_code",
want: "POST /api/users 200",
},
{
name: "variables with text",
input: "Method: $req_method, Path: $req_path",
want: "Method: POST, Path: /api/users",
},
{
name: "function variables with text",
input: "Header: $header(User-Agent), Status: $status_code",
want: "Header: test-agent/1.0, Status: 200",
},
// Escaped dollar signs
{
name: "escaped dollar",
input: "$$req_method",
want: "$req_method",
},
{
name: "mixed escaped and unescaped",
input: "$$req_method $req_path",
want: "$req_method /api/users",
},
// Environment variable syntax ${}
{
name: "env var syntax",
input: "${VAR}",
want: "${VAR}",
},
// Error cases
{
name: "unknown variable",
input: "$unknown_var",
wantErr: true,
},
{
name: "invalid function syntax",
input: "$arg(param1",
wantErr: true,
},
{
name: "incomplete dollar",
input: "test$",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.want, out.String())
}
})
}
}
func TestExpandVars_Integration(t *testing.T) {
t.Run("complex log format", func(t *testing.T) {
testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil)
testRequest.Header.Set("User-Agent", "curl/7.68.0")
testRequest.RemoteAddr = "10.0.0.1:54321"
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
testResponseModifier.WriteHeader(200)
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest,
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
&out)
require.NoError(t, err)
require.Equal(t, "GET https://api.example.com/users/123?sort=asc 200 User-Agent=curl/7.68.0", out.String())
})
t.Run("with query parameters", func(t *testing.T) {
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil)
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest,
"Query: $arg(q), Page: $arg(page)",
&out)
require.NoError(t, err)
require.Equal(t, "Query: test, Page: 1", out.String())
})
t.Run("response headers", func(t *testing.T) {
testRequest := httptest.NewRequest("GET", "/", nil)
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
testResponseModifier.Header().Set("Cache-Control", "no-cache")
testResponseModifier.Header().Set("X-Rate-Limit", "100")
testResponseModifier.WriteHeader(200)
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest,
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
&out)
require.NoError(t, err)
require.Equal(t, "Status: 200, Cache: no-cache, Limit: 100", out.String())
})
}
func TestExpandVars_RequestSchemes(t *testing.T) {
tests := []struct {
name string
request *http.Request
expected string
}{
{
name: "http scheme",
request: httptest.NewRequest("GET", "http://example.com/", nil),
expected: "http",
},
{
name: "https scheme",
request: &http.Request{
Method: "GET",
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
TLS: &tls.ConnectionState{}, // Simulate TLS connection
},
expected: "https",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
var out strings.Builder
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
require.NoError(t, err)
require.Equal(t, tt.expected, out.String())
})
}
}
func TestExpandVars_UpstreamVariables(t *testing.T) {
// Upstream variables require context from routes package
testRequest := httptest.NewRequest("GET", "/", nil)
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
// Test that upstream variables don't cause errors even when not set
upstreamVars := []string{
"$upstream_name",
"$upstream_scheme",
"$upstream_host",
"$upstream_port",
"$upstream_addr",
"$upstream_url",
}
for _, varExpr := range upstreamVars {
t.Run(varExpr, func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
// Should not error, may return empty string
require.NoError(t, err)
})
}
}
func TestExpandVars_NoHostPort(t *testing.T) {
// Test request without port in Host header
testRequest := httptest.NewRequest("GET", "/", nil)
testRequest.Host = "example.com" // No port
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
t.Run("req_host without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
require.NoError(t, err)
require.Equal(t, "example.com", out.String())
})
t.Run("req_port without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
require.NoError(t, err)
require.Equal(t, "", out.String())
})
}
func TestExpandVars_NoRemotePort(t *testing.T) {
// Test request without port in RemoteAddr
testRequest := httptest.NewRequest("GET", "/", nil)
testRequest.RemoteAddr = "192.168.1.1" // No port
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
t.Run("remote_host without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
require.NoError(t, err)
require.Equal(t, "", out.String())
})
t.Run("remote_port without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
require.NoError(t, err)
require.Equal(t, "", out.String())
})
}
func TestExpandVars_WhitespaceHandling(t *testing.T) {
testRequest := httptest.NewRequest("GET", "/test", nil)
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
require.NoError(t, err)
require.Equal(t, "GET /test", out.String())
}
func TestValidateVars(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "valid simple variable",
input: "$req_method",
},
{
name: "valid function variable",
input: "$header(User-Agent)",
},
{
name: "valid response variable",
input: "$status_code",
},
{
name: "invalid variable",
input: "$unknown_var",
wantErr: true,
},
{
name: "incomplete variable",
input: "test$",
wantErr: true,
},
{
name: "valid variables with text",
input: "Method: $req_method",
},
{
name: "valid escaped dollar",
input: "$$req_method",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateVars(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestNeedExpandVars(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{
name: "contains variable",
input: "$req_method",
want: true,
},
{
name: "contains function variable",
input: "$header(X-Test)",
want: true,
},
{
name: "no variable",
input: "plain text",
want: false,
},
{
name: "escaped dollar",
input: "$$req_method",
want: true,
},
{
name: "mixed content",
input: "Method: $req_method",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NeedExpandVars(tt.input)
require.Equal(t, tt.want, got)
})
}
}