mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-24 01:51:10 +01:00
Allow $func(...) expressions inside function arguments by extracting nested calls and expanding them before evaluation.
344 lines
8.0 KiB
Go
344 lines
8.0 KiB
Go
package rules
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"unsafe"
|
|
|
|
httputils "github.com/yusing/goutils/http"
|
|
)
|
|
|
|
// TODO: remove middleware/vars.go and use this instead
|
|
|
|
type (
|
|
reqVarGetter func(*http.Request) string
|
|
respVarGetter func(*httputils.ResponseModifier) string
|
|
)
|
|
|
|
var reVar = regexp.MustCompile(`\$[\w_]+`)
|
|
|
|
var validVarNameCharset = func() (ret [256]bool) {
|
|
for c := byte('a'); c <= 'z'; c++ {
|
|
ret[c] = true
|
|
}
|
|
for c := byte('A'); c <= 'Z'; c++ {
|
|
ret[c] = true
|
|
}
|
|
ret['_'] = true
|
|
return
|
|
}()
|
|
|
|
func NeedExpandVars(s string) bool {
|
|
return reVar.MatchString(s)
|
|
}
|
|
|
|
var (
|
|
voidResponseModifier = httputils.NewResponseModifier(httptest.NewRecorder())
|
|
dummyRequest = http.Request{
|
|
Method: http.MethodGet,
|
|
URL: &url.URL{Path: "/"},
|
|
Header: http.Header{},
|
|
}
|
|
)
|
|
|
|
type bytesBufferLike interface {
|
|
io.Writer
|
|
WriteByte(c byte) error
|
|
WriteString(s string) (int, error)
|
|
}
|
|
|
|
type bytesBufferAdapter struct {
|
|
io.Writer
|
|
}
|
|
|
|
func (b bytesBufferAdapter) WriteByte(c byte) error {
|
|
buf := [1]byte{c}
|
|
_, err := b.Write(buf[:])
|
|
return err
|
|
}
|
|
|
|
func (b bytesBufferAdapter) WriteString(s string) (int, error) {
|
|
return b.Write(unsafe.Slice(unsafe.StringData(s), len(s))) // avoid copy
|
|
}
|
|
|
|
func asBytesBufferLike(w io.Writer) bytesBufferLike {
|
|
switch w := w.(type) {
|
|
case *bytes.Buffer:
|
|
return w
|
|
case bytesBufferLike:
|
|
return w
|
|
default:
|
|
return bytesBufferAdapter{w}
|
|
}
|
|
}
|
|
|
|
// ValidateVars validates the variables in the given string.
|
|
// It returns the phase that the variables require and an error if any error occurs.
|
|
//
|
|
// Possible errors:
|
|
// - ErrUnexpectedVar: if any invalid variable is found
|
|
// - ErrUnterminatedEnvVar: missing closing }
|
|
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
|
// - ErrUnterminatedParenthesis: missing closing )
|
|
func ValidateVars(s string) (phase PhaseFlag, err error) {
|
|
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
|
}
|
|
|
|
// ExpandVars expands the variables in the given string and writes the result to the given writer.
|
|
// It returns the phase that the variables require and an error if any error occurs.
|
|
//
|
|
// Possible errors:
|
|
// - ErrUnexpectedVar: if any invalid variable is found
|
|
// - ErrUnterminatedEnvVar: missing closing }
|
|
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
|
// - ErrUnterminatedParenthesis: missing closing )
|
|
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) (phase PhaseFlag, err error) {
|
|
dst := asBytesBufferLike(dstW)
|
|
for i := 0; i < len(src); i++ {
|
|
ch := src[i]
|
|
if ch != '$' {
|
|
if err = dst.WriteByte(ch); err != nil {
|
|
return phase, err
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Look ahead
|
|
if i+1 >= len(src) {
|
|
return phase, ErrUnterminatedEnvVar
|
|
}
|
|
j := i + 1
|
|
|
|
switch src[j] {
|
|
case '$': // $$ -> literal '$'
|
|
if err := dst.WriteByte('$'); err != nil {
|
|
return phase, err
|
|
}
|
|
i = j
|
|
continue
|
|
case '{': // ${...} pass through as-is
|
|
if _, err := dst.WriteString("${"); err != nil {
|
|
return phase, err
|
|
}
|
|
i = j // we've consumed the '{' too
|
|
continue
|
|
}
|
|
|
|
if validVarNameCharset[src[j]] {
|
|
k := j
|
|
for k < len(src) {
|
|
c := src[k]
|
|
if validVarNameCharset[c] {
|
|
k++
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
name := src[j:k]
|
|
isStatic := true
|
|
|
|
var actual string
|
|
if getter, ok := dynamicVarSubsMap[name]; ok {
|
|
// Function-like variables
|
|
isStatic = false
|
|
phase |= getter.phase
|
|
args, nextIdx, err := extractArgs(src, j, name)
|
|
if err != nil {
|
|
return phase, err
|
|
}
|
|
i = nextIdx
|
|
// Expand any nested $func(...) expressions in args
|
|
args, argPhase, err := expandArgs(args, w, req)
|
|
if err != nil {
|
|
return phase, err
|
|
}
|
|
phase |= argPhase
|
|
actual, err = getter.get(args, w, req)
|
|
if err != nil {
|
|
return phase, err
|
|
}
|
|
} else if getter, ok := staticReqVarSubsMap[name]; ok { // always available
|
|
actual = getter(req)
|
|
} else if getter, ok := staticRespVarSubsMap[name]; ok { // post response
|
|
actual = getter(w)
|
|
phase |= PhasePost
|
|
} else {
|
|
return phase, ErrUnexpectedVar.Subject(name)
|
|
}
|
|
if _, err := dst.WriteString(actual); err != nil {
|
|
return phase, err
|
|
}
|
|
if isStatic {
|
|
i = k - 1
|
|
}
|
|
continue
|
|
}
|
|
|
|
// No valid construct after '$'
|
|
return phase, ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
|
}
|
|
|
|
return phase, nil
|
|
}
|
|
|
|
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
|
|
// Find opening parenthesis
|
|
parenIdx := strings.IndexByte(src[i:], '(')
|
|
if parenIdx == -1 {
|
|
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q at position %d", funcName, i)
|
|
}
|
|
parenIdx += i
|
|
|
|
var (
|
|
quote byte // current quote character (0 if not in quotes)
|
|
arg strings.Builder
|
|
)
|
|
|
|
nextIdx = parenIdx + 1
|
|
for nextIdx < len(src) {
|
|
ch := src[nextIdx]
|
|
|
|
if quote != 0 {
|
|
// We're inside a quoted string
|
|
if ch == quote {
|
|
// Closing quote - the content between quotes is now complete, add it
|
|
args = append(args, arg.String())
|
|
arg.Reset()
|
|
quote = 0
|
|
nextIdx++
|
|
continue
|
|
}
|
|
// Inside quotes - add everything as-is
|
|
arg.WriteByte(ch)
|
|
nextIdx++
|
|
continue
|
|
}
|
|
|
|
// Not inside quotes
|
|
if quoteChars[ch] {
|
|
// Opening quote
|
|
quote = ch
|
|
nextIdx++
|
|
continue
|
|
}
|
|
|
|
// Nested function call: $func(...) as an argument
|
|
if ch == '$' && arg.Len() == 0 {
|
|
// Capture the entire $func(...) expression as a raw argument token
|
|
nestedEnd, nestedErr := extractNestedFuncExpr(src, nextIdx)
|
|
if nestedErr != nil {
|
|
return nil, 0, nestedErr
|
|
}
|
|
args = append(args, src[nextIdx:nestedEnd+1])
|
|
nextIdx = nestedEnd + 1
|
|
continue
|
|
}
|
|
|
|
if ch == ')' {
|
|
// End of arguments
|
|
if arg.Len() > 0 {
|
|
args = append(args, arg.String())
|
|
}
|
|
return args, nextIdx, nil
|
|
}
|
|
|
|
if ch == ',' {
|
|
// Argument separator
|
|
if arg.Len() > 0 {
|
|
args = append(args, arg.String())
|
|
arg.Reset()
|
|
}
|
|
nextIdx++
|
|
continue
|
|
}
|
|
|
|
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
|
|
// Whitespace outside quotes - skip
|
|
nextIdx++
|
|
continue
|
|
}
|
|
|
|
// Regular character - accumulate until we hit a delimiter
|
|
arg.WriteByte(ch)
|
|
nextIdx++
|
|
}
|
|
|
|
// Reached end of string without closing parenthesis
|
|
if quote != 0 {
|
|
return nil, 0, ErrUnterminatedQuotes.Withf("func %q", funcName)
|
|
}
|
|
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q", funcName)
|
|
}
|
|
|
|
// extractNestedFuncExpr finds the end index (inclusive) of a $func(...) expression
|
|
// starting at position start in src. It handles nested parentheses.
|
|
func extractNestedFuncExpr(src string, start int) (endIdx int, err error) {
|
|
// src[start] must be '$'
|
|
i := start + 1
|
|
// skip the function name (valid var name chars)
|
|
for i < len(src) && validVarNameCharset[src[i]] {
|
|
i++
|
|
}
|
|
if i >= len(src) || src[i] != '(' {
|
|
return 0, ErrUnterminatedParenthesis.Withf("nested func at position %d", start)
|
|
}
|
|
// Now find the matching closing parenthesis, respecting quotes and nesting
|
|
depth := 0
|
|
var quote byte
|
|
for i < len(src) {
|
|
ch := src[i]
|
|
if quote != 0 {
|
|
if ch == quote {
|
|
quote = 0
|
|
}
|
|
i++
|
|
continue
|
|
}
|
|
if quoteChars[ch] {
|
|
quote = ch
|
|
i++
|
|
continue
|
|
}
|
|
switch ch {
|
|
case '(':
|
|
depth++
|
|
case ')':
|
|
depth--
|
|
if depth == 0 {
|
|
return i, nil
|
|
}
|
|
}
|
|
i++
|
|
}
|
|
if quote != 0 {
|
|
return 0, ErrUnterminatedQuotes.Withf("nested func at position %d", start)
|
|
}
|
|
return 0, ErrUnterminatedParenthesis.Withf("nested func at position %d", start)
|
|
}
|
|
|
|
// expandArgs expands any args that are nested dynamic var expressions (starting with '$').
|
|
// It returns the expanded args and the combined phase flags.
|
|
func expandArgs(args []string, w *httputils.ResponseModifier, req *http.Request) (expanded []string, phase PhaseFlag, err error) {
|
|
expanded = make([]string, len(args))
|
|
for i, arg := range args {
|
|
if len(arg) > 0 && arg[0] == '$' {
|
|
var buf strings.Builder
|
|
var argPhase PhaseFlag
|
|
argPhase, err = ExpandVars(w, req, arg, &buf)
|
|
if err != nil {
|
|
return nil, phase, err
|
|
}
|
|
phase |= argPhase
|
|
expanded[i] = buf.String()
|
|
} else {
|
|
expanded[i] = arg
|
|
}
|
|
}
|
|
return expanded, phase, nil
|
|
}
|