mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-26 11:01:07 +01:00
feat(rules): glob and regex support, env var substitution
- optimized `remote` rule for ip matching - updated descriptions
This commit is contained in:
@@ -52,11 +52,21 @@ var commands = map[string]struct {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
return validateURLPaths(args)
|
||||
path1, err1 := validateURLPath(args[:1])
|
||||
path2, err2 := validateURLPath(args[1:])
|
||||
if err1 != nil {
|
||||
err1 = gperr.Errorf("from: %w", err1)
|
||||
}
|
||||
if err2 != nil {
|
||||
err2 = gperr.Errorf("to: %w", err2)
|
||||
}
|
||||
if err1 != nil || err2 != nil {
|
||||
return nil, gperr.Join(err1, err2)
|
||||
}
|
||||
return &StrTuple{path1.(string), path2.(string)}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
a := args.([]string)
|
||||
orig, repl := a[0], a[1]
|
||||
orig, repl := args.(*StrTuple).Unpack()
|
||||
return StaticCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
if len(path) > 0 && path[0] != '/' {
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
|
||||
var (
|
||||
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
|
||||
ErrUnsupportedEscapeChar = gperr.New("unsupported escape char")
|
||||
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
|
||||
ErrUnknownDirective = gperr.New("unknown directive")
|
||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
@@ -28,6 +27,7 @@ const (
|
||||
OnForm = "form"
|
||||
OnPostForm = "postform"
|
||||
OnMethod = "method"
|
||||
OnHost = "host"
|
||||
OnPath = "path"
|
||||
OnRemote = "remote"
|
||||
OnBasicAuth = "basic_auth"
|
||||
@@ -42,58 +42,69 @@ var checkers = map[string]struct {
|
||||
OnHeader: {
|
||||
help: Help{
|
||||
command: OnHeader,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
header username "user"
|
||||
header username glob("user*")
|
||||
header username regex("user.*")`,
|
||||
args: map[string]string{
|
||||
"key": "the header key",
|
||||
"[value]": "the header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalV,
|
||||
validate: toKVOptionalVMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
if v == "" {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return len(r.Header[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return slices.Contains(r.Header[k], v)
|
||||
return slices.ContainsFunc(r.Header[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnQuery: {
|
||||
help: Help{
|
||||
command: OnQuery,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
query username "user"
|
||||
query username glob("user*")
|
||||
query username regex("user.*")`,
|
||||
args: map[string]string{
|
||||
"key": "the query key",
|
||||
"[value]": "the query value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalV,
|
||||
validate: toKVOptionalVMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
if v == "" {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return len(cached.GetQueries(r)[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
queries := cached.GetQueries(r)[k]
|
||||
return slices.Contains(queries, v)
|
||||
return slices.ContainsFunc(cached.GetQueries(r)[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnCookie: {
|
||||
help: Help{
|
||||
command: OnCookie,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
cookie username "user"
|
||||
cookie username glob("user*")
|
||||
cookie username regex("user.*")`,
|
||||
args: map[string]string{
|
||||
"key": "the cookie key",
|
||||
"[value]": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalV,
|
||||
validate: toKVOptionalVMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
if v == "" {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
cookies := cached.GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
@@ -107,9 +118,10 @@ var checkers = map[string]struct {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
cookies := cached.GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k &&
|
||||
cookie.Value == v {
|
||||
return true
|
||||
if cookie.Name == k {
|
||||
if matcher(cookie.Value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
@@ -119,42 +131,50 @@ var checkers = map[string]struct {
|
||||
OnForm: {
|
||||
help: Help{
|
||||
command: OnForm,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
form username "user"
|
||||
form username glob("user*")
|
||||
form username regex("user.*")`,
|
||||
args: map[string]string{
|
||||
"key": "the form key",
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalV,
|
||||
validate: toKVOptionalVMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
if v == "" {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return r.FormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return r.FormValue(k) == v
|
||||
return matcher(r.FormValue(k))
|
||||
}
|
||||
},
|
||||
},
|
||||
OnPostForm: {
|
||||
help: Help{
|
||||
command: OnPostForm,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
postform username "user"
|
||||
postform username glob("user*")
|
||||
postform username regex("user.*")`,
|
||||
args: map[string]string{
|
||||
"key": "the form key",
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalV,
|
||||
validate: toKVOptionalVMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
if v == "" {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return r.PostFormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return r.PostFormValue(k) == v
|
||||
return matcher(r.PostFormValue(k))
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -173,25 +193,47 @@ var checkers = map[string]struct {
|
||||
}
|
||||
},
|
||||
},
|
||||
OnHost: {
|
||||
help: Help{
|
||||
command: OnHost,
|
||||
description: `Supports string, glob pattern, or regex pattern, e.g.:
|
||||
host example.com
|
||||
host glob(example*.com)
|
||||
host regex(example\w+\.com)
|
||||
host regex(example\.com$)`,
|
||||
args: map[string]string{
|
||||
"host": "the host name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return matcher(r.Host)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnPath: {
|
||||
help: Help{
|
||||
command: OnPath,
|
||||
description: `The path can be a glob pattern, e.g.:
|
||||
/path/to
|
||||
/path/to/*`,
|
||||
description: `Supports string, glob pattern, or regex pattern, e.g.:
|
||||
path /path/to
|
||||
path glob(/path/to/*)
|
||||
path regex(^/path/to/.*$)
|
||||
path regex(/path/[A-Z]+/)`,
|
||||
args: map[string]string{
|
||||
"path": "the request path",
|
||||
},
|
||||
},
|
||||
validate: validateURLPathGlob,
|
||||
validate: validateURLPathMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
pat := args.(glob.Glob)
|
||||
matcher := args.(Matcher)
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
reqPath := r.URL.Path
|
||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||
reqPath = "/" + reqPath
|
||||
}
|
||||
return pat.Match(reqPath)
|
||||
return matcher(reqPath)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -204,13 +246,24 @@ var checkers = map[string]struct {
|
||||
},
|
||||
validate: validateCIDR,
|
||||
builder: func(args any) CheckFunc {
|
||||
cidr := args.(nettypes.CIDR)
|
||||
ipnet := args.(*net.IPNet)
|
||||
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||
wantIP := ipnet.IP
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
ip := cached.GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(wantIP)
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
ip := cached.GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return cidr.Contains(ip)
|
||||
return ipnet.Contains(ip)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -233,15 +286,19 @@ var checkers = map[string]struct {
|
||||
OnRoute: {
|
||||
help: Help{
|
||||
command: OnRoute,
|
||||
description: `Supports string, glob pattern, or regex pattern, e.g.:
|
||||
route example
|
||||
route glob(example*)
|
||||
route regex(example\w+)`,
|
||||
args: map[string]string{
|
||||
"route": "the route name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleArg,
|
||||
validate: validateSingleMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
route := args.(string)
|
||||
matcher := args.(Matcher)
|
||||
return func(_ Cache, r *http.Request) bool {
|
||||
return routes.TryGetUpstreamName(r) == route
|
||||
return matcher(routes.TryGetUpstreamName(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -253,8 +310,8 @@ var (
|
||||
)
|
||||
|
||||
func indexAnd(s string) int {
|
||||
for i, c := range s {
|
||||
if andSeps[c] != 0 {
|
||||
for i := range s {
|
||||
if andSeps[s[i]] != 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
@@ -263,8 +320,8 @@ func indexAnd(s string) int {
|
||||
|
||||
func countAnd(s string) int {
|
||||
n := 0
|
||||
for _, c := range s {
|
||||
if andSeps[c] != 0 {
|
||||
for i := range s {
|
||||
if andSeps[s[i]] != 0 {
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"unicode"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
@@ -14,20 +16,27 @@ var escapedChars = map[rune]rune{
|
||||
'\'': '\'',
|
||||
'"': '"',
|
||||
'\\': '\\',
|
||||
'$': '$',
|
||||
' ': ' ',
|
||||
}
|
||||
|
||||
// parse expression to subject and args
|
||||
// with support for quotes and escaped chars, e.g.
|
||||
// with support for quotes, escaped chars, and env substitution, e.g.
|
||||
//
|
||||
// error 403 "Forbidden 'foo' 'bar'"
|
||||
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
||||
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
||||
func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
|
||||
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
brackets := 0
|
||||
|
||||
var envVar bytes.Buffer
|
||||
var missingEnvVars bytes.Buffer
|
||||
inEnvVar := false
|
||||
expectingBrace := false
|
||||
|
||||
flush := func(quoted bool) {
|
||||
part := buf.String()
|
||||
if !quoted {
|
||||
@@ -56,8 +65,7 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
if ch, ok := escapedChars[r]; ok {
|
||||
buf.WriteRune(ch)
|
||||
} else {
|
||||
err = ErrUnsupportedEscapeChar.Subjectf("\\%c", r)
|
||||
return subject, args, err
|
||||
fmt.Fprintf(buf, `\%c`, r)
|
||||
}
|
||||
escaped = false
|
||||
continue
|
||||
@@ -65,10 +73,36 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
switch r {
|
||||
case '\\':
|
||||
escaped = true
|
||||
continue
|
||||
case '"', '\'':
|
||||
case '$':
|
||||
if expectingBrace { // $$ => $ and continue
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
} else {
|
||||
expectingBrace = true
|
||||
}
|
||||
case '{':
|
||||
if expectingBrace {
|
||||
inEnvVar = true
|
||||
expectingBrace = false
|
||||
envVar.Reset()
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '}':
|
||||
if inEnvVar {
|
||||
envValue, ok := os.LookupEnv(envVar.String())
|
||||
if !ok {
|
||||
fmt.Fprintf(&missingEnvVars, "%q, ", envVar.String())
|
||||
} else {
|
||||
buf.WriteString(envValue)
|
||||
}
|
||||
inEnvVar = false
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
switch {
|
||||
case quote == 0:
|
||||
case quote == 0 && brackets == 0:
|
||||
quote = r
|
||||
flush(false)
|
||||
case r == quote:
|
||||
@@ -77,21 +111,40 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '(':
|
||||
brackets++
|
||||
buf.WriteRune(r)
|
||||
case ')':
|
||||
if brackets == 0 {
|
||||
err = ErrUnterminatedBrackets
|
||||
return subject, args, err
|
||||
}
|
||||
brackets--
|
||||
buf.WriteRune(r)
|
||||
case ' ':
|
||||
if quote == 0 {
|
||||
flush(false)
|
||||
continue
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
buf.WriteRune(r)
|
||||
if inEnvVar {
|
||||
envVar.WriteRune(r)
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
err = ErrUnterminatedQuotes
|
||||
} else if brackets != 0 {
|
||||
err = ErrUnterminatedBrackets
|
||||
} else {
|
||||
flush(false)
|
||||
}
|
||||
if missingEnvVars.Len() > 0 {
|
||||
err = gperr.Join(err, ErrEnvVarNotFound.Subject(missingEnvVars.String()))
|
||||
}
|
||||
return subject, args, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
@@ -47,9 +48,28 @@ func TestParser(t *testing.T) {
|
||||
args: []string{"", ""},
|
||||
},
|
||||
{
|
||||
name: "invalid_escape",
|
||||
input: `foo \bar`,
|
||||
wantErr: ErrUnsupportedEscapeChar,
|
||||
name: "regex_escaped",
|
||||
input: `foo regex(\b\B\s\S\w\W\d\D\$\.)`,
|
||||
subject: "foo",
|
||||
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.)`},
|
||||
},
|
||||
{
|
||||
name: "quote inside argument",
|
||||
input: `foo "abc 'def'"`,
|
||||
subject: "foo",
|
||||
args: []string{"abc 'def'"},
|
||||
},
|
||||
{
|
||||
name: "quote inside function",
|
||||
input: `foo glob("'/**/to/path'")`,
|
||||
subject: "foo",
|
||||
args: []string{"glob(\"'/**/to/path'\")"},
|
||||
},
|
||||
{
|
||||
name: "quote inside quoted function",
|
||||
input: "foo 'glob(\"`/**/to/path`\")'",
|
||||
subject: "foo",
|
||||
args: []string{"glob(\"`/**/to/path`\")"},
|
||||
},
|
||||
{
|
||||
name: "chaos",
|
||||
@@ -74,12 +94,69 @@ func TestParser(t *testing.T) {
|
||||
// t.Log(subject, args, err)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, tt.subject)
|
||||
expect.Equal(t, len(args), len(tt.args))
|
||||
for i, arg := range args {
|
||||
expect.Equal(t, arg, tt.args[i])
|
||||
}
|
||||
expect.Equal(t, args, tt.args)
|
||||
})
|
||||
}
|
||||
t.Run("env substitution", func(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
os.Setenv("CLOUDFLARE_API_KEY", "test-api-key-123")
|
||||
os.Setenv("DOMAIN", "example.com")
|
||||
defer func() {
|
||||
os.Unsetenv("CLOUDFLARE_API_KEY")
|
||||
os.Unsetenv("DOMAIN")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
subject string
|
||||
args []string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "simple env var",
|
||||
input: `error 403 "Forbidden: ${CLOUDFLARE_API_KEY}"`,
|
||||
subject: "error",
|
||||
args: []string{"403", "Forbidden: test-api-key-123"},
|
||||
},
|
||||
{
|
||||
name: "multiple env vars",
|
||||
input: `forward https://${DOMAIN}/api`,
|
||||
subject: "forward",
|
||||
args: []string{"https://example.com/api"},
|
||||
},
|
||||
{
|
||||
name: "env var with other text",
|
||||
input: `auth "user-${DOMAIN}-admin" "password"`,
|
||||
subject: "auth",
|
||||
args: []string{"user-example.com-admin", "password"},
|
||||
},
|
||||
{
|
||||
name: "non-existent env var",
|
||||
input: `error 404 "${NON_EXISTENT}"`,
|
||||
wantErr: ErrEnvVarNotFound.Error(),
|
||||
},
|
||||
{
|
||||
name: "escaped",
|
||||
input: `error 404 "$${NON_EXISTENT}"`,
|
||||
subject: "error",
|
||||
args: []string{"404", "${NON_EXISTENT}"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subject, args, err := parse(tt.input)
|
||||
if tt.wantErr != "" {
|
||||
expect.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, tt.subject)
|
||||
expect.Equal(t, args, tt.args)
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("unterminated quotes", func(t *testing.T) {
|
||||
tests := []string{
|
||||
`error 403 "Forbidden 'foo' 'bar'`,
|
||||
@@ -97,7 +174,7 @@ func TestParser(t *testing.T) {
|
||||
|
||||
func BenchmarkParser(b *testing.B) {
|
||||
const input = `error 403 "Forbidden "foo" "bar""\ baz`
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
_, _, err := parse(input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
||||
@@ -2,9 +2,12 @@ package rules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
@@ -19,7 +22,9 @@ type (
|
||||
First T1
|
||||
Second T2
|
||||
}
|
||||
StrTuple = Tuple[string, string]
|
||||
StrTuple = Tuple[string, string]
|
||||
IntTuple = Tuple[int, int]
|
||||
MapValueMatcher = Tuple[string, Matcher]
|
||||
)
|
||||
|
||||
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
||||
@@ -30,11 +35,101 @@ func (t *Tuple[T1, T2]) String() string {
|
||||
return fmt.Sprintf("%v:%v", t.First, t.Second)
|
||||
}
|
||||
|
||||
func validateSingleArg(args []string) (any, gperr.Error) {
|
||||
type (
|
||||
Matcher func(string) bool
|
||||
MatcherType string
|
||||
)
|
||||
|
||||
const (
|
||||
MatcherTypeString MatcherType = "string"
|
||||
MatcherTypeGlob MatcherType = "glob"
|
||||
MatcherTypeRegex MatcherType = "regex"
|
||||
)
|
||||
|
||||
func unquoteExpr(s string) (string, gperr.Error) {
|
||||
if s == "" {
|
||||
return "", nil
|
||||
}
|
||||
switch s[0] {
|
||||
case '"', '\'', '`':
|
||||
if s[0] != s[len(s)-1] {
|
||||
return "", ErrUnterminatedQuotes
|
||||
}
|
||||
return s[1 : len(s)-1], nil
|
||||
default:
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
|
||||
idx := strings.IndexByte(s, '(')
|
||||
if idx == -1 {
|
||||
return MatcherTypeString, s, nil
|
||||
}
|
||||
idxEnd := strings.LastIndexByte(s, ')')
|
||||
if idxEnd == -1 {
|
||||
return "", "", ErrUnterminatedBrackets
|
||||
}
|
||||
|
||||
expr, err = unquoteExpr(s[idx+1 : idxEnd])
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
matcherType = MatcherType(strings.ToLower(s[:idx]))
|
||||
|
||||
switch matcherType {
|
||||
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
|
||||
return
|
||||
default:
|
||||
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
|
||||
}
|
||||
}
|
||||
|
||||
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
t, expr, err := ExtractExpr(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t {
|
||||
case MatcherTypeString:
|
||||
return StringMatcher(expr)
|
||||
case MatcherTypeGlob:
|
||||
return GlobMatcher(expr)
|
||||
case MatcherTypeRegex:
|
||||
return RegexMatcher(expr)
|
||||
}
|
||||
// won't reach here
|
||||
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
||||
}
|
||||
|
||||
func StringMatcher(s string) (Matcher, gperr.Error) {
|
||||
return func(s2 string) bool {
|
||||
return s == s2
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GlobMatcher(expr string) (Matcher, gperr.Error) {
|
||||
g, err := glob.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return g.Match, nil
|
||||
}
|
||||
|
||||
func RegexMatcher(expr string) (Matcher, gperr.Error) {
|
||||
re, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return re.MatchString, nil
|
||||
}
|
||||
|
||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||
func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
return args[0], nil
|
||||
return ParseMatcher(args[0])
|
||||
}
|
||||
|
||||
// toStrTuple returns *StrTuple.
|
||||
@@ -45,13 +140,17 @@ func toStrTuple(args []string) (any, gperr.Error) {
|
||||
return &StrTuple{args[0], args[1]}, nil
|
||||
}
|
||||
|
||||
// toKVOptionalV returns *StrTuple that value is optional.
|
||||
func toKVOptionalV(args []string) (any, gperr.Error) {
|
||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
switch len(args) {
|
||||
case 1:
|
||||
return &StrTuple{args[0], ""}, nil
|
||||
return &MapValueMatcher{args[0], nil}, nil
|
||||
case 2:
|
||||
return &StrTuple{args[0], args[1]}, nil
|
||||
m, err := ParseMatcher(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MapValueMatcher{args[0], m}, nil
|
||||
default:
|
||||
return nil, ErrExpectKVOptionalV
|
||||
}
|
||||
@@ -95,11 +194,11 @@ func validateCIDR(args []string) (any, gperr.Error) {
|
||||
if !strings.Contains(args[0], "/") {
|
||||
args[0] += "/32"
|
||||
}
|
||||
cidr, err := nettypes.ParseCIDR(args[0])
|
||||
_, ipnet, err := net.ParseCIDR(args[0])
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return cidr, nil
|
||||
return ipnet, nil
|
||||
}
|
||||
|
||||
// validateURLPath returns string with the path validated.
|
||||
@@ -120,35 +219,12 @@ func validateURLPath(args []string) (any, gperr.Error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// validateURLPathGlob returns []string with each element validated.
|
||||
func validateURLPathGlob(args []string) (any, gperr.Error) {
|
||||
p, err := validateURLPath(args)
|
||||
func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
path, err := validateURLPath(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g, gErr := glob.Compile(p.(string))
|
||||
if gErr != nil {
|
||||
return nil, ErrInvalidArguments.With(gErr)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// validateURLPaths returns []string with each element validated.
|
||||
func validateURLPaths(paths []string) (any, gperr.Error) {
|
||||
errs := gperr.NewBuilder("invalid url paths")
|
||||
for i, p := range paths {
|
||||
val, err := validateURLPath([]string{p})
|
||||
if err != nil {
|
||||
errs.Add(err.Subject(p))
|
||||
continue
|
||||
}
|
||||
paths[i] = val.(string)
|
||||
}
|
||||
if err := errs.Error(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return paths, nil
|
||||
return ParseMatcher(path.(string))
|
||||
}
|
||||
|
||||
// validateFSPath returns string with the path validated.
|
||||
|
||||
98
internal/route/rules/validate_test.go
Normal file
98
internal/route/rules/validate_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
func TestExtractExpr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
wantT MatcherType
|
||||
wantExpr string
|
||||
}{
|
||||
{
|
||||
name: "string implicit",
|
||||
in: "foo",
|
||||
wantT: MatcherTypeString,
|
||||
wantExpr: "foo",
|
||||
},
|
||||
{
|
||||
name: "string explicit",
|
||||
in: "string(`foo`)",
|
||||
wantT: MatcherTypeString,
|
||||
wantExpr: "foo",
|
||||
},
|
||||
{
|
||||
name: "glob",
|
||||
in: "glob(foo)",
|
||||
wantT: MatcherTypeGlob,
|
||||
wantExpr: "foo",
|
||||
},
|
||||
{
|
||||
name: "glob quoted",
|
||||
in: "glob(`foo`)",
|
||||
wantT: MatcherTypeGlob,
|
||||
wantExpr: "foo",
|
||||
},
|
||||
{
|
||||
name: "regex",
|
||||
in: "regex(^[A-Z]+$)",
|
||||
wantT: MatcherTypeRegex,
|
||||
wantExpr: "^[A-Z]+$",
|
||||
},
|
||||
{
|
||||
name: "regex quoted",
|
||||
in: "regex(`^[A-Z]+$`)",
|
||||
wantT: MatcherTypeRegex,
|
||||
wantExpr: "^[A-Z]+$",
|
||||
},
|
||||
{
|
||||
name: "quoted expr",
|
||||
in: "glob(`'foo'`)",
|
||||
wantT: MatcherTypeGlob,
|
||||
wantExpr: "'foo'",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
typ, expr, err := ExtractExpr(tt.in)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, tt.wantT, typ)
|
||||
expect.Equal(t, tt.wantExpr, expr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractExprInvalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "missing closing quote",
|
||||
in: "glob(`foo)",
|
||||
wantErr: "unterminated quotes",
|
||||
},
|
||||
{
|
||||
name: "missing closing bracket",
|
||||
in: "glob(`foo",
|
||||
wantErr: "unterminated brackets",
|
||||
},
|
||||
{
|
||||
name: "invalid matcher type",
|
||||
in: "invalid(`foo`)",
|
||||
wantErr: "invalid matcher type: invalid",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, err := ExtractExpr(tt.in)
|
||||
expect.HasError(t, err)
|
||||
expect.ErrorContains(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user