refactor(rules): introduce block DSL, phase-based execution, and flow validation

- add block syntax parser/scanner with nested @blocks and elif/else support
- restructure rule execution into explicit pre/post phases with phase flags
- classify commands by phase and termination behavior
- enforce flow semantics (default rule handling, dead-rule detection)
- expand HTTP flow coverage with block + YAML parity tests and benches
- refresh rules README/spec and update playground/docs integration
This commit is contained in:
yusing
2026-02-23 22:24:15 +08:00
parent 0850ea3918
commit faecbab2cb
34 changed files with 4691 additions and 1057 deletions

View File

@@ -12,19 +12,19 @@ import (
)
type RuleOn struct {
raw string
checker Checker
isResponseChecker bool
}
func (on *RuleOn) IsResponseChecker() bool {
return on.isResponseChecker
raw string
checker Checker
phase PhaseFlag
}
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
return on.checker.Check(w, r)
if on.checker == nil {
return true
}
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
}
// on request
const (
OnDefault = "default"
OnHeader = "header"
@@ -39,35 +39,36 @@ const (
OnRemote = "remote"
OnBasicAuth = "basic_auth"
OnRoute = "route"
)
// on response
// on response
const (
OnResponseHeader = "resp_header"
OnStatus = "status"
)
var checkers = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) CheckFunc
isResponseChecker bool
help Help
validate ValidateFunc
builder func(args any) CheckFunc
}{
OnDefault: {
help: Help{
command: OnDefault,
description: makeLines(
"The default rule is matched when no other rules are matched.",
"Select the default (baseline) rule.",
),
args: map[string]string{},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 0 {
return nil, ErrExpectNoArg
return phase, nil, ErrExpectNoArg
}
//nolint:nilnil
return nil, nil
return phase, nil, nil
},
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
builder: func(args any) CheckFunc {
return func(w *httputils.ResponseModifier, r *http.Request) bool { return false }
}, // this should never be called
},
OnHeader: {
help: Help{
@@ -83,21 +84,23 @@ var checkers = map[string]struct {
"[value]": "the header value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(r.Header[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(r.Header[k], matcher)
}
},
},
OnResponseHeader: {
isResponseChecker: true,
help: Help{
command: OnResponseHeader,
description: makeLines(
@@ -111,16 +114,20 @@ var checkers = map[string]struct {
"[value]": "the response header value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePost
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(w.Header()[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(w.Header()[k], matcher)
}
},
},
@@ -138,16 +145,19 @@ var checkers = map[string]struct {
"[value]": "the query value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(w.SharedData().GetQueries(r)[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
}
},
},
@@ -165,12 +175,15 @@ var checkers = map[string]struct {
"[value]": "the cookie value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
cookies := httputils.GetSharedData(w).GetCookies(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
cookies := w.SharedData().GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
return true
@@ -179,8 +192,8 @@ var checkers = map[string]struct {
return false
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
cookies := httputils.GetSharedData(w).GetCookies(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
cookies := w.SharedData().GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
if matcher(cookie.Value) {
@@ -192,6 +205,7 @@ var checkers = map[string]struct {
}
},
},
//nolint:dupl
OnForm: {
help: Help{
command: OnForm,
@@ -206,15 +220,18 @@ var checkers = map[string]struct {
"[value]": "the form value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.FormValue(k) != ""
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.FormValue(k))
}
},
@@ -233,15 +250,18 @@ var checkers = map[string]struct {
"[value]": "the form value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.PostFormValue(k) != ""
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.PostFormValue(k))
}
},
@@ -250,32 +270,46 @@ var checkers = map[string]struct {
help: Help{
command: OnProto,
args: map[string]string{
"proto": "the http protocol (http, https, h3)",
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
}
proto := args[0]
if proto != "http" && proto != "https" && proto != "h3" {
return nil, ErrInvalidArguments.Withf("proto: %q", proto)
switch proto {
case "http", "https", "h1", "h2", "h2c", "h3":
return phase, proto, nil
default:
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
}
return proto, nil
},
builder: func(args any) CheckFunc {
proto := args.(string)
switch proto {
case "http":
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil
}
case "https":
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil
}
case "h1":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil && r.ProtoMajor == 1
}
case "h2":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil && r.ProtoMajor == 2
}
case "h2c":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil && r.ProtoMajor == 2
}
default: // h3
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil && r.ProtoMajor == 3
}
}
@@ -288,10 +322,13 @@ var checkers = map[string]struct {
"method": "the http method",
},
},
validate: validateMethod,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateMethod(args)
return
},
builder: func(args any) CheckFunc {
method := args.(string)
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.Method == method
}
},
@@ -310,10 +347,13 @@ var checkers = map[string]struct {
"host": "the host name",
},
},
validate: validateSingleMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateSingleMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.Host)
}
},
@@ -332,10 +372,13 @@ var checkers = map[string]struct {
"path": "the request path",
},
},
validate: validateURLPathMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateURLPathMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
@@ -351,22 +394,25 @@ var checkers = map[string]struct {
"ip|cidr": "the remote ip or cidr",
},
},
validate: validateCIDR,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateCIDR(args)
return
},
builder: func(args any) CheckFunc {
ipnet := args.(*net.IPNet)
// for /32 (IPv4) or /128 (IPv6), just compare the IP
if ones, bits := ipnet.Mask.Size(); ones == bits {
wantIP := ipnet.IP
return func(w http.ResponseWriter, r *http.Request) bool {
ip := httputils.GetSharedData(w).GetRemoteIP(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
ip := w.SharedData().GetRemoteIP(r)
if ip == nil {
return false
}
return ip.Equal(wantIP)
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
ip := httputils.GetSharedData(w).GetRemoteIP(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
ip := w.SharedData().GetRemoteIP(r)
if ip == nil {
return false
}
@@ -382,11 +428,14 @@ var checkers = map[string]struct {
"password": "the password encrypted with bcrypt",
},
},
validate: validateUserBCryptPassword,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateUserBCryptPassword(args)
return
},
builder: func(args any) CheckFunc {
cred := args.(*HashedCrendentials)
return func(w http.ResponseWriter, r *http.Request) bool {
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return cred.Match(w.SharedData().GetBasicAuth(r))
}
},
},
@@ -403,16 +452,18 @@ var checkers = map[string]struct {
"route": "the route name",
},
},
validate: validateSingleMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateSingleMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(_ http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(routes.TryGetUpstreamName(r))
}
},
},
OnStatus: {
isResponseChecker: true,
help: Help{
command: OnStatus,
description: makeLines(
@@ -429,16 +480,20 @@ var checkers = map[string]struct {
"status": "the status code range",
},
},
validate: validateStatusRange,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePost
parsedArgs, err = validateStatusRange(args)
return
},
builder: func(args any) CheckFunc {
beg, end := args.(*IntTuple).Unpack()
if beg == end {
return func(w http.ResponseWriter, _ *http.Request) bool {
return httputils.GetInitResponseModifier(w).StatusCode() == beg
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
return w.StatusCode() == beg
}
}
return func(w http.ResponseWriter, _ *http.Request) bool {
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
statusCode := w.StatusCode()
return statusCode >= beg && statusCode <= end
}
},
@@ -515,85 +570,119 @@ func splitPipe(s string) []string {
return []string{}
}
var result []string
var current strings.Builder
escaped := false
quote := rune(0)
result := make([]string, 0, 2)
quote := byte(0)
brackets := 0
start := 0
for _, r := range s {
if escaped {
current.WriteRune(r)
escaped = false
continue
}
switch r {
for i := 0; i < len(s); i++ {
switch s[i] {
case '\\':
escaped = true
current.WriteRune(r)
// Skip escaped character.
if i+1 < len(s) {
i++
}
case '"', '\'', '`':
if quote == 0 && brackets == 0 {
quote = r
} else if r == quote {
quote = s[i]
} else if s[i] == quote {
quote = 0
}
current.WriteRune(r)
case '(':
brackets++
current.WriteRune(r)
case ')':
if brackets > 0 {
brackets--
}
current.WriteRune(r)
case '|':
if quote == 0 && brackets == 0 {
// Found a pipe outside quotes/brackets, split here
result = append(result, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(r)
result = append(result, strings.TrimSpace(s[start:i]))
start = i + 1
}
default:
current.WriteRune(r)
}
}
// Add the last part
if current.Len() > 0 {
result = append(result, strings.TrimSpace(current.String()))
// drop trailing empty part.
if start < len(s) {
result = append(result, strings.TrimSpace(s[start:]))
}
return result
}
func forEachAndPart(s string, fn func(part string)) {
start := 0
for i := 0; i <= len(s); i++ {
if i < len(s) && andSeps[s[i]] == 0 {
continue
}
part := strings.TrimSpace(s[start:i])
if part != "" {
fn(part)
}
start = i + 1
}
}
func forEachPipePart(s string, fn func(part string)) {
quote := byte(0)
brackets := 0
start := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '\\':
if i+1 < len(s) {
i++
}
case '"', '\'', '`':
if quote == 0 && brackets == 0 {
quote = s[i]
} else if s[i] == quote {
quote = 0
}
case '(':
brackets++
case ')':
if brackets > 0 {
brackets--
}
case '|':
if quote == 0 && brackets == 0 {
fn(strings.TrimSpace(s[start:i]))
start = i + 1
}
}
}
if start < len(s) {
fn(strings.TrimSpace(s[start:]))
}
}
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error {
on.raw = v
rules := splitAnd(v)
checkAnd := make(CheckMatchAll, 0, len(rules))
ruleCount := 0
forEachAndPart(v, func(_ string) {
ruleCount++
})
checkAnd := make(CheckMatchAll, 0, ruleCount)
errs := gperr.NewBuilder("rule.on syntax errors")
isResponseChecker := false
for i, rule := range rules {
if rule == "" {
continue
}
parsed, isResp, err := parseOn(rule)
i := 0
forEachAndPart(v, func(rule string) {
i++
parsed, phase, err := parseOn(rule)
if err != nil {
errs.AddSubjectf(err, "line %d", i+1)
continue
}
if isResp {
isResponseChecker = true
errs.AddSubjectf(err, "line %d", i)
return
}
on.phase |= phase
checkAnd = append(checkAnd, parsed)
}
})
on.checker = checkAnd
on.isResponseChecker = isResponseChecker
return errs.Error()
}
@@ -605,33 +694,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
return []byte(on.String()), nil
}
func parseOn(line string) (Checker, bool, error) {
ors := splitPipe(line)
if len(ors) > 1 {
func parseOn(line string) (Checker, PhaseFlag, error) {
orCount := 0
forEachPipePart(line, func(_ string) {
orCount++
})
if orCount > 1 {
var phase PhaseFlag
errs := gperr.NewBuilder("rule.on syntax errors")
checkOr := make(CheckMatchSingle, len(ors))
isResponseChecker := false
for i, or := range ors {
curCheckers, isResp, err := parseOn(or)
checkOr := make(CheckMatchSingle, orCount)
i := 0
forEachPipePart(line, func(or string) {
i++
checkFunc, req, err := parseOnAtom(or)
if err != nil {
errs.Add(err)
continue
errs.AddSubjectf(err, "or[%d]", i)
return
}
if isResp {
isResponseChecker = true
}
checkOr[i] = curCheckers.(CheckFunc)
}
checkOr[i-1] = checkFunc
phase |= req
})
if err := errs.Error(); err != nil {
return nil, false, err
return nil, phase, err
}
return checkOr, isResponseChecker, nil
return checkOr, phase, nil
}
return parseOnAtom(line)
}
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
var phase PhaseFlag
subject, args, err := parse(line)
if err != nil {
return nil, false, err
return nil, phase, err
}
negate := false
@@ -642,20 +738,21 @@ func parseOn(line string) (Checker, bool, error) {
checker, ok := checkers[subject]
if !ok {
return nil, false, ErrInvalidOnTarget.Subject(subject)
return nil, phase, ErrInvalidOnTarget.Subject(subject)
}
validArgs, err := checker.validate(args)
req, validArgs, err := checker.validate(args)
if err != nil {
return nil, false, gperr.Wrap(err).With(checker.help.Error())
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
}
phase |= req
checkFunc := checker.builder(validArgs)
if negate {
origCheckFunc := checkFunc
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
return !origCheckFunc(w, r)
}
}
return checkFunc, checker.isResponseChecker, nil
return checkFunc, phase, nil
}