feat(rules): glob and regex support, env var substitution

- optimized `remote` rule for ip matching
- updated descriptions
This commit is contained in:
yusing
2025-10-10 14:43:48 +08:00
parent 60cfff3435
commit c2c9f42fb3
7 changed files with 469 additions and 97 deletions

View File

@@ -52,11 +52,21 @@ var commands = map[string]struct {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return validateURLPaths(args)
path1, err1 := validateURLPath(args[:1])
path2, err2 := validateURLPath(args[1:])
if err1 != nil {
err1 = gperr.Errorf("from: %w", err1)
}
if err2 != nil {
err2 = gperr.Errorf("to: %w", err2)
}
if err1 != nil || err2 != nil {
return nil, gperr.Join(err1, err2)
}
return &StrTuple{path1.(string), path2.(string)}, nil
},
build: func(args any) CommandHandler {
a := args.([]string)
orig, repl := a[0], a[1]
orig, repl := args.(*StrTuple).Unpack()
return StaticCommand(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if len(path) > 0 && path[0] != '/' {

View File

@@ -6,8 +6,9 @@ import (
var (
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
ErrUnsupportedEscapeChar = gperr.New("unsupported escape char")
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
ErrUnknownDirective = gperr.New("unknown directive")
ErrEnvVarNotFound = gperr.New("env variable not found")
ErrInvalidArguments = gperr.New("invalid arguments")
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
ErrInvalidCommandSequence = gperr.New("invalid command sequence")

View File

@@ -1,12 +1,11 @@
package rules
import (
"net"
"net/http"
"slices"
"strings"
"github.com/gobwas/glob"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/godoxy/internal/route/routes"
gperr "github.com/yusing/goutils/errs"
strutils "github.com/yusing/goutils/strings"
@@ -28,6 +27,7 @@ const (
OnForm = "form"
OnPostForm = "postform"
OnMethod = "method"
OnHost = "host"
OnPath = "path"
OnRemote = "remote"
OnBasicAuth = "basic_auth"
@@ -42,58 +42,69 @@ var checkers = map[string]struct {
OnHeader: {
help: Help{
command: OnHeader,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
header username "user"
header username glob("user*")
header username regex("user.*")`,
args: map[string]string{
"key": "the header key",
"[value]": "the header value",
},
},
validate: toKVOptionalV,
validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack()
if v == "" {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return len(r.Header[k]) > 0
}
}
return func(cached Cache, r *http.Request) bool {
return slices.Contains(r.Header[k], v)
return slices.ContainsFunc(r.Header[k], matcher)
}
},
},
OnQuery: {
help: Help{
command: OnQuery,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
query username "user"
query username glob("user*")
query username regex("user.*")`,
args: map[string]string{
"key": "the query key",
"[value]": "the query value",
},
},
validate: toKVOptionalV,
validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack()
if v == "" {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return len(cached.GetQueries(r)[k]) > 0
}
}
return func(cached Cache, r *http.Request) bool {
queries := cached.GetQueries(r)[k]
return slices.Contains(queries, v)
return slices.ContainsFunc(cached.GetQueries(r)[k], matcher)
}
},
},
OnCookie: {
help: Help{
command: OnCookie,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
cookie username "user"
cookie username glob("user*")
cookie username regex("user.*")`,
args: map[string]string{
"key": "the cookie key",
"[value]": "the cookie value",
},
},
validate: toKVOptionalV,
validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack()
if v == "" {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
cookies := cached.GetCookies(r)
for _, cookie := range cookies {
@@ -107,9 +118,10 @@ var checkers = map[string]struct {
return func(cached Cache, r *http.Request) bool {
cookies := cached.GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k &&
cookie.Value == v {
return true
if cookie.Name == k {
if matcher(cookie.Value) {
return true
}
}
}
return false
@@ -119,42 +131,50 @@ var checkers = map[string]struct {
OnForm: {
help: Help{
command: OnForm,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
form username "user"
form username glob("user*")
form username regex("user.*")`,
args: map[string]string{
"key": "the form key",
"[value]": "the form value",
},
},
validate: toKVOptionalV,
validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack()
if v == "" {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return r.FormValue(k) != ""
}
}
return func(cached Cache, r *http.Request) bool {
return r.FormValue(k) == v
return matcher(r.FormValue(k))
}
},
},
OnPostForm: {
help: Help{
command: OnPostForm,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
postform username "user"
postform username glob("user*")
postform username regex("user.*")`,
args: map[string]string{
"key": "the form key",
"[value]": "the form value",
},
},
validate: toKVOptionalV,
validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack()
if v == "" {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return r.PostFormValue(k) != ""
}
}
return func(cached Cache, r *http.Request) bool {
return r.PostFormValue(k) == v
return matcher(r.PostFormValue(k))
}
},
},
@@ -173,25 +193,47 @@ var checkers = map[string]struct {
}
},
},
OnHost: {
help: Help{
command: OnHost,
description: `Supports string, glob pattern, or regex pattern, e.g.:
host example.com
host glob(example*.com)
host regex(example\w+\.com)
host regex(example\.com$)`,
args: map[string]string{
"host": "the host name",
},
},
validate: validateSingleMatcher,
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(cached Cache, r *http.Request) bool {
return matcher(r.Host)
}
},
},
OnPath: {
help: Help{
command: OnPath,
description: `The path can be a glob pattern, e.g.:
/path/to
/path/to/*`,
description: `Supports string, glob pattern, or regex pattern, e.g.:
path /path/to
path glob(/path/to/*)
path regex(^/path/to/.*$)
path regex(/path/[A-Z]+/)`,
args: map[string]string{
"path": "the request path",
},
},
validate: validateURLPathGlob,
validate: validateURLPathMatcher,
builder: func(args any) CheckFunc {
pat := args.(glob.Glob)
matcher := args.(Matcher)
return func(cached Cache, r *http.Request) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
}
return pat.Match(reqPath)
return matcher(reqPath)
}
},
},
@@ -204,13 +246,24 @@ var checkers = map[string]struct {
},
validate: validateCIDR,
builder: func(args any) CheckFunc {
cidr := args.(nettypes.CIDR)
ipnet := args.(*net.IPNet)
// for /32 (IPv4) or /128 (IPv6), just compare the IP
if ones, bits := ipnet.Mask.Size(); ones == bits {
wantIP := ipnet.IP
return func(cached Cache, r *http.Request) bool {
ip := cached.GetRemoteIP(r)
if ip == nil {
return false
}
return ip.Equal(wantIP)
}
}
return func(cached Cache, r *http.Request) bool {
ip := cached.GetRemoteIP(r)
if ip == nil {
return false
}
return cidr.Contains(ip)
return ipnet.Contains(ip)
}
},
},
@@ -233,15 +286,19 @@ var checkers = map[string]struct {
OnRoute: {
help: Help{
command: OnRoute,
description: `Supports string, glob pattern, or regex pattern, e.g.:
route example
route glob(example*)
route regex(example\w+)`,
args: map[string]string{
"route": "the route name",
},
},
validate: validateSingleArg,
validate: validateSingleMatcher,
builder: func(args any) CheckFunc {
route := args.(string)
matcher := args.(Matcher)
return func(_ Cache, r *http.Request) bool {
return routes.TryGetUpstreamName(r) == route
return matcher(routes.TryGetUpstreamName(r))
}
},
},
@@ -253,8 +310,8 @@ var (
)
func indexAnd(s string) int {
for i, c := range s {
if andSeps[c] != 0 {
for i := range s {
if andSeps[s[i]] != 0 {
return i
}
}
@@ -263,8 +320,8 @@ func indexAnd(s string) int {
func countAnd(s string) int {
n := 0
for _, c := range s {
if andSeps[c] != 0 {
for i := range s {
if andSeps[s[i]] != 0 {
n++
}
}

View File

@@ -2,6 +2,8 @@ package rules
import (
"bytes"
"fmt"
"os"
"unicode"
gperr "github.com/yusing/goutils/errs"
@@ -14,20 +16,27 @@ var escapedChars = map[rune]rune{
'\'': '\'',
'"': '"',
'\\': '\\',
'$': '$',
' ': ' ',
}
// parse expression to subject and args
// with support for quotes and escaped chars, e.g.
// with support for quotes, escaped chars, and env substitution, e.g.
//
// error 403 "Forbidden 'foo' 'bar'"
// error 403 Forbidden\ \"foo\"\ \"bar\".
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
func parse(v string) (subject string, args []string, err gperr.Error) {
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
escaped := false
quote := rune(0)
brackets := 0
var envVar bytes.Buffer
var missingEnvVars bytes.Buffer
inEnvVar := false
expectingBrace := false
flush := func(quoted bool) {
part := buf.String()
if !quoted {
@@ -56,8 +65,7 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
if ch, ok := escapedChars[r]; ok {
buf.WriteRune(ch)
} else {
err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r)
return subject, args, err
fmt.Fprintf(buf, `\%c`, r)
}
escaped = false
continue
@@ -65,10 +73,36 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
switch r {
case '\\':
escaped = true
continue
case '"', '\'':
case '$':
if expectingBrace { // $$ => $ and continue
buf.WriteRune('$')
expectingBrace = false
} else {
expectingBrace = true
}
case '{':
if expectingBrace {
inEnvVar = true
expectingBrace = false
envVar.Reset()
} else {
buf.WriteRune(r)
}
case '}':
if inEnvVar {
envValue, ok := os.LookupEnv(envVar.String())
if !ok {
fmt.Fprintf(&missingEnvVars, "%q, ", envVar.String())
} else {
buf.WriteString(envValue)
}
inEnvVar = false
} else {
buf.WriteRune(r)
}
case '"', '\'', '`':
switch {
case quote == 0:
case quote == 0 && brackets == 0:
quote = r
flush(false)
case r == quote:
@@ -77,21 +111,40 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
default:
buf.WriteRune(r)
}
case '(':
brackets++
buf.WriteRune(r)
case ')':
if brackets == 0 {
err = ErrUnterminatedBrackets
return subject, args, err
}
brackets--
buf.WriteRune(r)
case ' ':
if quote == 0 {
flush(false)
continue
} else {
buf.WriteRune(r)
}
fallthrough
default:
buf.WriteRune(r)
if inEnvVar {
envVar.WriteRune(r)
} else {
buf.WriteRune(r)
}
}
}
if quote != 0 {
err = ErrUnterminatedQuotes
} else if brackets != 0 {
err = ErrUnterminatedBrackets
} else {
flush(false)
}
if missingEnvVars.Len() > 0 {
err = gperr.Join(err, ErrEnvVarNotFound.Subject(missingEnvVars.String()))
}
return subject, args, err
}

View File

@@ -1,6 +1,7 @@
package rules
import (
"os"
"strconv"
"testing"
@@ -47,9 +48,28 @@ func TestParser(t *testing.T) {
args: []string{"", ""},
},
{
name: "invalid_escape",
input: `foo \bar`,
wantErr: ErrUnsupportedEscapeChar,
name: "regex_escaped",
input: `foo regex(\b\B\s\S\w\W\d\D\$\.)`,
subject: "foo",
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.)`},
},
{
name: "quote inside argument",
input: `foo "abc 'def'"`,
subject: "foo",
args: []string{"abc 'def'"},
},
{
name: "quote inside function",
input: `foo glob("'/**/to/path'")`,
subject: "foo",
args: []string{"glob(\"'/**/to/path'\")"},
},
{
name: "quote inside quoted function",
input: "foo 'glob(\"`/**/to/path`\")'",
subject: "foo",
args: []string{"glob(\"`/**/to/path`\")"},
},
{
name: "chaos",
@@ -74,12 +94,69 @@ func TestParser(t *testing.T) {
// t.Log(subject, args, err)
expect.NoError(t, err)
expect.Equal(t, subject, tt.subject)
expect.Equal(t, len(args), len(tt.args))
for i, arg := range args {
expect.Equal(t, arg, tt.args[i])
}
expect.Equal(t, args, tt.args)
})
}
t.Run("env substitution", func(t *testing.T) {
// Set up test environment variables
os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
os.Setenv("DOMAIN", "example.com")
defer func() {
os.Unsetenv("CLOUDFLARE_API_KEY")
os.Unsetenv("DOMAIN")
}()
tests := []struct {
name string
input string
subject string
args []string
wantErr string
}{
{
name: "simple env var",
input: `error 403 "Forbidden: ${CLOUDFLARE_API_KEY}"`,
subject: "error",
args: []string{"403", "Forbidden: test-api-key-123"},
},
{
name: "multiple env vars",
input: `forward https://${DOMAIN}/api`,
subject: "forward",
args: []string{"https://example.com/api"},
},
{
name: "env var with other text",
input: `auth "user-${DOMAIN}-admin" "password"`,
subject: "auth",
args: []string{"user-example.com-admin", "password"},
},
{
name: "non-existent env var",
input: `error 404 "${NON_EXISTENT}"`,
wantErr: ErrEnvVarNotFound.Error(),
},
{
name: "escaped",
input: `error 404 "$${NON_EXISTENT}"`,
subject: "error",
args: []string{"404", "${NON_EXISTENT}"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subject, args, err := parse(tt.input)
if tt.wantErr != "" {
expect.ErrorContains(t, err, tt.wantErr)
return
}
expect.NoError(t, err)
expect.Equal(t, subject, tt.subject)
expect.Equal(t, args, tt.args)
})
}
})
t.Run("unterminated quotes", func(t *testing.T) {
tests := []string{
`error 403 "Forbidden 'foo' 'bar'`,
@@ -97,7 +174,7 @@ func TestParser(t *testing.T) {
func BenchmarkParser(b *testing.B) {
const input = `error 403 "Forbidden "foo" "bar""\ baz`
for range b.N {
for b.Loop() {
_, _, err := parse(input)
if err != nil {
b.Fatal(err)

View File

@@ -2,9 +2,12 @@ package rules
import (
"fmt"
"net"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"github.com/gobwas/glob"
@@ -19,7 +22,9 @@ type (
First T1
Second T2
}
StrTuple = Tuple[string, string]
StrTuple = Tuple[string, string]
IntTuple = Tuple[int, int]
MapValueMatcher = Tuple[string, Matcher]
)
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
@@ -30,11 +35,101 @@ func (t *Tuple[T1, T2]) String() string {
return fmt.Sprintf("%v:%v", t.First, t.Second)
}
func validateSingleArg(args []string) (any, gperr.Error) {
type (
Matcher func(string) bool
MatcherType string
)
const (
MatcherTypeString MatcherType = "string"
MatcherTypeGlob MatcherType = "glob"
MatcherTypeRegex MatcherType = "regex"
)
func unquoteExpr(s string) (string, gperr.Error) {
if s == "" {
return "", nil
}
switch s[0] {
case '"', '\'', '`':
if s[0] != s[len(s)-1] {
return "", ErrUnterminatedQuotes
}
return s[1 : len(s)-1], nil
default:
return s, nil
}
}
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
idx := strings.IndexByte(s, '(')
if idx == -1 {
return MatcherTypeString, s, nil
}
idxEnd := strings.LastIndexByte(s, ')')
if idxEnd == -1 {
return "", "", ErrUnterminatedBrackets
}
expr, err = unquoteExpr(s[idx+1 : idxEnd])
if err != nil {
return "", "", err
}
matcherType = MatcherType(strings.ToLower(s[:idx]))
switch matcherType {
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
return
default:
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
}
}
func ParseMatcher(expr string) (Matcher, gperr.Error) {
t, expr, err := ExtractExpr(expr)
if err != nil {
return nil, err
}
switch t {
case MatcherTypeString:
return StringMatcher(expr)
case MatcherTypeGlob:
return GlobMatcher(expr)
case MatcherTypeRegex:
return RegexMatcher(expr)
}
// won't reach here
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
}
func StringMatcher(s string) (Matcher, gperr.Error) {
return func(s2 string) bool {
return s == s2
}, nil
}
func GlobMatcher(expr string) (Matcher, gperr.Error) {
g, err := glob.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return g.Match, nil
}
func RegexMatcher(expr string) (Matcher, gperr.Error) {
re, err := regexp.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return re.MatchString, nil
}
// validateSingleMatcher returns Matcher with the matcher validated.
func validateSingleMatcher(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
return args[0], nil
return ParseMatcher(args[0])
}
// toStrTuple returns *StrTuple.
@@ -45,13 +140,17 @@ func toStrTuple(args []string) (any, gperr.Error) {
return &StrTuple{args[0], args[1]}, nil
}
// toKVOptionalV returns *StrTuple that value is optional.
func toKVOptionalV(args []string) (any, gperr.Error) {
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
switch len(args) {
case 1:
return &StrTuple{args[0], ""}, nil
return &MapValueMatcher{args[0], nil}, nil
case 2:
return &StrTuple{args[0], args[1]}, nil
m, err := ParseMatcher(args[1])
if err != nil {
return nil, err
}
return &MapValueMatcher{args[0], m}, nil
default:
return nil, ErrExpectKVOptionalV
}
@@ -95,11 +194,11 @@ func validateCIDR(args []string) (any, gperr.Error) {
if !strings.Contains(args[0], "/") {
args[0] += "/32"
}
cidr, err := nettypes.ParseCIDR(args[0])
_, ipnet, err := net.ParseCIDR(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return cidr, nil
return ipnet, nil
}
// validateURLPath returns string with the path validated.
@@ -120,35 +219,12 @@ func validateURLPath(args []string) (any, gperr.Error) {
return p, nil
}
// validateURLPathGlob returns []string with each element validated.
func validateURLPathGlob(args []string) (any, gperr.Error) {
p, err := validateURLPath(args)
func validateURLPathMatcher(args []string) (any, gperr.Error) {
path, err := validateURLPath(args)
if err != nil {
return nil, err
}
g, gErr := glob.Compile(p.(string))
if gErr != nil {
return nil, ErrInvalidArguments.With(gErr)
}
return g, nil
}
// validateURLPaths returns []string with each element validated.
func validateURLPaths(paths []string) (any, gperr.Error) {
errs := gperr.NewBuilder("invalid url paths")
for i, p := range paths {
val, err := validateURLPath([]string{p})
if err != nil {
errs.Add(err.Subject(p))
continue
}
paths[i] = val.(string)
}
if err := errs.Error(); err != nil {
return nil, err
}
return paths, nil
return ParseMatcher(path.(string))
}
// validateFSPath returns string with the path validated.

View File

@@ -0,0 +1,98 @@
package rules
import (
"testing"
expect "github.com/yusing/goutils/testing"
)
func TestExtractExpr(t *testing.T) {
tests := []struct {
name string
in string
wantT MatcherType
wantExpr string
}{
{
name: "string implicit",
in: "foo",
wantT: MatcherTypeString,
wantExpr: "foo",
},
{
name: "string explicit",
in: "string(`foo`)",
wantT: MatcherTypeString,
wantExpr: "foo",
},
{
name: "glob",
in: "glob(foo)",
wantT: MatcherTypeGlob,
wantExpr: "foo",
},
{
name: "glob quoted",
in: "glob(`foo`)",
wantT: MatcherTypeGlob,
wantExpr: "foo",
},
{
name: "regex",
in: "regex(^[A-Z]+$)",
wantT: MatcherTypeRegex,
wantExpr: "^[A-Z]+$",
},
{
name: "regex quoted",
in: "regex(`^[A-Z]+$`)",
wantT: MatcherTypeRegex,
wantExpr: "^[A-Z]+$",
},
{
name: "quoted expr",
in: "glob(`'foo'`)",
wantT: MatcherTypeGlob,
wantExpr: "'foo'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
typ, expr, err := ExtractExpr(tt.in)
expect.NoError(t, err)
expect.Equal(t, tt.wantT, typ)
expect.Equal(t, tt.wantExpr, expr)
})
}
}
func TestExtractExprInvalid(t *testing.T) {
tests := []struct {
name string
in string
wantErr string
}{
{
name: "missing closing quote",
in: "glob(`foo)",
wantErr: "unterminated quotes",
},
{
name: "missing closing bracket",
in: "glob(`foo",
wantErr: "unterminated brackets",
},
{
name: "invalid matcher type",
in: "invalid(`foo`)",
wantErr: "invalid matcher type: invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, err := ExtractExpr(tt.in)
expect.HasError(t, err)
expect.ErrorContains(t, err, tt.wantErr)
})
}
}