mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-18 16:17:07 +01:00
feat(rules): replace go templates with custom variable expansion
- Replace template syntax ({{ .Request.Method }}) with $-prefixed variables ($req_method)
- Implement custom variable parser with static ($req_method, $status_code) and dynamic ($header(), $arg(), $form()) variables
- Replace templateOrStr interface with templateString struct and ExpandVars methods
- Add parser improvements for reliable quote handling
- Add new error types: ErrUnterminatedParenthesis, ErrUnexpectedVar, ErrExpectOneOrTwoArgs
- Update all tests and help text to use new variable syntax
- Add comprehensive unit and benchmark tests for variable expansion
This commit is contained in:
@@ -334,7 +334,7 @@ var commands = map[string]struct {
|
||||
helpListItem("Response", "the response object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(CommandLog, "info", "/dev/stdout", "{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"),
|
||||
helpExample(CommandLog, "info", "/dev/stdout", "$req_method $req_url $status_code"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"level": "the log level",
|
||||
@@ -372,7 +372,7 @@ var commands = map[string]struct {
|
||||
logger = f
|
||||
}
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
err := executeReqRespTemplateTo(tmpl, logger, w, r)
|
||||
err := tmpl.ExpandVars(w, r, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -390,7 +390,7 @@ var commands = map[string]struct {
|
||||
helpListItem("Response", "the response object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(CommandNotify, "info", "ntfy", "Received request to {{ .Request.URL }}", "{{ .Request.Method }} {{ .Response.StatusCode }}"),
|
||||
helpExample(CommandNotify, "info", "ntfy", "Received request to $req_url", "$req_method $status_code"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"level": "the log level",
|
||||
@@ -432,12 +432,12 @@ var commands = map[string]struct {
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
|
||||
|
||||
err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r)
|
||||
err := titleTmpl.ExpandVars(w, r, respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
titleLen := respBuf.Len()
|
||||
err = executeReqRespTemplateTo(bodyTmpl, respBuf, w, r)
|
||||
err = bodyTmpl.ExpandVars(w, r, respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -455,16 +455,8 @@ var commands = map[string]struct {
|
||||
},
|
||||
}
|
||||
|
||||
type reqResponseTemplateData struct {
|
||||
Request *http.Request
|
||||
Response struct {
|
||||
StatusCode int
|
||||
Header http.Header
|
||||
}
|
||||
}
|
||||
|
||||
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr]
|
||||
type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr]
|
||||
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateString]
|
||||
type onNotifyArgs = Tuple4[zerolog.Level, string, templateString, templateString]
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (cmd *Command) Parse(v string) error {
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-request-response
|
||||
do: |
|
||||
log info %q '{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }} {{ index (index .Response.Header "Content-Type") 0 }}'
|
||||
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -84,10 +84,10 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
err := parseRules(`
|
||||
- name: log-stdout
|
||||
do: |
|
||||
log info /dev/stdout "stdout: {{ .Request.Method }} {{ .Response.StatusCode }}"
|
||||
log info /dev/stdout "stdout: $req_method $status_code"
|
||||
- name: log-stderr
|
||||
do: |
|
||||
log error /dev/stderr "stderr: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
log error /dev/stderr "stderr: $req_path $status_code"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -126,13 +126,13 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-info
|
||||
do: |
|
||||
log info %s "INFO: {{ .Request.Method }} {{ .Response.StatusCode }}"
|
||||
log info %s "INFO: $req_method $status_code"
|
||||
- name: log-warn
|
||||
do: |
|
||||
log warn %s "WARN: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
log warn %s "WARN: $req_path $status_code"
|
||||
- name: log-error
|
||||
do: |
|
||||
log error %s "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
log error %s "ERROR: $req_method $req_path $status_code"
|
||||
`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -177,7 +177,7 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-with-templates
|
||||
do: |
|
||||
log info %s 'Request: {{ .Request.Method }} {{ .Request.URL }} Host: {{ .Request.Host }} User-Agent: {{ index .Request.Header "User-Agent" 0 }} Response: {{ .Response.StatusCode }} Custom-Header: {{ index .Response.Header "X-Custom-Header" 0 }} Content-Length: {{ index .Response.Header "Content-Length" 0 }}'
|
||||
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -231,11 +231,11 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
- name: log-success
|
||||
on: status 2xx
|
||||
do: |
|
||||
log info %q "SUCCESS: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
log info %q "SUCCESS: $req_method $req_path $status_code"
|
||||
- name: log-error
|
||||
on: status 4xx | status 5xx
|
||||
do: |
|
||||
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
log error %q "ERROR: $req_method $req_path $status_code"
|
||||
`, successFile.Name(), errorFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -288,7 +288,7 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-multiple
|
||||
do: |
|
||||
log info %q "{{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"`, tempFile.Name()), &rules)
|
||||
log info %q "$req_method $req_path $status_code"`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -339,7 +339,7 @@ func TestLogCommand_FilePermissions(t *testing.T) {
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- on: status 2xx
|
||||
do: log info %q "{{ .Request.Method }} {{ .Response.StatusCode }}"`, logFilePath), &rules)
|
||||
do: log info %q "$req_method $status_code"`, logFilePath), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -374,27 +374,12 @@ func TestLogCommand_FilePermissions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_InvalidTemplate(t *testing.T) {
|
||||
upstream := mockUpstream(200, "success")
|
||||
|
||||
var rules Rules
|
||||
|
||||
// Test with invalid template syntax
|
||||
err := parseRules(`
|
||||
- name: log-invalid
|
||||
do: |
|
||||
log info /dev/stdout "{{ .Invalid.Field }}"`, &rules)
|
||||
// Should not error during parsing, but template execution will fail gracefully
|
||||
assert.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
handler.ServeHTTP(w, req)
|
||||
})
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
log info /dev/stdout "$invalid_var"`, &rules)
|
||||
assert.ErrorIs(t, err, ErrUnexpectedVar)
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ var modFields = map[string]struct {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -62,7 +62,7 @@ var modFields = map[string]struct {
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -89,7 +89,7 @@ var modFields = map[string]struct {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -97,7 +97,7 @@ var modFields = map[string]struct {
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,7 +124,7 @@ var modFields = map[string]struct {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -134,7 +134,7 @@ var modFields = map[string]struct {
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -165,7 +165,7 @@ var modFields = map[string]struct {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -181,7 +181,7 @@ var modFields = map[string]struct {
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -221,7 +221,7 @@ var modFields = map[string]struct {
|
||||
helpListItem("Request", "the request object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(FieldBody, "HTTP STATUS: {{ .Request.Method }} {{ .Request.URL.Path }}"),
|
||||
helpExample(FieldBody, "HTTP STATUS: $req_method $req_path"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"template": "the body template",
|
||||
@@ -234,7 +234,7 @@ var modFields = map[string]struct {
|
||||
return validateTemplate(args[0], true)
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateOrStr)
|
||||
tmpl := args.(templateString)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Body != nil {
|
||||
@@ -244,7 +244,7 @@ var modFields = map[string]struct {
|
||||
|
||||
bufPool := GetInitResponseModifier(w).BufPool()
|
||||
b := bufPool.GetBuffer()
|
||||
err := executeRequestTemplateTo(tmpl, b, r)
|
||||
err := tmpl.ExpandVars(w, r, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -266,7 +266,7 @@ var modFields = map[string]struct {
|
||||
helpListItem("Response", "the response object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(FieldResponseBody, "HTTP STATUS: {{ .Request.Method }} {{ .Response.StatusCode }}"),
|
||||
helpExample(FieldResponseBody, "HTTP STATUS: $req_method $status_code"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"template": "the response body template",
|
||||
@@ -279,12 +279,12 @@ var modFields = map[string]struct {
|
||||
return validateTemplate(args[0], true)
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateOrStr)
|
||||
tmpl := args.(templateString)
|
||||
return &FieldHandler{
|
||||
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rm := GetInitResponseModifier(w)
|
||||
rm.ResetBody()
|
||||
return executeReqRespTemplateTo(tmpl, rm, rm, r)
|
||||
return tmpl.ExpandVars(w, r, rm)
|
||||
}),
|
||||
}
|
||||
},
|
||||
|
||||
@@ -367,7 +367,7 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "set body with template",
|
||||
template: "Hello {{ .Request.Method }} {{ .Request.URL.Path }}",
|
||||
template: "Hello $req_method $req_path",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "POST"
|
||||
r.URL.Path = "/test"
|
||||
@@ -424,7 +424,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "set response body with template",
|
||||
template: "Response: {{ .Request.Method }} {{ .Request.URL.Path }}",
|
||||
template: "Response: $req_method $req_path",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "GET"
|
||||
r.URL.Path = "/api/test"
|
||||
@@ -552,19 +552,19 @@ func TestFieldValidation(t *testing.T) {
|
||||
{
|
||||
name: "body valid template",
|
||||
field: FieldBody,
|
||||
args: []string{"Hello {{ .Request.Method }}"},
|
||||
args: []string{"Hello $req_method"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "body invalid template syntax",
|
||||
field: FieldBody,
|
||||
args: []string{"Hello {{ .InvalidField "},
|
||||
args: []string{"Hello $invalid_field"},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "response body valid template",
|
||||
field: FieldResponseBody,
|
||||
args: []string{"Response: {{ .Request.Method }}"},
|
||||
args: []string{"Response: $req_method"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -5,18 +5,25 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
|
||||
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
|
||||
ErrUnterminatedEnvVar = gperr.New("unterminated env var")
|
||||
ErrUnknownDirective = gperr.New("unknown directive")
|
||||
ErrUnknownModField = gperr.New("unknown field")
|
||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
||||
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
|
||||
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
|
||||
ErrUnterminatedParenthesis = gperr.New("unterminated parenthesis")
|
||||
ErrUnterminatedEnvVar = gperr.New("unterminated env var")
|
||||
ErrUnknownDirective = gperr.New("unknown directive")
|
||||
ErrUnknownModField = gperr.New("unknown field")
|
||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
||||
|
||||
// vars errors
|
||||
ErrNoArgProvided = gperr.New("no argument provided")
|
||||
ErrUnexpectedVar = gperr.New("unexpected variable")
|
||||
ErrUnexpectedQuote = gperr.New("unexpected quote")
|
||||
|
||||
ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg")
|
||||
ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg")
|
||||
ErrExpectOneOrTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 1 or 2 args")
|
||||
ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args")
|
||||
ErrExpectTwoOrThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 or 3 args")
|
||||
ErrExpectThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 3 args")
|
||||
|
||||
@@ -28,11 +28,11 @@ func helpExample(cmd string, args ...string) string {
|
||||
var out strings.Builder
|
||||
pos := 0
|
||||
for {
|
||||
start := strings.Index(arg[pos:], "{{")
|
||||
start := strings.IndexByte(arg[pos:], '$')
|
||||
if start == -1 {
|
||||
if pos < len(arg) {
|
||||
// If no template at all (pos == 0), cyan highlight for whole-arg
|
||||
// Otherwise, for mixed strings containing templates, leave non-template text unhighlighted
|
||||
// If no variable at all (pos == 0), cyan highlight for whole-arg
|
||||
// Otherwise, for mixed strings containing variables, leave non-variable text unhighlighted
|
||||
if pos == 0 {
|
||||
out.WriteString(ansi.WithANSI(arg[pos:], ansi.HighlightCyan))
|
||||
} else {
|
||||
@@ -43,20 +43,31 @@ func helpExample(cmd string, args ...string) string {
|
||||
}
|
||||
start += pos
|
||||
if start > pos {
|
||||
// Non-template text should not be highlighted
|
||||
// Non-variable text should not be highlighted
|
||||
out.WriteString(arg[pos:start])
|
||||
}
|
||||
end := strings.Index(arg[start+2:], "}}")
|
||||
if end == -1 {
|
||||
// Unmatched template start; write remainder without highlighting
|
||||
out.WriteString(arg[start:])
|
||||
break
|
||||
// Parse variable name and optional function call
|
||||
end := start + 1
|
||||
for end < len(arg) && (arg[end] == '_' || (arg[end] >= 'a' && arg[end] <= 'z') || (arg[end] >= 'A' && arg[end] <= 'Z') || (arg[end] >= '0' && arg[end] <= '9')) {
|
||||
end++
|
||||
}
|
||||
end += start + 2
|
||||
inner := strings.TrimSpace(arg[start+2 : end])
|
||||
parts := strings.Split(inner, ".")
|
||||
out.WriteString(helpTemplateVar(parts...))
|
||||
pos = end + 2
|
||||
// Check for function call
|
||||
if end < len(arg) && arg[end] == '(' {
|
||||
parenCount := 1
|
||||
end++
|
||||
for end < len(arg) && parenCount > 0 {
|
||||
switch arg[end] {
|
||||
case '(':
|
||||
parenCount++
|
||||
case ')':
|
||||
parenCount--
|
||||
}
|
||||
end++
|
||||
}
|
||||
}
|
||||
varExpr := arg[start:end]
|
||||
out.WriteString(helpVar(varExpr))
|
||||
pos = end
|
||||
}
|
||||
fmt.Fprintf(&sb, ` "%s"`, out.String())
|
||||
}
|
||||
@@ -87,17 +98,29 @@ func helpFuncCall(fn string, args ...string) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// helpTemplateVar generates a string like "{{ .Request.Method }} {{ .Request.URL.Path }}"
|
||||
func helpTemplateVar(parts ...string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(ansi.WithANSI("{{ ", ansi.HighlightWhite))
|
||||
for i, part := range parts {
|
||||
sb.WriteString(ansi.WithANSI(part, ansi.HighlightCyan))
|
||||
if i < len(parts)-1 {
|
||||
sb.WriteString(".")
|
||||
}
|
||||
// helpVar generates a highlighted string for a variable like "$req_method" or "$header(X-Test)"
|
||||
func helpVar(varExpr string) string {
|
||||
if !strings.HasPrefix(varExpr, "$") {
|
||||
return varExpr
|
||||
}
|
||||
sb.WriteString(ansi.WithANSI(" }}", ansi.HighlightWhite))
|
||||
|
||||
// Check if it's a function call
|
||||
parenIdx := strings.IndexByte(varExpr, '(')
|
||||
if parenIdx == -1 {
|
||||
// Simple variable like "$req_method"
|
||||
return ansi.WithANSI(varExpr, ansi.HighlightCyan)
|
||||
}
|
||||
|
||||
// Function call like "$header(X-Test)"
|
||||
var sb strings.Builder
|
||||
sb.WriteString(ansi.WithANSI(varExpr[:parenIdx], ansi.HighlightCyan))
|
||||
sb.WriteString(ansi.WithANSI("(", ansi.HighlightWhite))
|
||||
|
||||
// Extract and highlight the arguments
|
||||
argsStr := varExpr[parenIdx+1 : len(varExpr)-1]
|
||||
sb.WriteString(ansi.WithANSI(argsStr, ansi.HighlightYellow))
|
||||
|
||||
sb.WriteString(ansi.WithANSI(")", ansi.HighlightWhite))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-response
|
||||
on: path /test
|
||||
do: log info %s "{{ .Request.Method }} {{ .Response.StatusCode }}"
|
||||
do: log info %s "$req_method $status_code"
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -261,7 +261,7 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-errors
|
||||
on: status 4xx
|
||||
do: log error %s "{{ .Request.URL }} returned {{ .Response.StatusCode }}"
|
||||
do: log error %s "$req_url returned $status_code"
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -364,11 +364,11 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
do: require_basic_auth "Protected Area"
|
||||
- name: log-all-requests
|
||||
do: |
|
||||
log info %q "{{ .Request.Method }} {{ .Request.URL }} -> {{ .Response.StatusCode }}"
|
||||
log info %q "$req_method $req_url -> $status_code"
|
||||
- name: log-errors
|
||||
on: status 4xx
|
||||
do: |
|
||||
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"
|
||||
log error %q "ERROR: $req_method $req_url $status_code"
|
||||
`, logFile.Name(), errorLogFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -610,10 +610,10 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
err := parseRules(`
|
||||
- name: process-form
|
||||
on: form username
|
||||
do: set resp_header X-Username "{{ index .Request.Form.username 0 }}"
|
||||
do: set resp_header X-Username "$form(username)"
|
||||
- name: process-postform
|
||||
on: postform email
|
||||
do: set resp_header X-Email "{{ index .Request.PostForm.email 0 }}"
|
||||
do: set resp_header X-Email "$postform(email)"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -923,7 +923,7 @@ func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
- name: modify-response
|
||||
do: |
|
||||
set resp_header X-Modified "true"
|
||||
set resp_body "Modified: {{ .Request.Method }} {{ .Request.URL.Path }}"
|
||||
set resp_body "Modified: $req_method $req_path"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -591,7 +591,7 @@ func parseOn(line string) (Checker, bool, gperr.Error) {
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, false, err.Subject(subject).With(checker.help.Error())
|
||||
return nil, false, err.With(checker.help.Error())
|
||||
}
|
||||
|
||||
checkFunc := checker.builder(validArgs)
|
||||
|
||||
@@ -19,6 +19,12 @@ var escapedChars = map[rune]rune{
|
||||
' ': ' ',
|
||||
}
|
||||
|
||||
var quoteChars = [256]bool{
|
||||
'"': true,
|
||||
'\'': true,
|
||||
'`': true,
|
||||
}
|
||||
|
||||
// parse expression to subject and args
|
||||
// with support for quotes, escaped chars, and env substitution, e.g.
|
||||
//
|
||||
@@ -74,6 +80,19 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
}
|
||||
if quoteChars[r] {
|
||||
switch {
|
||||
case quote == 0 && brackets == 0:
|
||||
quote = r
|
||||
flush(false)
|
||||
case r == quote:
|
||||
quote = 0
|
||||
flush(true)
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '\\':
|
||||
escaped = true
|
||||
@@ -106,17 +125,6 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
switch {
|
||||
case quote == 0 && brackets == 0:
|
||||
quote = r
|
||||
flush(false)
|
||||
case r == quote:
|
||||
quote = 0
|
||||
flush(true)
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '(':
|
||||
brackets++
|
||||
buf.WriteRune(r)
|
||||
|
||||
@@ -1,48 +1,52 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type templateOrStr interface {
|
||||
Execute(w io.Writer, data any) error
|
||||
Len() int
|
||||
type templateString struct {
|
||||
string
|
||||
isTemplate bool
|
||||
}
|
||||
|
||||
type strTemplate string
|
||||
type keyValueTemplate struct {
|
||||
key string
|
||||
tmpl templateString
|
||||
}
|
||||
|
||||
func (t strTemplate) Execute(w io.Writer, _ any) error {
|
||||
n, err := w.Write([]byte(t))
|
||||
if err != nil {
|
||||
func (tmpl *keyValueTemplate) Unpack() (string, templateString) {
|
||||
return tmpl.key, tmpl.tmpl
|
||||
}
|
||||
|
||||
func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error {
|
||||
if !tmpl.isTemplate {
|
||||
_, err := dstW.Write(strtobNoCopy(tmpl.string))
|
||||
return err
|
||||
}
|
||||
if n != len(t) {
|
||||
return io.ErrShortWrite
|
||||
|
||||
return ExpandVars(GetInitResponseModifier(w), req, tmpl.string, dstW)
|
||||
}
|
||||
|
||||
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
|
||||
if !tmpl.isTemplate {
|
||||
return tmpl.string, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t strTemplate) Len() int {
|
||||
return len(t)
|
||||
}
|
||||
|
||||
type keyValueTemplate = Tuple[string, templateOrStr]
|
||||
|
||||
func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, reqResponseTemplateData{Request: r})
|
||||
var buf strings.Builder
|
||||
err := ExpandVars(GetInitResponseModifier(w), req, tmpl.string, &buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func executeRequestTemplateTo(tmpl templateOrStr, o io.Writer, r *http.Request) error {
|
||||
return tmpl.Execute(o, reqResponseTemplateData{Request: r})
|
||||
func (tmpl *templateString) Len() int {
|
||||
return len(tmpl.string)
|
||||
}
|
||||
|
||||
func executeReqRespTemplateTo(tmpl templateOrStr, o io.Writer, w http.ResponseWriter, r *http.Request) error {
|
||||
return tmpl.Execute(o, reqResponseTemplateData{Request: r, Response: GetInitResponseModifier(w).Response()})
|
||||
func strtobNoCopy(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
@@ -91,11 +90,11 @@ func toKeyValueTemplate(args []string) (any, gperr.Error) {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
tmpl, err := validateTemplate(args[1], false)
|
||||
isTemplate, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &keyValueTemplate{args[0], tmpl}, nil
|
||||
return &keyValueTemplate{args[0], isTemplate}, nil
|
||||
}
|
||||
|
||||
// validateURL returns types.URL with the URL validated.
|
||||
@@ -300,33 +299,20 @@ func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.E
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func isTemplate(tmplStr string) bool {
|
||||
return strings.Contains(tmplStr, "{{")
|
||||
}
|
||||
|
||||
type templateWithLen struct {
|
||||
*template.Template
|
||||
len int
|
||||
}
|
||||
|
||||
func (t *templateWithLen) Len() int {
|
||||
return t.len
|
||||
}
|
||||
|
||||
func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) {
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, gperr.Error) {
|
||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||
tmplStr += "\n"
|
||||
}
|
||||
|
||||
if !isTemplate(tmplStr) {
|
||||
return strTemplate(tmplStr), nil
|
||||
if !NeedExpandVars(tmplStr) {
|
||||
return templateString{tmplStr, false}, nil
|
||||
}
|
||||
|
||||
tmpl, err := template.New("template").Parse(tmplStr)
|
||||
err := ValidateVars(tmplStr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return templateString{}, gperr.Wrap(err)
|
||||
}
|
||||
return &templateWithLen{tmpl, len(tmplStr)}, nil
|
||||
return templateString{tmplStr, true}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||
|
||||
28
internal/route/rules/var_bench_test.go
Normal file
28
internal/route/rules/var_bench_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkExpandVars(b *testing.B) {
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.WriteHeader(200)
|
||||
testResponseModifier.Write([]byte("Hello, world!"))
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest.Header.Set("User-Agent", "test-agent/1.0")
|
||||
testRequest.Header.Set("X-Custom", "value1,value2")
|
||||
testRequest.ContentLength = 12345
|
||||
testRequest.RemoteAddr = "192.168.1.100:54321"
|
||||
testRequest.Form = url.Values{"param1": {"value1"}, "param2": {"value2"}}
|
||||
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
|
||||
|
||||
for b.Loop() {
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
214
internal/route/rules/vars.go
Normal file
214
internal/route/rules/vars.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
|
||||
// TODO: remove middleware/vars.go and use this instead
|
||||
|
||||
type (
|
||||
reqVarGetter func(*http.Request) string
|
||||
respVarGetter func(*ResponseModifier) string
|
||||
)
|
||||
|
||||
var reVar = regexp.MustCompile(`\$[\w_]+`)
|
||||
|
||||
var validVarNameCharset = func() (ret [256]bool) {
|
||||
for c := byte('a'); c <= 'z'; c++ {
|
||||
ret[c] = true
|
||||
}
|
||||
for c := byte('A'); c <= 'Z'; c++ {
|
||||
ret[c] = true
|
||||
}
|
||||
ret['_'] = true
|
||||
return
|
||||
}()
|
||||
|
||||
func NeedExpandVars(s string) bool {
|
||||
return reVar.MatchString(s)
|
||||
}
|
||||
|
||||
var (
|
||||
voidResponseModifier = NewResponseModifier(httptest.NewRecorder())
|
||||
dummyRequest = http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/"},
|
||||
Header: http.Header{},
|
||||
}
|
||||
)
|
||||
|
||||
// ValidateVars validates the variables in the given string.
|
||||
// It returns ErrUnexpectedVar if any invalid variable is found.
|
||||
func ValidateVars(s string) error {
|
||||
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
||||
}
|
||||
|
||||
func ExpandVars(w *ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
|
||||
dst := ioutils.NewBufferedWriter(dstW, 1024)
|
||||
defer func() {
|
||||
dst.Flush()
|
||||
dst.Release()
|
||||
}()
|
||||
|
||||
for i := 0; i < len(src); i++ {
|
||||
ch := src[i]
|
||||
if ch != '$' {
|
||||
if err := dst.WriteByte(ch); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Look ahead
|
||||
if i+1 >= len(src) {
|
||||
return ErrUnterminatedEnvVar
|
||||
}
|
||||
j := i + 1
|
||||
|
||||
switch src[j] {
|
||||
case '$': // $$ -> literal '$'
|
||||
if err := dst.WriteByte('$'); err != nil {
|
||||
return err
|
||||
}
|
||||
i = j
|
||||
continue
|
||||
case '{': // ${...} pass through as-is
|
||||
if _, err := dst.WriteString("${"); err != nil {
|
||||
return err
|
||||
}
|
||||
i = j // we've consumed the '{' too
|
||||
continue
|
||||
}
|
||||
|
||||
if validVarNameCharset[src[j]] {
|
||||
k := j
|
||||
for k < len(src) {
|
||||
c := src[k]
|
||||
if validVarNameCharset[c] {
|
||||
k++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
name := src[j:k]
|
||||
isStatic := true
|
||||
|
||||
var actual string
|
||||
if getter, ok := dynamicVarSubsMap[name]; ok {
|
||||
// Function-like variables
|
||||
isStatic = false
|
||||
args, nextIdx, err := extractArgs(src, j, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i = nextIdx
|
||||
actual, err = getter(args, w, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if getter, ok := staticReqVarSubsMap[name]; ok {
|
||||
actual = getter(req)
|
||||
} else if getter, ok := staticRespVarSubsMap[name]; ok {
|
||||
actual = getter(w)
|
||||
} else {
|
||||
return ErrUnexpectedVar.Subject(name)
|
||||
}
|
||||
if _, err := dst.WriteString(actual); err != nil {
|
||||
return err
|
||||
}
|
||||
if isStatic {
|
||||
i = k - 1
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// No valid construct after '$'
|
||||
return ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
|
||||
// Find opening parenthesis
|
||||
parenIdx := strings.IndexByte(src[i:], '(')
|
||||
if parenIdx == -1 {
|
||||
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q at position %d", funcName, i)
|
||||
}
|
||||
parenIdx += i
|
||||
|
||||
var (
|
||||
quote byte // current quote character (0 if not in quotes)
|
||||
arg strings.Builder
|
||||
)
|
||||
|
||||
nextIdx = parenIdx + 1
|
||||
for nextIdx < len(src) {
|
||||
ch := src[nextIdx]
|
||||
|
||||
if quote != 0 {
|
||||
// We're inside a quoted string
|
||||
if ch == quote {
|
||||
// Closing quote - the content between quotes is now complete, add it
|
||||
args = append(args, arg.String())
|
||||
arg.Reset()
|
||||
quote = 0
|
||||
nextIdx++
|
||||
continue
|
||||
}
|
||||
// Inside quotes - add everything as-is
|
||||
arg.WriteByte(ch)
|
||||
nextIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Not inside quotes
|
||||
if quoteChars[ch] {
|
||||
// Opening quote
|
||||
quote = ch
|
||||
nextIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ')' {
|
||||
// End of arguments
|
||||
if arg.Len() > 0 {
|
||||
args = append(args, arg.String())
|
||||
}
|
||||
return args, nextIdx, nil
|
||||
}
|
||||
|
||||
if ch == ',' {
|
||||
// Argument separator
|
||||
if arg.Len() > 0 {
|
||||
args = append(args, arg.String())
|
||||
arg.Reset()
|
||||
}
|
||||
nextIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
|
||||
// Whitespace outside quotes - skip
|
||||
nextIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Regular character - accumulate until we hit a delimiter
|
||||
arg.WriteByte(ch)
|
||||
nextIdx++
|
||||
}
|
||||
|
||||
// Reached end of string without closing parenthesis
|
||||
if quote != 0 {
|
||||
return nil, 0, ErrUnterminatedQuotes.Withf("func %q", funcName)
|
||||
}
|
||||
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q", funcName)
|
||||
}
|
||||
81
internal/route/rules/vars_dynamic.go
Normal file
81
internal/route/rules/vars_dynamic.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
VarHeader = "header"
|
||||
VarResponseHeader = "resp_header"
|
||||
VarQuery = "arg"
|
||||
VarForm = "form"
|
||||
VarPostForm = "postform"
|
||||
)
|
||||
|
||||
type dynamicVarGetter func(args []string, w *ResponseModifier, req *http.Request) (string, error)
|
||||
|
||||
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
||||
VarHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Header, key, index)
|
||||
},
|
||||
VarResponseHeader: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||
},
|
||||
VarQuery: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(GetSharedData(w).GetQueries(req), key, index)
|
||||
},
|
||||
VarForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Form, key, index)
|
||||
},
|
||||
VarPostForm: func(args []string, w *ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(req.PostForm, key, index)
|
||||
},
|
||||
}
|
||||
|
||||
func getValueByKeyAtIndex[Values http.Header | url.Values](values Values, key string, index int) (string, error) {
|
||||
// NOTE: do not use Header.Get or http.CanonicalHeaderKey here, respect to user input
|
||||
if values, ok := values[key]; ok && index < len(values) {
|
||||
return values[index], nil
|
||||
}
|
||||
// ignore unknown header or index out of range
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func getKeyAndIndex(args []string) (key string, index int, err error) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return "", 0, ErrExpectNoArg
|
||||
case 1:
|
||||
return args[0], 0, nil
|
||||
case 2:
|
||||
index, err = strconv.Atoi(args[1])
|
||||
if err != nil {
|
||||
return "", 0, ErrInvalidArguments.Withf("invalid index %q", args[1])
|
||||
}
|
||||
return args[0], index, nil
|
||||
default:
|
||||
return "", 0, ErrExpectOneOrTwoArgs
|
||||
}
|
||||
}
|
||||
92
internal/route/rules/vars_static.go
Normal file
92
internal/route/rules/vars_static.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
)
|
||||
|
||||
const (
|
||||
VarRequestMethod = "req_method"
|
||||
VarRequestScheme = "req_scheme"
|
||||
VarRequestHost = "req_host"
|
||||
VarRequestPort = "req_port"
|
||||
VarRequestPath = "req_path"
|
||||
VarRequestAddr = "req_addr"
|
||||
VarRequestQuery = "req_query"
|
||||
VarRequestURL = "req_url"
|
||||
VarRequestURI = "req_uri"
|
||||
VarRequestContentType = "req_content_type"
|
||||
VarRequestContentLen = "req_content_length"
|
||||
VarRemoteHost = "remote_host"
|
||||
VarRemotePort = "remote_port"
|
||||
VarRemoteAddr = "remote_addr"
|
||||
|
||||
VarUpstreamName = "upstream_name"
|
||||
VarUpstreamScheme = "upstream_scheme"
|
||||
VarUpstreamHost = "upstream_host"
|
||||
VarUpstreamPort = "upstream_port"
|
||||
VarUpstreamAddr = "upstream_addr"
|
||||
VarUpstreamURL = "upstream_url"
|
||||
|
||||
VarRespContentType = "resp_content_type"
|
||||
VarRespContentLen = "resp_content_length"
|
||||
VarRespStatusCode = "status_code"
|
||||
)
|
||||
|
||||
var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||
VarRequestMethod: func(req *http.Request) string { return req.Method },
|
||||
VarRequestScheme: func(req *http.Request) string {
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
},
|
||||
VarRequestHost: func(req *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
return req.Host
|
||||
}
|
||||
return reqHost
|
||||
},
|
||||
VarRequestPort: func(req *http.Request) string {
|
||||
_, reqPort, _ := net.SplitHostPort(req.Host)
|
||||
return reqPort
|
||||
},
|
||||
VarRequestAddr: func(req *http.Request) string { return req.Host },
|
||||
VarRequestPath: func(req *http.Request) string { return req.URL.Path },
|
||||
VarRequestQuery: func(req *http.Request) string { return req.URL.RawQuery },
|
||||
VarRequestURL: func(req *http.Request) string { return req.URL.String() },
|
||||
VarRequestURI: func(req *http.Request) string { return req.URL.RequestURI() },
|
||||
VarRequestContentType: func(req *http.Request) string { return req.Header.Get("Content-Type") },
|
||||
VarRequestContentLen: func(req *http.Request) string { return strconv.FormatInt(req.ContentLength, 10) },
|
||||
VarRemoteHost: func(req *http.Request) string {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientIP
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemotePort: func(req *http.Request) string {
|
||||
_, clientPort, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return clientPort
|
||||
}
|
||||
return ""
|
||||
},
|
||||
VarRemoteAddr: func(req *http.Request) string { return req.RemoteAddr },
|
||||
VarUpstreamName: routes.TryGetUpstreamName,
|
||||
VarUpstreamScheme: routes.TryGetUpstreamScheme,
|
||||
VarUpstreamHost: routes.TryGetUpstreamHost,
|
||||
VarUpstreamPort: routes.TryGetUpstreamPort,
|
||||
VarUpstreamAddr: routes.TryGetUpstreamAddr,
|
||||
VarUpstreamURL: routes.TryGetUpstreamURL,
|
||||
}
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) },
|
||||
VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
|
||||
}
|
||||
672
internal/route/rules/vars_test.go
Normal file
672
internal/route/rules/vars_test.go
Normal file
@@ -0,0 +1,672 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src string
|
||||
startPos int
|
||||
funcName string
|
||||
wantArgs []string
|
||||
wantNextIdx int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "unquoted single arg",
|
||||
src: "header(X-Some-Header)",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header"},
|
||||
wantNextIdx: 20,
|
||||
},
|
||||
{
|
||||
name: "double quoted arg",
|
||||
src: `header("X-Some-Header")`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header"},
|
||||
wantNextIdx: 22,
|
||||
},
|
||||
{
|
||||
name: "single quoted arg",
|
||||
src: "header('X-Some-Header')",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header"},
|
||||
wantNextIdx: 22,
|
||||
},
|
||||
{
|
||||
name: "backtick quoted arg",
|
||||
src: "header(`X-Some-Header`)",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header"},
|
||||
wantNextIdx: 22,
|
||||
},
|
||||
{
|
||||
name: "two args with double quotes and unquoted",
|
||||
src: `header("X-Some-Header", 1)`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header", "1"},
|
||||
wantNextIdx: 25,
|
||||
},
|
||||
{
|
||||
name: "two args with single and double quotes",
|
||||
src: "header('X-Some-Header', \"1\")",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header", "1"},
|
||||
wantNextIdx: 27,
|
||||
},
|
||||
{
|
||||
name: "two args with backtick and single quotes",
|
||||
src: "header(`X-Some-Header`, '1')",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Some-Header", "1"},
|
||||
wantNextIdx: 27,
|
||||
},
|
||||
{
|
||||
name: "quoted string with nested different quotes",
|
||||
src: `arg("'(value)'")`,
|
||||
startPos: 0,
|
||||
funcName: "arg",
|
||||
wantArgs: []string{"'(value)'"},
|
||||
wantNextIdx: 15,
|
||||
},
|
||||
{
|
||||
name: "quoted string with backticks inside double quotes",
|
||||
src: "header(\"value`with`backticks\")",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"value`with`backticks"},
|
||||
wantNextIdx: 29,
|
||||
},
|
||||
{
|
||||
name: "multiple args with whitespace",
|
||||
src: "header( \"X-Header\" , 2 )",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Header", "2"},
|
||||
wantNextIdx: 27,
|
||||
},
|
||||
{
|
||||
name: "empty quoted string",
|
||||
src: `header("")`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{""},
|
||||
wantNextIdx: 9,
|
||||
},
|
||||
{
|
||||
name: "multiple empty args",
|
||||
src: `header("", "")`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"", ""},
|
||||
wantNextIdx: 13,
|
||||
},
|
||||
{
|
||||
name: "unquoted args separated by comma",
|
||||
src: "header(key1,key2,key3)",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"key1", "key2", "key3"},
|
||||
wantNextIdx: 21,
|
||||
},
|
||||
{
|
||||
name: "trailing whitespace before closing paren",
|
||||
src: `header("value" )`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"value"},
|
||||
wantNextIdx: 16,
|
||||
},
|
||||
{
|
||||
name: "startPos not at beginning",
|
||||
src: "prefix_header(X-Header)",
|
||||
startPos: 7,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Header"},
|
||||
wantNextIdx: 22,
|
||||
},
|
||||
{
|
||||
name: "special chars in unquoted arg",
|
||||
src: "header(X-Custom_Header.v1)",
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantArgs: []string{"X-Custom_Header.v1"},
|
||||
wantNextIdx: 25,
|
||||
},
|
||||
{
|
||||
name: "unterminated quote",
|
||||
src: `header("X-Header`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing closing parenthesis",
|
||||
src: `header("X-Header"`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no opening parenthesis",
|
||||
src: `header"X-Header"`,
|
||||
startPos: 0,
|
||||
funcName: "header",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args, nextIdx, err := extractArgs(tt.src, tt.startPos, tt.funcName)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantArgs, args)
|
||||
require.Equal(t, tt.wantNextIdx, nextIdx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandVars(t *testing.T) {
|
||||
// Create a comprehensive test request
|
||||
testRequest := httptest.NewRequest("POST", "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment", nil)
|
||||
testRequest.Header.Set("Content-Type", "application/json")
|
||||
testRequest.Header.Set("User-Agent", "test-agent/1.0")
|
||||
testRequest.Header.Set("X-Custom", "value1,value2")
|
||||
testRequest.ContentLength = 12345
|
||||
testRequest.RemoteAddr = "192.168.1.100:54321"
|
||||
|
||||
// Create response modifier with headers
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.Header().Set("Content-Type", "text/html")
|
||||
testResponseModifier.Header().Set("X-Custom-Resp", "resp-value")
|
||||
testResponseModifier.WriteHeader(200)
|
||||
// set content length to 9876 by writing 9876 'a' bytes
|
||||
testResponseModifier.Write(bytes.Repeat([]byte("a"), 9876))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
// Basic request variables
|
||||
{
|
||||
name: "req_method",
|
||||
input: "$req_method",
|
||||
want: "POST",
|
||||
},
|
||||
{
|
||||
name: "req_path",
|
||||
input: "$req_path",
|
||||
want: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "req_query",
|
||||
input: "$req_query",
|
||||
want: "param1=value1¶m2=value2",
|
||||
},
|
||||
{
|
||||
name: "req_url",
|
||||
input: "$req_url",
|
||||
want: "https://example.com:8080/api/users?param1=value1¶m2=value2#fragment",
|
||||
},
|
||||
{
|
||||
name: "req_uri",
|
||||
input: "$req_uri",
|
||||
want: "/api/users?param1=value1¶m2=value2",
|
||||
},
|
||||
{
|
||||
name: "req_host",
|
||||
input: "$req_host",
|
||||
want: "example.com",
|
||||
},
|
||||
{
|
||||
name: "req_port",
|
||||
input: "$req_port",
|
||||
want: "8080",
|
||||
},
|
||||
{
|
||||
name: "req_addr",
|
||||
input: "$req_addr",
|
||||
want: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "req_content_type",
|
||||
input: "$req_content_type",
|
||||
want: "application/json",
|
||||
},
|
||||
{
|
||||
name: "req_content_length",
|
||||
input: "$req_content_length",
|
||||
want: "12345",
|
||||
},
|
||||
{
|
||||
name: "remote_host",
|
||||
input: "$remote_host",
|
||||
want: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "remote_port",
|
||||
input: "$remote_port",
|
||||
want: "54321",
|
||||
},
|
||||
{
|
||||
name: "remote_addr",
|
||||
input: "$remote_addr",
|
||||
want: "192.168.1.100:54321",
|
||||
},
|
||||
// Response variables
|
||||
{
|
||||
name: "status_code",
|
||||
input: "$status_code",
|
||||
want: "200",
|
||||
},
|
||||
{
|
||||
name: "resp_content_type",
|
||||
input: "$resp_content_type",
|
||||
want: "text/html",
|
||||
},
|
||||
{
|
||||
name: "resp_content_length",
|
||||
input: "$resp_content_length",
|
||||
want: "9876",
|
||||
},
|
||||
// Function-like variables - header
|
||||
{
|
||||
name: "header single value",
|
||||
input: "$header(User-Agent)",
|
||||
want: "test-agent/1.0",
|
||||
},
|
||||
{
|
||||
name: "header with index 0",
|
||||
input: "$header(X-Custom, 0)",
|
||||
want: "value1",
|
||||
},
|
||||
{
|
||||
name: "header with index 1",
|
||||
input: "$header(X-Custom, 1)",
|
||||
want: "value2",
|
||||
},
|
||||
{
|
||||
name: "header not found",
|
||||
input: "$header(X-Not-Found)",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "header index out of range",
|
||||
input: "$header(X-Custom, 99)",
|
||||
want: "",
|
||||
},
|
||||
// Function-like variables - resp_header
|
||||
{
|
||||
name: "resp_header single value",
|
||||
input: "$resp_header(Content-Type)",
|
||||
want: "text/html",
|
||||
},
|
||||
{
|
||||
name: "resp_header custom header",
|
||||
input: "$resp_header(X-Custom-Resp)",
|
||||
want: "resp-value",
|
||||
},
|
||||
{
|
||||
name: "resp_header not found",
|
||||
input: "$resp_header(X-Not-Found)",
|
||||
want: "",
|
||||
},
|
||||
// Function-like variables - arg (query parameters)
|
||||
{
|
||||
name: "arg single parameter",
|
||||
input: "$arg(param1)",
|
||||
want: "value1",
|
||||
},
|
||||
{
|
||||
name: "arg second parameter",
|
||||
input: "$arg(param2)",
|
||||
want: "value2",
|
||||
},
|
||||
{
|
||||
name: "arg not found",
|
||||
input: "$arg(param3)",
|
||||
want: "",
|
||||
},
|
||||
// Mixed variables
|
||||
{
|
||||
name: "mixed variables",
|
||||
input: "$req_method $req_path $status_code",
|
||||
want: "POST /api/users 200",
|
||||
},
|
||||
{
|
||||
name: "variables with text",
|
||||
input: "Method: $req_method, Path: $req_path",
|
||||
want: "Method: POST, Path: /api/users",
|
||||
},
|
||||
{
|
||||
name: "function variables with text",
|
||||
input: "Header: $header(User-Agent), Status: $status_code",
|
||||
want: "Header: test-agent/1.0, Status: 200",
|
||||
},
|
||||
// Escaped dollar signs
|
||||
{
|
||||
name: "escaped dollar",
|
||||
input: "$$req_method",
|
||||
want: "$req_method",
|
||||
},
|
||||
{
|
||||
name: "mixed escaped and unescaped",
|
||||
input: "$$req_method $req_path",
|
||||
want: "$req_method /api/users",
|
||||
},
|
||||
// Environment variable syntax ${}
|
||||
{
|
||||
name: "env var syntax",
|
||||
input: "${VAR}",
|
||||
want: "${VAR}",
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
name: "unknown variable",
|
||||
input: "$unknown_var",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid function syntax",
|
||||
input: "$arg(param1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete dollar",
|
||||
input: "test$",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, out.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandVars_Integration(t *testing.T) {
|
||||
t.Run("complex log format", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "https://api.example.com/users/123?sort=asc", nil)
|
||||
testRequest.Header.Set("User-Agent", "curl/7.68.0")
|
||||
testRequest.RemoteAddr = "10.0.0.1:54321"
|
||||
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.WriteHeader(200)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
||||
&out)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "GET https://api.example.com/users/123?sort=asc 200 User-Agent=curl/7.68.0", out.String())
|
||||
})
|
||||
|
||||
t.Run("with query parameters", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "http://example.com/search?q=test&page=1", nil)
|
||||
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
"Query: $arg(q), Page: $arg(page)",
|
||||
&out)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Query: test, Page: 1", out.String())
|
||||
})
|
||||
|
||||
t.Run("response headers", func(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
testResponseModifier.Header().Set("Cache-Control", "no-cache")
|
||||
testResponseModifier.Header().Set("X-Rate-Limit", "100")
|
||||
testResponseModifier.WriteHeader(200)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
||||
&out)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Status: 200, Cache: no-cache, Limit: 100", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request *http.Request
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "http scheme",
|
||||
request: httptest.NewRequest("GET", "http://example.com/", nil),
|
||||
expected: "http",
|
||||
},
|
||||
{
|
||||
name: "https scheme",
|
||||
request: &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
|
||||
TLS: &tls.ConnectionState{}, // Simulate TLS connection
|
||||
},
|
||||
expected: "https",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, out.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
// Upstream variables require context from routes package
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
// Test that upstream variables don't cause errors even when not set
|
||||
upstreamVars := []string{
|
||||
"$upstream_name",
|
||||
"$upstream_scheme",
|
||||
"$upstream_host",
|
||||
"$upstream_port",
|
||||
"$upstream_addr",
|
||||
"$upstream_url",
|
||||
}
|
||||
|
||||
for _, varExpr := range upstreamVars {
|
||||
t.Run(varExpr, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
||||
// Should not error, may return empty string
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
// Test request without port in Host header
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest.Host = "example.com" // No port
|
||||
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
t.Run("req_host without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "example.com", out.String())
|
||||
})
|
||||
|
||||
t.Run("req_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
// Test request without port in RemoteAddr
|
||||
testRequest := httptest.NewRequest("GET", "/", nil)
|
||||
testRequest.RemoteAddr = "192.168.1.1" // No port
|
||||
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
t.Run("remote_host without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
|
||||
t.Run("remote_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
||||
testRequest := httptest.NewRequest("GET", "/test", nil)
|
||||
testResponseModifier := NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "GET /test", out.String())
|
||||
}
|
||||
|
||||
func TestValidateVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid simple variable",
|
||||
input: "$req_method",
|
||||
},
|
||||
{
|
||||
name: "valid function variable",
|
||||
input: "$header(User-Agent)",
|
||||
},
|
||||
{
|
||||
name: "valid response variable",
|
||||
input: "$status_code",
|
||||
},
|
||||
{
|
||||
name: "invalid variable",
|
||||
input: "$unknown_var",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "incomplete variable",
|
||||
input: "test$",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid variables with text",
|
||||
input: "Method: $req_method",
|
||||
},
|
||||
{
|
||||
name: "valid escaped dollar",
|
||||
input: "$$req_method",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateVars(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedExpandVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "contains variable",
|
||||
input: "$req_method",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains function variable",
|
||||
input: "$header(X-Test)",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no variable",
|
||||
input: "plain text",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "escaped dollar",
|
||||
input: "$$req_method",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "mixed content",
|
||||
input: "Method: $req_method",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NeedExpandVars(tt.input)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user