feat(rules): glob and regex support, env var substitution

- optimized `remote` rule for ip matching
- updated descriptions
This commit is contained in:
yusing
2025-10-10 14:43:48 +08:00
parent 60cfff3435
commit c2c9f42fb3
7 changed files with 469 additions and 97 deletions

View File

@@ -52,11 +52,21 @@ var commands = map[string]struct {
if len(args) != 2 { if len(args) != 2 {
return nil, ErrExpectTwoArgs return nil, ErrExpectTwoArgs
} }
return validateURLPaths(args) path1, err1 := validateURLPath(args[:1])
path2, err2 := validateURLPath(args[1:])
if err1 != nil {
err1 = gperr.Errorf("from: %w", err1)
}
if err2 != nil {
err2 = gperr.Errorf("to: %w", err2)
}
if err1 != nil || err2 != nil {
return nil, gperr.Join(err1, err2)
}
return &StrTuple{path1.(string), path2.(string)}, nil
}, },
build: func(args any) CommandHandler { build: func(args any) CommandHandler {
a := args.([]string) orig, repl := args.(*StrTuple).Unpack()
orig, repl := a[0], a[1]
return StaticCommand(func(w http.ResponseWriter, r *http.Request) { return StaticCommand(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path path := r.URL.Path
if len(path) > 0 && path[0] != '/' { if len(path) > 0 && path[0] != '/' {

View File

@@ -6,8 +6,9 @@ import (
var ( var (
ErrUnterminatedQuotes = gperr.New("unterminated quotes") ErrUnterminatedQuotes = gperr.New("unterminated quotes")
ErrUnsupportedEscapeChar = gperr.New("unsupported escape char") ErrUnterminatedBrackets = gperr.New("unterminated brackets")
ErrUnknownDirective = gperr.New("unknown directive") ErrUnknownDirective = gperr.New("unknown directive")
ErrEnvVarNotFound = gperr.New("env variable not found")
ErrInvalidArguments = gperr.New("invalid arguments") ErrInvalidArguments = gperr.New("invalid arguments")
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target") ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
ErrInvalidCommandSequence = gperr.New("invalid command sequence") ErrInvalidCommandSequence = gperr.New("invalid command sequence")

View File

@@ -1,12 +1,11 @@
package rules package rules
import ( import (
"net"
"net/http" "net/http"
"slices" "slices"
"strings" "strings"
"github.com/gobwas/glob"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/godoxy/internal/route/routes" "github.com/yusing/godoxy/internal/route/routes"
gperr "github.com/yusing/goutils/errs" gperr "github.com/yusing/goutils/errs"
strutils "github.com/yusing/goutils/strings" strutils "github.com/yusing/goutils/strings"
@@ -28,6 +27,7 @@ const (
OnForm = "form" OnForm = "form"
OnPostForm = "postform" OnPostForm = "postform"
OnMethod = "method" OnMethod = "method"
OnHost = "host"
OnPath = "path" OnPath = "path"
OnRemote = "remote" OnRemote = "remote"
OnBasicAuth = "basic_auth" OnBasicAuth = "basic_auth"
@@ -42,58 +42,69 @@ var checkers = map[string]struct {
OnHeader: { OnHeader: {
help: Help{ help: Help{
command: OnHeader, command: OnHeader,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
header username "user"
header username glob("user*")
header username regex("user.*")`,
args: map[string]string{ args: map[string]string{
"key": "the header key", "key": "the header key",
"[value]": "the header value", "[value]": "the header value",
}, },
}, },
validate: toKVOptionalV, validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack() k, matcher := args.(*MapValueMatcher).Unpack()
if v == "" { if matcher == nil {
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return len(r.Header[k]) > 0 return len(r.Header[k]) > 0
} }
} }
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return slices.Contains(r.Header[k], v) return slices.ContainsFunc(r.Header[k], matcher)
} }
}, },
}, },
OnQuery: { OnQuery: {
help: Help{ help: Help{
command: OnQuery, command: OnQuery,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
query username "user"
query username glob("user*")
query username regex("user.*")`,
args: map[string]string{ args: map[string]string{
"key": "the query key", "key": "the query key",
"[value]": "the query value", "[value]": "the query value",
}, },
}, },
validate: toKVOptionalV, validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack() k, matcher := args.(*MapValueMatcher).Unpack()
if v == "" { if matcher == nil {
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return len(cached.GetQueries(r)[k]) > 0 return len(cached.GetQueries(r)[k]) > 0
} }
} }
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
queries := cached.GetQueries(r)[k] return slices.ContainsFunc(cached.GetQueries(r)[k], matcher)
return slices.Contains(queries, v)
} }
}, },
}, },
OnCookie: { OnCookie: {
help: Help{ help: Help{
command: OnCookie, command: OnCookie,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
cookie username "user"
cookie username glob("user*")
cookie username regex("user.*")`,
args: map[string]string{ args: map[string]string{
"key": "the cookie key", "key": "the cookie key",
"[value]": "the cookie value", "[value]": "the cookie value",
}, },
}, },
validate: toKVOptionalV, validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack() k, matcher := args.(*MapValueMatcher).Unpack()
if v == "" { if matcher == nil {
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
cookies := cached.GetCookies(r) cookies := cached.GetCookies(r)
for _, cookie := range cookies { for _, cookie := range cookies {
@@ -107,9 +118,10 @@ var checkers = map[string]struct {
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
cookies := cached.GetCookies(r) cookies := cached.GetCookies(r)
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == k && if cookie.Name == k {
cookie.Value == v { if matcher(cookie.Value) {
return true return true
}
} }
} }
return false return false
@@ -119,42 +131,50 @@ var checkers = map[string]struct {
OnForm: { OnForm: {
help: Help{ help: Help{
command: OnForm, command: OnForm,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
form username "user"
form username glob("user*")
form username regex("user.*")`,
args: map[string]string{ args: map[string]string{
"key": "the form key", "key": "the form key",
"[value]": "the form value", "[value]": "the form value",
}, },
}, },
validate: toKVOptionalV, validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack() k, matcher := args.(*MapValueMatcher).Unpack()
if v == "" { if matcher == nil {
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return r.FormValue(k) != "" return r.FormValue(k) != ""
} }
} }
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return r.FormValue(k) == v return matcher(r.FormValue(k))
} }
}, },
}, },
OnPostForm: { OnPostForm: {
help: Help{ help: Help{
command: OnPostForm, command: OnPostForm,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
postform username "user"
postform username glob("user*")
postform username regex("user.*")`,
args: map[string]string{ args: map[string]string{
"key": "the form key", "key": "the form key",
"[value]": "the form value", "[value]": "the form value",
}, },
}, },
validate: toKVOptionalV, validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
k, v := args.(*StrTuple).Unpack() k, matcher := args.(*MapValueMatcher).Unpack()
if v == "" { if matcher == nil {
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return r.PostFormValue(k) != "" return r.PostFormValue(k) != ""
} }
} }
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
return r.PostFormValue(k) == v return matcher(r.PostFormValue(k))
} }
}, },
}, },
@@ -173,25 +193,47 @@ var checkers = map[string]struct {
} }
}, },
}, },
OnHost: {
help: Help{
command: OnHost,
description: `Supports string, glob pattern, or regex pattern, e.g.:
host example.com
host glob(example*.com)
host regex(example\w+\.com)
host regex(example\.com$)`,
args: map[string]string{
"host": "the host name",
},
},
validate: validateSingleMatcher,
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(cached Cache, r *http.Request) bool {
return matcher(r.Host)
}
},
},
OnPath: { OnPath: {
help: Help{ help: Help{
command: OnPath, command: OnPath,
description: `The path can be a glob pattern, e.g.: description: `Supports string, glob pattern, or regex pattern, e.g.:
/path/to path /path/to
/path/to/*`, path glob(/path/to/*)
path regex(^/path/to/.*$)
path regex(/path/[A-Z]+/)`,
args: map[string]string{ args: map[string]string{
"path": "the request path", "path": "the request path",
}, },
}, },
validate: validateURLPathGlob, validate: validateURLPathMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
pat := args.(glob.Glob) matcher := args.(Matcher)
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
reqPath := r.URL.Path reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' { if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath reqPath = "/" + reqPath
} }
return pat.Match(reqPath) return matcher(reqPath)
} }
}, },
}, },
@@ -204,13 +246,24 @@ var checkers = map[string]struct {
}, },
validate: validateCIDR, validate: validateCIDR,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
cidr := args.(nettypes.CIDR) ipnet := args.(*net.IPNet)
// for /32 (IPv4) or /128 (IPv6), just compare the IP
if ones, bits := ipnet.Mask.Size(); ones == bits {
wantIP := ipnet.IP
return func(cached Cache, r *http.Request) bool {
ip := cached.GetRemoteIP(r)
if ip == nil {
return false
}
return ip.Equal(wantIP)
}
}
return func(cached Cache, r *http.Request) bool { return func(cached Cache, r *http.Request) bool {
ip := cached.GetRemoteIP(r) ip := cached.GetRemoteIP(r)
if ip == nil { if ip == nil {
return false return false
} }
return cidr.Contains(ip) return ipnet.Contains(ip)
} }
}, },
}, },
@@ -233,15 +286,19 @@ var checkers = map[string]struct {
OnRoute: { OnRoute: {
help: Help{ help: Help{
command: OnRoute, command: OnRoute,
description: `Supports string, glob pattern, or regex pattern, e.g.:
route example
route glob(example*)
route regex(example\w+)`,
args: map[string]string{ args: map[string]string{
"route": "the route name", "route": "the route name",
}, },
}, },
validate: validateSingleArg, validate: validateSingleMatcher,
builder: func(args any) CheckFunc { builder: func(args any) CheckFunc {
route := args.(string) matcher := args.(Matcher)
return func(_ Cache, r *http.Request) bool { return func(_ Cache, r *http.Request) bool {
return routes.TryGetUpstreamName(r) == route return matcher(routes.TryGetUpstreamName(r))
} }
}, },
}, },
@@ -253,8 +310,8 @@ var (
) )
func indexAnd(s string) int { func indexAnd(s string) int {
for i, c := range s { for i := range s {
if andSeps[c] != 0 { if andSeps[s[i]] != 0 {
return i return i
} }
} }
@@ -263,8 +320,8 @@ func indexAnd(s string) int {
func countAnd(s string) int { func countAnd(s string) int {
n := 0 n := 0
for _, c := range s { for i := range s {
if andSeps[c] != 0 { if andSeps[s[i]] != 0 {
n++ n++
} }
} }

View File

@@ -2,6 +2,8 @@ package rules
import ( import (
"bytes" "bytes"
"fmt"
"os"
"unicode" "unicode"
gperr "github.com/yusing/goutils/errs" gperr "github.com/yusing/goutils/errs"
@@ -14,20 +16,27 @@ var escapedChars = map[rune]rune{
'\'': '\'', '\'': '\'',
'"': '"', '"': '"',
'\\': '\\', '\\': '\\',
'$': '$',
' ': ' ', ' ': ' ',
} }
// parse expression to subject and args // parse expression to subject and args
// with support for quotes and escaped chars, e.g. // with support for quotes, escaped chars, and env substitution, e.g.
// //
// error 403 "Forbidden 'foo' 'bar'" // error 403 "Forbidden 'foo' 'bar'"
// error 403 Forbidden\ \"foo\"\ \"bar\". // error 403 Forbidden\ \"foo\"\ \"bar\".
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
func parse(v string) (subject string, args []string, err gperr.Error) { func parse(v string) (subject string, args []string, err gperr.Error) {
buf := bytes.NewBuffer(make([]byte, 0, len(v))) buf := bytes.NewBuffer(make([]byte, 0, len(v)))
escaped := false escaped := false
quote := rune(0) quote := rune(0)
brackets := 0
var envVar bytes.Buffer
var missingEnvVars bytes.Buffer
inEnvVar := false
expectingBrace := false
flush := func(quoted bool) { flush := func(quoted bool) {
part := buf.String() part := buf.String()
if !quoted { if !quoted {
@@ -56,8 +65,7 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
if ch, ok := escapedChars[r]; ok { if ch, ok := escapedChars[r]; ok {
buf.WriteRune(ch) buf.WriteRune(ch)
} else { } else {
err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r) fmt.Fprintf(buf, `\%c`, r)
return subject, args, err
} }
escaped = false escaped = false
continue continue
@@ -65,10 +73,36 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
switch r { switch r {
case '\\': case '\\':
escaped = true escaped = true
continue case '$':
case '"', '\'': if expectingBrace { // $$ => $ and continue
buf.WriteRune('$')
expectingBrace = false
} else {
expectingBrace = true
}
case '{':
if expectingBrace {
inEnvVar = true
expectingBrace = false
envVar.Reset()
} else {
buf.WriteRune(r)
}
case '}':
if inEnvVar {
envValue, ok := os.LookupEnv(envVar.String())
if !ok {
fmt.Fprintf(&missingEnvVars, "%q, ", envVar.String())
} else {
buf.WriteString(envValue)
}
inEnvVar = false
} else {
buf.WriteRune(r)
}
case '"', '\'', '`':
switch { switch {
case quote == 0: case quote == 0 && brackets == 0:
quote = r quote = r
flush(false) flush(false)
case r == quote: case r == quote:
@@ -77,21 +111,40 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
default: default:
buf.WriteRune(r) buf.WriteRune(r)
} }
case '(':
brackets++
buf.WriteRune(r)
case ')':
if brackets == 0 {
err = ErrUnterminatedBrackets
return subject, args, err
}
brackets--
buf.WriteRune(r)
case ' ': case ' ':
if quote == 0 { if quote == 0 {
flush(false) flush(false)
continue } else {
buf.WriteRune(r)
} }
fallthrough
default: default:
buf.WriteRune(r) if inEnvVar {
envVar.WriteRune(r)
} else {
buf.WriteRune(r)
}
} }
} }
if quote != 0 { if quote != 0 {
err = ErrUnterminatedQuotes err = ErrUnterminatedQuotes
} else if brackets != 0 {
err = ErrUnterminatedBrackets
} else { } else {
flush(false) flush(false)
} }
if missingEnvVars.Len() > 0 {
err = gperr.Join(err, ErrEnvVarNotFound.Subject(missingEnvVars.String()))
}
return subject, args, err return subject, args, err
} }

View File

@@ -1,6 +1,7 @@
package rules package rules
import ( import (
"os"
"strconv" "strconv"
"testing" "testing"
@@ -47,9 +48,28 @@ func TestParser(t *testing.T) {
args: []string{"", ""}, args: []string{"", ""},
}, },
{ {
name: "invalid_escape", name: "regex_escaped",
input: `foo \bar`, input: `foo regex(\b\B\s\S\w\W\d\D\$\.)`,
wantErr: ErrUnsupportedEscapeChar, subject: "foo",
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.)`},
},
{
name: "quote inside argument",
input: `foo "abc 'def'"`,
subject: "foo",
args: []string{"abc 'def'"},
},
{
name: "quote inside function",
input: `foo glob("'/**/to/path'")`,
subject: "foo",
args: []string{"glob(\"'/**/to/path'\")"},
},
{
name: "quote inside quoted function",
input: "foo 'glob(\"`/**/to/path`\")'",
subject: "foo",
args: []string{"glob(\"`/**/to/path`\")"},
}, },
{ {
name: "chaos", name: "chaos",
@@ -74,12 +94,69 @@ func TestParser(t *testing.T) {
// t.Log(subject, args, err) // t.Log(subject, args, err)
expect.NoError(t, err) expect.NoError(t, err)
expect.Equal(t, subject, tt.subject) expect.Equal(t, subject, tt.subject)
expect.Equal(t, len(args), len(tt.args)) expect.Equal(t, args, tt.args)
for i, arg := range args {
expect.Equal(t, arg, tt.args[i])
}
}) })
} }
t.Run("env substitution", func(t *testing.T) {
// Set up test environment variables
os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
os.Setenv("DOMAIN", "example.com")
defer func() {
os.Unsetenv("CLOUDFLARE_API_KEY")
os.Unsetenv("DOMAIN")
}()
tests := []struct {
name string
input string
subject string
args []string
wantErr string
}{
{
name: "simple env var",
input: `error 403 "Forbidden: ${CLOUDFLARE_API_KEY}"`,
subject: "error",
args: []string{"403", "Forbidden: test-api-key-123"},
},
{
name: "multiple env vars",
input: `forward https://${DOMAIN}/api`,
subject: "forward",
args: []string{"https://example.com/api"},
},
{
name: "env var with other text",
input: `auth "user-${DOMAIN}-admin" "password"`,
subject: "auth",
args: []string{"user-example.com-admin", "password"},
},
{
name: "non-existent env var",
input: `error 404 "${NON_EXISTENT}"`,
wantErr: ErrEnvVarNotFound.Error(),
},
{
name: "escaped",
input: `error 404 "$${NON_EXISTENT}"`,
subject: "error",
args: []string{"404", "${NON_EXISTENT}"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subject, args, err := parse(tt.input)
if tt.wantErr != "" {
expect.ErrorContains(t, err, tt.wantErr)
return
}
expect.NoError(t, err)
expect.Equal(t, subject, tt.subject)
expect.Equal(t, args, tt.args)
})
}
})
t.Run("unterminated quotes", func(t *testing.T) { t.Run("unterminated quotes", func(t *testing.T) {
tests := []string{ tests := []string{
`error 403 "Forbidden 'foo' 'bar'`, `error 403 "Forbidden 'foo' 'bar'`,
@@ -97,7 +174,7 @@ func TestParser(t *testing.T) {
func BenchmarkParser(b *testing.B) { func BenchmarkParser(b *testing.B) {
const input = `error 403 "Forbidden "foo" "bar""\ baz` const input = `error 403 "Forbidden "foo" "bar""\ baz`
for range b.N { for b.Loop() {
_, _, err := parse(input) _, _, err := parse(input)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)

View File

@@ -2,9 +2,12 @@ package rules
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"regexp"
"strconv"
"strings" "strings"
"github.com/gobwas/glob" "github.com/gobwas/glob"
@@ -19,7 +22,9 @@ type (
First T1 First T1
Second T2 Second T2
} }
StrTuple = Tuple[string, string] StrTuple = Tuple[string, string]
IntTuple = Tuple[int, int]
MapValueMatcher = Tuple[string, Matcher]
) )
func (t *Tuple[T1, T2]) Unpack() (T1, T2) { func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
@@ -30,11 +35,101 @@ func (t *Tuple[T1, T2]) String() string {
return fmt.Sprintf("%v:%v", t.First, t.Second) return fmt.Sprintf("%v:%v", t.First, t.Second)
} }
func validateSingleArg(args []string) (any, gperr.Error) { type (
Matcher func(string) bool
MatcherType string
)
const (
MatcherTypeString MatcherType = "string"
MatcherTypeGlob MatcherType = "glob"
MatcherTypeRegex MatcherType = "regex"
)
func unquoteExpr(s string) (string, gperr.Error) {
if s == "" {
return "", nil
}
switch s[0] {
case '"', '\'', '`':
if s[0] != s[len(s)-1] {
return "", ErrUnterminatedQuotes
}
return s[1 : len(s)-1], nil
default:
return s, nil
}
}
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
idx := strings.IndexByte(s, '(')
if idx == -1 {
return MatcherTypeString, s, nil
}
idxEnd := strings.LastIndexByte(s, ')')
if idxEnd == -1 {
return "", "", ErrUnterminatedBrackets
}
expr, err = unquoteExpr(s[idx+1 : idxEnd])
if err != nil {
return "", "", err
}
matcherType = MatcherType(strings.ToLower(s[:idx]))
switch matcherType {
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
return
default:
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
}
}
func ParseMatcher(expr string) (Matcher, gperr.Error) {
t, expr, err := ExtractExpr(expr)
if err != nil {
return nil, err
}
switch t {
case MatcherTypeString:
return StringMatcher(expr)
case MatcherTypeGlob:
return GlobMatcher(expr)
case MatcherTypeRegex:
return RegexMatcher(expr)
}
// won't reach here
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
}
func StringMatcher(s string) (Matcher, gperr.Error) {
return func(s2 string) bool {
return s == s2
}, nil
}
func GlobMatcher(expr string) (Matcher, gperr.Error) {
g, err := glob.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return g.Match, nil
}
func RegexMatcher(expr string) (Matcher, gperr.Error) {
re, err := regexp.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return re.MatchString, nil
}
// validateSingleMatcher returns Matcher with the matcher validated.
func validateSingleMatcher(args []string) (any, gperr.Error) {
if len(args) != 1 { if len(args) != 1 {
return nil, ErrExpectOneArg return nil, ErrExpectOneArg
} }
return args[0], nil return ParseMatcher(args[0])
} }
// toStrTuple returns *StrTuple. // toStrTuple returns *StrTuple.
@@ -45,13 +140,17 @@ func toStrTuple(args []string) (any, gperr.Error) {
return &StrTuple{args[0], args[1]}, nil return &StrTuple{args[0], args[1]}, nil
} }
// toKVOptionalV returns *StrTuple that value is optional. // toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
func toKVOptionalV(args []string) (any, gperr.Error) { func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
switch len(args) { switch len(args) {
case 1: case 1:
return &StrTuple{args[0], ""}, nil return &MapValueMatcher{args[0], nil}, nil
case 2: case 2:
return &StrTuple{args[0], args[1]}, nil m, err := ParseMatcher(args[1])
if err != nil {
return nil, err
}
return &MapValueMatcher{args[0], m}, nil
default: default:
return nil, ErrExpectKVOptionalV return nil, ErrExpectKVOptionalV
} }
@@ -95,11 +194,11 @@ func validateCIDR(args []string) (any, gperr.Error) {
if !strings.Contains(args[0], "/") { if !strings.Contains(args[0], "/") {
args[0] += "/32" args[0] += "/32"
} }
cidr, err := nettypes.ParseCIDR(args[0]) _, ipnet, err := net.ParseCIDR(args[0])
if err != nil { if err != nil {
return nil, ErrInvalidArguments.With(err) return nil, ErrInvalidArguments.With(err)
} }
return cidr, nil return ipnet, nil
} }
// validateURLPath returns string with the path validated. // validateURLPath returns string with the path validated.
@@ -120,35 +219,12 @@ func validateURLPath(args []string) (any, gperr.Error) {
return p, nil return p, nil
} }
// validateURLPathGlob returns []string with each element validated. func validateURLPathMatcher(args []string) (any, gperr.Error) {
func validateURLPathGlob(args []string) (any, gperr.Error) { path, err := validateURLPath(args)
p, err := validateURLPath(args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ParseMatcher(path.(string))
g, gErr := glob.Compile(p.(string))
if gErr != nil {
return nil, ErrInvalidArguments.With(gErr)
}
return g, nil
}
// validateURLPaths returns []string with each element validated.
func validateURLPaths(paths []string) (any, gperr.Error) {
errs := gperr.NewBuilder("invalid url paths")
for i, p := range paths {
val, err := validateURLPath([]string{p})
if err != nil {
errs.Add(err.Subject(p))
continue
}
paths[i] = val.(string)
}
if err := errs.Error(); err != nil {
return nil, err
}
return paths, nil
} }
// validateFSPath returns string with the path validated. // validateFSPath returns string with the path validated.

View File

@@ -0,0 +1,98 @@
package rules
import (
"testing"
expect "github.com/yusing/goutils/testing"
)
func TestExtractExpr(t *testing.T) {
tests := []struct {
name string
in string
wantT MatcherType
wantExpr string
}{
{
name: "string implicit",
in: "foo",
wantT: MatcherTypeString,
wantExpr: "foo",
},
{
name: "string explicit",
in: "string(`foo`)",
wantT: MatcherTypeString,
wantExpr: "foo",
},
{
name: "glob",
in: "glob(foo)",
wantT: MatcherTypeGlob,
wantExpr: "foo",
},
{
name: "glob quoted",
in: "glob(`foo`)",
wantT: MatcherTypeGlob,
wantExpr: "foo",
},
{
name: "regex",
in: "regex(^[A-Z]+$)",
wantT: MatcherTypeRegex,
wantExpr: "^[A-Z]+$",
},
{
name: "regex quoted",
in: "regex(`^[A-Z]+$`)",
wantT: MatcherTypeRegex,
wantExpr: "^[A-Z]+$",
},
{
name: "quoted expr",
in: "glob(`'foo'`)",
wantT: MatcherTypeGlob,
wantExpr: "'foo'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
typ, expr, err := ExtractExpr(tt.in)
expect.NoError(t, err)
expect.Equal(t, tt.wantT, typ)
expect.Equal(t, tt.wantExpr, expr)
})
}
}
func TestExtractExprInvalid(t *testing.T) {
tests := []struct {
name string
in string
wantErr string
}{
{
name: "missing closing quote",
in: "glob(`foo)",
wantErr: "unterminated quotes",
},
{
name: "missing closing bracket",
in: "glob(`foo",
wantErr: "unterminated brackets",
},
{
name: "invalid matcher type",
in: "invalid(`foo`)",
wantErr: "invalid matcher type: invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, err := ExtractExpr(tt.in)
expect.HasError(t, err)
expect.ErrorContains(t, err, tt.wantErr)
})
}
}