mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 16:58:31 +02:00
refactor(rules): introduce block DSL, phase-based execution, and flow validation
- add block syntax parser/scanner with nested @blocks and elif/else support - restructure rule execution into explicit pre/post phases with phase flags - classify commands by phase and termination behavior - enforce flow semantics (default rule handling, dead-rule detection) - expand HTTP flow coverage with block + YAML parity tests and benches - refresh rules README/spec and update playground/docs integration
This commit is contained in:
@@ -12,19 +12,19 @@ import (
|
||||
)
|
||||
|
||||
type RuleOn struct {
|
||||
raw string
|
||||
checker Checker
|
||||
isResponseChecker bool
|
||||
}
|
||||
|
||||
func (on *RuleOn) IsResponseChecker() bool {
|
||||
return on.isResponseChecker
|
||||
raw string
|
||||
checker Checker
|
||||
phase PhaseFlag
|
||||
}
|
||||
|
||||
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
return on.checker.Check(w, r)
|
||||
if on.checker == nil {
|
||||
return true
|
||||
}
|
||||
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
|
||||
}
|
||||
|
||||
// on request
|
||||
const (
|
||||
OnDefault = "default"
|
||||
OnHeader = "header"
|
||||
@@ -39,35 +39,36 @@ const (
|
||||
OnRemote = "remote"
|
||||
OnBasicAuth = "basic_auth"
|
||||
OnRoute = "route"
|
||||
)
|
||||
|
||||
// on response
|
||||
|
||||
// on response
|
||||
const (
|
||||
OnResponseHeader = "resp_header"
|
||||
OnStatus = "status"
|
||||
)
|
||||
|
||||
var checkers = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
isResponseChecker bool
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
}{
|
||||
OnDefault: {
|
||||
help: Help{
|
||||
command: OnDefault,
|
||||
description: makeLines(
|
||||
"The default rule is matched when no other rules are matched.",
|
||||
"Select the default (baseline) rule.",
|
||||
),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
return phase, nil, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
|
||||
builder: func(args any) CheckFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool { return false }
|
||||
}, // this should never be called
|
||||
},
|
||||
OnHeader: {
|
||||
help: Help{
|
||||
@@ -83,21 +84,23 @@ var checkers = map[string]struct {
|
||||
"[value]": "the header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(r.Header[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(r.Header[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnResponseHeader: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnResponseHeader,
|
||||
description: makeLines(
|
||||
@@ -111,16 +114,20 @@ var checkers = map[string]struct {
|
||||
"[value]": "the response header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePost
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(w.Header()[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(w.Header()[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -138,16 +145,19 @@ var checkers = map[string]struct {
|
||||
"[value]": "the query value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(w.SharedData().GetQueries(r)[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -165,12 +175,15 @@ var checkers = map[string]struct {
|
||||
"[value]": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
cookies := w.SharedData().GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
return true
|
||||
@@ -179,8 +192,8 @@ var checkers = map[string]struct {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
cookies := w.SharedData().GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
if matcher(cookie.Value) {
|
||||
@@ -192,6 +205,7 @@ var checkers = map[string]struct {
|
||||
}
|
||||
},
|
||||
},
|
||||
//nolint:dupl
|
||||
OnForm: {
|
||||
help: Help{
|
||||
command: OnForm,
|
||||
@@ -206,15 +220,18 @@ var checkers = map[string]struct {
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.FormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.FormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -233,15 +250,18 @@ var checkers = map[string]struct {
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.PostFormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.PostFormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -250,32 +270,46 @@ var checkers = map[string]struct {
|
||||
help: Help{
|
||||
command: OnProto,
|
||||
args: map[string]string{
|
||||
"proto": "the http protocol (http, https, h3)",
|
||||
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
proto := args[0]
|
||||
if proto != "http" && proto != "https" && proto != "h3" {
|
||||
return nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||
switch proto {
|
||||
case "http", "https", "h1", "h2", "h2c", "h3":
|
||||
return phase, proto, nil
|
||||
default:
|
||||
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||
}
|
||||
return proto, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
proto := args.(string)
|
||||
switch proto {
|
||||
case "http":
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil
|
||||
}
|
||||
case "https":
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil
|
||||
}
|
||||
case "h1":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil && r.ProtoMajor == 1
|
||||
}
|
||||
case "h2":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil && r.ProtoMajor == 2
|
||||
}
|
||||
case "h2c":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil && r.ProtoMajor == 2
|
||||
}
|
||||
default: // h3
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil && r.ProtoMajor == 3
|
||||
}
|
||||
}
|
||||
@@ -288,10 +322,13 @@ var checkers = map[string]struct {
|
||||
"method": "the http method",
|
||||
},
|
||||
},
|
||||
validate: validateMethod,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateMethod(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
method := args.(string)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.Method == method
|
||||
}
|
||||
},
|
||||
@@ -310,10 +347,13 @@ var checkers = map[string]struct {
|
||||
"host": "the host name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateSingleMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.Host)
|
||||
}
|
||||
},
|
||||
@@ -332,10 +372,13 @@ var checkers = map[string]struct {
|
||||
"path": "the request path",
|
||||
},
|
||||
},
|
||||
validate: validateURLPathMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateURLPathMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
reqPath := r.URL.Path
|
||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||
reqPath = "/" + reqPath
|
||||
@@ -351,22 +394,25 @@ var checkers = map[string]struct {
|
||||
"ip|cidr": "the remote ip or cidr",
|
||||
},
|
||||
},
|
||||
validate: validateCIDR,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateCIDR(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
ipnet := args.(*net.IPNet)
|
||||
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||
wantIP := ipnet.IP
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
ip := w.SharedData().GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(wantIP)
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
ip := w.SharedData().GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
@@ -382,11 +428,14 @@ var checkers = map[string]struct {
|
||||
"password": "the password encrypted with bcrypt",
|
||||
},
|
||||
},
|
||||
validate: validateUserBCryptPassword,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateUserBCryptPassword(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
cred := args.(*HashedCrendentials)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return cred.Match(w.SharedData().GetBasicAuth(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -403,16 +452,18 @@ var checkers = map[string]struct {
|
||||
"route": "the route name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateSingleMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(_ http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(routes.TryGetUpstreamName(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
OnStatus: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnStatus,
|
||||
description: makeLines(
|
||||
@@ -429,16 +480,20 @@ var checkers = map[string]struct {
|
||||
"status": "the status code range",
|
||||
},
|
||||
},
|
||||
validate: validateStatusRange,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePost
|
||||
parsedArgs, err = validateStatusRange(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
beg, end := args.(*IntTuple).Unpack()
|
||||
if beg == end {
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
return httputils.GetInitResponseModifier(w).StatusCode() == beg
|
||||
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
return w.StatusCode() == beg
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
|
||||
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
statusCode := w.StatusCode()
|
||||
return statusCode >= beg && statusCode <= end
|
||||
}
|
||||
},
|
||||
@@ -515,85 +570,119 @@ func splitPipe(s string) []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
result := make([]string, 0, 2)
|
||||
quote := byte(0)
|
||||
brackets := 0
|
||||
start := 0
|
||||
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '\\':
|
||||
escaped = true
|
||||
current.WriteRune(r)
|
||||
// Skip escaped character.
|
||||
if i+1 < len(s) {
|
||||
i++
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
if quote == 0 && brackets == 0 {
|
||||
quote = r
|
||||
} else if r == quote {
|
||||
quote = s[i]
|
||||
} else if s[i] == quote {
|
||||
quote = 0
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '(':
|
||||
brackets++
|
||||
current.WriteRune(r)
|
||||
case ')':
|
||||
if brackets > 0 {
|
||||
brackets--
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '|':
|
||||
if quote == 0 && brackets == 0 {
|
||||
// Found a pipe outside quotes/brackets, split here
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
result = append(result, strings.TrimSpace(s[start:i]))
|
||||
start = i + 1
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last part
|
||||
if current.Len() > 0 {
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
// drop trailing empty part.
|
||||
if start < len(s) {
|
||||
result = append(result, strings.TrimSpace(s[start:]))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func forEachAndPart(s string, fn func(part string)) {
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i < len(s) && andSeps[s[i]] == 0 {
|
||||
continue
|
||||
}
|
||||
part := strings.TrimSpace(s[start:i])
|
||||
if part != "" {
|
||||
fn(part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
func forEachPipePart(s string, fn func(part string)) {
|
||||
quote := byte(0)
|
||||
brackets := 0
|
||||
start := 0
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '\\':
|
||||
if i+1 < len(s) {
|
||||
i++
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
if quote == 0 && brackets == 0 {
|
||||
quote = s[i]
|
||||
} else if s[i] == quote {
|
||||
quote = 0
|
||||
}
|
||||
case '(':
|
||||
brackets++
|
||||
case ')':
|
||||
if brackets > 0 {
|
||||
brackets--
|
||||
}
|
||||
case '|':
|
||||
if quote == 0 && brackets == 0 {
|
||||
fn(strings.TrimSpace(s[start:i]))
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
if start < len(s) {
|
||||
fn(strings.TrimSpace(s[start:]))
|
||||
}
|
||||
}
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (on *RuleOn) Parse(v string) error {
|
||||
on.raw = v
|
||||
|
||||
rules := splitAnd(v)
|
||||
checkAnd := make(CheckMatchAll, 0, len(rules))
|
||||
ruleCount := 0
|
||||
forEachAndPart(v, func(_ string) {
|
||||
ruleCount++
|
||||
})
|
||||
checkAnd := make(CheckMatchAll, 0, ruleCount)
|
||||
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
isResponseChecker := false
|
||||
for i, rule := range rules {
|
||||
if rule == "" {
|
||||
continue
|
||||
}
|
||||
parsed, isResp, err := parseOn(rule)
|
||||
i := 0
|
||||
forEachAndPart(v, func(rule string) {
|
||||
i++
|
||||
parsed, phase, err := parseOn(rule)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "line %d", i+1)
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
errs.AddSubjectf(err, "line %d", i)
|
||||
return
|
||||
}
|
||||
on.phase |= phase
|
||||
checkAnd = append(checkAnd, parsed)
|
||||
}
|
||||
})
|
||||
|
||||
on.checker = checkAnd
|
||||
on.isResponseChecker = isResponseChecker
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
@@ -605,33 +694,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||
return []byte(on.String()), nil
|
||||
}
|
||||
|
||||
func parseOn(line string) (Checker, bool, error) {
|
||||
ors := splitPipe(line)
|
||||
|
||||
if len(ors) > 1 {
|
||||
func parseOn(line string) (Checker, PhaseFlag, error) {
|
||||
orCount := 0
|
||||
forEachPipePart(line, func(_ string) {
|
||||
orCount++
|
||||
})
|
||||
if orCount > 1 {
|
||||
var phase PhaseFlag
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
checkOr := make(CheckMatchSingle, len(ors))
|
||||
isResponseChecker := false
|
||||
for i, or := range ors {
|
||||
curCheckers, isResp, err := parseOn(or)
|
||||
checkOr := make(CheckMatchSingle, orCount)
|
||||
i := 0
|
||||
forEachPipePart(line, func(or string) {
|
||||
i++
|
||||
checkFunc, req, err := parseOnAtom(or)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
errs.AddSubjectf(err, "or[%d]", i)
|
||||
return
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
}
|
||||
checkOr[i] = curCheckers.(CheckFunc)
|
||||
}
|
||||
checkOr[i-1] = checkFunc
|
||||
phase |= req
|
||||
})
|
||||
if err := errs.Error(); err != nil {
|
||||
return nil, false, err
|
||||
return nil, phase, err
|
||||
}
|
||||
return checkOr, isResponseChecker, nil
|
||||
return checkOr, phase, nil
|
||||
}
|
||||
|
||||
return parseOnAtom(line)
|
||||
}
|
||||
|
||||
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
|
||||
var phase PhaseFlag
|
||||
subject, args, err := parse(line)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, phase, err
|
||||
}
|
||||
|
||||
negate := false
|
||||
@@ -642,20 +738,21 @@ func parseOn(line string) (Checker, bool, error) {
|
||||
|
||||
checker, ok := checkers[subject]
|
||||
if !ok {
|
||||
return nil, false, ErrInvalidOnTarget.Subject(subject)
|
||||
return nil, phase, ErrInvalidOnTarget.Subject(subject)
|
||||
}
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
req, validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, false, gperr.Wrap(err).With(checker.help.Error())
|
||||
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
|
||||
}
|
||||
phase |= req
|
||||
|
||||
checkFunc := checker.builder(validArgs)
|
||||
if negate {
|
||||
origCheckFunc := checkFunc
|
||||
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
|
||||
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return !origCheckFunc(w, r)
|
||||
}
|
||||
}
|
||||
return checkFunc, checker.isResponseChecker, nil
|
||||
return checkFunc, phase, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user