refactor(rules): introduce block DSL, phase-based execution, and flow validation

- add block syntax parser/scanner with nested @blocks and elif/else support
- restructure rule execution into explicit pre/post phases with phase flags
- classify commands by phase and termination behavior
- enforce flow semantics (default rule handling, dead-rule detection)
- expand HTTP flow coverage with block + YAML parity tests and benches
- refresh rules README/spec and update playground/docs integration
This commit is contained in:
yusing
2026-02-23 22:24:15 +08:00
parent 0850ea3918
commit faecbab2cb
34 changed files with 4691 additions and 1057 deletions

View File

@@ -1,7 +1,6 @@
package rules
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -24,17 +23,17 @@ import (
type (
Command struct {
raw string
exec CommandHandler
isResponseHandler bool
raw string
pre Commands // runs before w.WriteHeader
post Commands
}
)
func (cmd *Command) IsResponseHandler() bool {
return cmd.isResponseHandler
}
const (
CommandUpstream = "upstream"
CommandUpstreamOld = "bypass"
CommandUpstreamOld2 = "pass"
CommandRequireAuth = "require_auth"
CommandRewrite = "rewrite"
CommandServe = "serve"
@@ -48,8 +47,6 @@ const (
CommandRemove = "remove"
CommandLog = "log"
CommandNotify = "notify"
CommandPass = "pass"
CommandPassAlt = "bypass"
)
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
@@ -60,36 +57,57 @@ func InitAuthHandler(handler AuthHandler) {
authHandler = handler
}
func init() {
commands[CommandUpstreamOld] = commands[CommandUpstream]
commands[CommandUpstreamOld2] = commands[CommandUpstream]
}
var commands = map[string]struct {
help Help
validate ValidateFunc
build func(args any) CommandHandler
isResponseHandler bool
help Help
validate ValidateFunc
build func(args any) HandlerFunc
terminate bool
}{
CommandUpstream: {
help: Help{
command: CommandUpstream,
description: makeLines("Pass the request to the upstream"),
args: map[string]string{},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 0 {
return phase, nil, ErrExpectNoArg
}
return phase, nil, nil
},
build: func(args any) HandlerFunc {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
upstream(w, r)
return errTerminateRule
}
},
terminate: true,
},
CommandRequireAuth: {
help: Help{
command: CommandRequireAuth,
description: makeLines("Require HTTP authentication for incoming requests"),
args: map[string]string{},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 0 {
return nil, ErrExpectNoArg
return phase, nil, ErrExpectNoArg
}
//nolint:nilnil
return nil, nil
return phase, nil, nil
},
build: func(args any) CommandHandler {
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
if authHandler == nil {
http.Error(w, "Auth handler not initialized", http.StatusInternalServerError)
return errTerminated
}
if !authHandler(w, r) {
return errTerminated
build: func(args any) HandlerFunc {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
if proceed := authHandler(w, r); !proceed {
return errTerminateRule
}
return nil
})
}
},
},
CommandRewrite: {
@@ -104,26 +122,27 @@ var commands = map[string]struct {
"to": "the path to rewrite to, must start with /",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 2 {
return nil, ErrExpectTwoArgs
return phase, nil, ErrExpectTwoArgs
}
path1, err1 := validateURLPath(args[:1])
path2, err2 := validateURLPath(args[1:])
if err1 != nil {
err1 = gperr.PrependSubject(err1, "from")
err1 = gperr.Errorf("from: %w", err1)
}
if err2 != nil {
err2 = gperr.PrependSubject(err2, "to")
err2 = gperr.Errorf("to: %w", err2)
}
if err1 != nil || err2 != nil {
return nil, gperr.Join(err1, err2)
return phase, nil, gperr.Join(err1, err2)
}
return &StrTuple{path1.(string), path2.(string)}, nil
return phase, &StrTuple{path1.(string), path2.(string)}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
orig, repl := args.(*StrTuple).Unpack()
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
path := r.URL.Path
if len(path) > 0 && path[0] != '/' {
path = "/" + path
@@ -133,10 +152,10 @@ var commands = map[string]struct {
}
path = repl + path[len(orig):]
r.URL.Path = path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
r.URL.RawPath = ""
r.RequestURI = ""
return nil
})
}
},
},
CommandServe: {
@@ -150,14 +169,19 @@ var commands = map[string]struct {
"root": "the file system path to serve, must be an existing directory",
},
},
validate: validateFSPath,
build: func(args any) CommandHandler {
root := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
return nil
})
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
parsedArgs, err = validateFSPath(args)
return
},
build: func(args any) HandlerFunc {
root := args.(string)
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
return errTerminateRule
}
},
terminate: true,
},
CommandRedirect: {
help: Help{
@@ -170,14 +194,19 @@ var commands = map[string]struct {
"to": "the url to redirect to, can be relative or absolute URL",
},
},
validate: validateURL,
build: func(args any) CommandHandler {
target := args.(*nettypes.URL).String()
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
return nil
})
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
parsedArgs, err = validateURL(args)
return
},
build: func(args any) HandlerFunc {
target := args.(*nettypes.URL).String()
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
return errTerminateRule
}
},
terminate: true,
},
CommandRoute: {
help: Help{
@@ -190,15 +219,16 @@ var commands = map[string]struct {
"route": "the route to route to",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 1 {
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
}
return args[0], nil
return phase, args[0], nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
route := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
return func(w *httputils.ResponseModifier, req *http.Request, upstream http.HandlerFunc) error {
ep := entrypoint.FromCtx(req.Context())
r, ok := ep.HTTPRoutes().Get(route)
if !ok {
@@ -212,9 +242,10 @@ var commands = map[string]struct {
} else {
http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound)
}
return nil
})
return errTerminateRule
}
},
terminate: true,
},
CommandError: {
help: Help{
@@ -228,34 +259,40 @@ var commands = map[string]struct {
"text": "the error message to return",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 2 {
return nil, ErrExpectTwoArgs
return phase, nil, ErrExpectTwoArgs
}
codeStr, text := args[0], args[1]
code, err := strconv.Atoi(codeStr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
return phase, nil, ErrInvalidArguments.With(err)
}
if !httputils.IsStatusCodeValid(code) {
return nil, ErrInvalidArguments.Subject(codeStr)
return phase, nil, ErrInvalidArguments.Subject(codeStr)
}
textTmpl, err := validateTemplate(text, true)
tmplReq, textTmpl, err := validateTemplate(text, true)
if err != nil {
return nil, ErrInvalidArguments.With(err)
return phase, nil, ErrInvalidArguments.With(err)
}
return &Tuple[int, templateString]{code, textTmpl}, nil
phase |= tmplReq
return phase, &Tuple[int, templateString]{code, textTmpl}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
// error command should overwrite the response body
httputils.GetInitResponseModifier(w).ResetBody()
w.ResetBody()
w.WriteHeader(code)
err := textTmpl.ExpandVars(w, r, w)
return err
})
_, err := textTmpl.ExpandVars(w, r, w.BodyBuffer())
if err != nil {
return err
}
return errTerminateRule
}
},
terminate: true,
},
CommandRequireBasicAuth: {
help: Help{
@@ -268,20 +305,22 @@ var commands = map[string]struct {
"realm": "the authentication realm",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) == 1 {
return args[0], nil
return phase, args[0], nil
}
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
realm := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, realm))
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return nil
})
return errTerminateRule
}
},
terminate: true,
},
CommandProxy: {
help: Help{
@@ -294,14 +333,19 @@ var commands = map[string]struct {
"to": "the url to proxy to, must be an absolute URL",
},
},
validate: validateURL,
build: func(args any) CommandHandler {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
parsedArgs, err = validateURL(args)
return
},
build: func(args any) HandlerFunc {
target := args.(*nettypes.URL)
if target.Scheme == "" {
target.Scheme = "http"
}
if target.Host == "" {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
rawPath := target.EscapedPath()
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
url := target.URL
url.Host = routes.TryGetUpstreamHostPort(r)
if url.Host == "" {
@@ -309,18 +353,19 @@ var commands = map[string]struct {
}
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
r.URL.Path = target.Path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
r.URL.RawPath = rawPath
r.RequestURI = ""
rp.ServeHTTP(w, r)
return nil
})
return errTerminateRule
}
}
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
rp.ServeHTTP(w, r)
return nil
})
return errTerminateRule
}
},
terminate: true,
},
CommandSet: {
help: Help{
@@ -335,11 +380,11 @@ var commands = map[string]struct {
"value": "the value to set",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
return validateModField(ModFieldSet, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
build: func(args any) HandlerFunc {
return args.(HandlerFunc)
},
},
CommandAdd: {
@@ -355,11 +400,11 @@ var commands = map[string]struct {
"value": "the value to add",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
return validateModField(ModFieldAdd, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
build: func(args any) HandlerFunc {
return args.(HandlerFunc)
},
},
CommandRemove: {
@@ -374,15 +419,14 @@ var commands = map[string]struct {
"field": "the field to remove",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
return validateModField(ModFieldRemove, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
build: func(args any) HandlerFunc {
return args.(HandlerFunc)
},
},
CommandLog: {
isResponseHandler: true,
help: Help{
command: CommandLog,
description: makeLines(
@@ -399,28 +443,28 @@ var commands = map[string]struct {
"template": "the template to log",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 3 {
return nil, ErrExpectThreeArgs
return phase, nil, ErrExpectThreeArgs
}
tmpl, err := validateTemplate(args[2], true)
phase, tmpl, err := validateTemplate(args[2], true)
if err != nil {
return nil, err
return phase, nil, err
}
level, err := validateLevel(args[0])
if err != nil {
return nil, err
return phase, nil, err
}
// NOTE: file will stay opened forever
// it leverages accesslog.NewFileIO so
// it will be opened only once for the same path
f, err := openFile(args[1])
if err != nil {
return nil, err
return phase, nil, err
}
return &onLogArgs{level, f, tmpl}, nil
return phase, &onLogArgs{level, f, tmpl}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
level, f, tmpl := args.(*onLogArgs).Unpack()
var logger io.Writer
if f == stdout || f == stderr {
@@ -428,17 +472,16 @@ var commands = map[string]struct {
} else {
logger = f
}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
err := tmpl.ExpandVars(w, r, logger)
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
_, err := tmpl.ExpandVars(w, r, logger)
if err != nil {
return err
}
return nil
})
}
},
},
CommandNotify: {
isResponseHandler: true,
help: Help{
command: CommandNotify,
description: makeLines(
@@ -456,22 +499,24 @@ var commands = map[string]struct {
"body": "the body of the notification",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 4 {
return nil, ErrExpectFourArgs
return phase, nil, ErrExpectFourArgs
}
titleTmpl, err := validateTemplate(args[2], false)
req1, titleTmpl, err := validateTemplate(args[2], false)
if err != nil {
return nil, err
return phase, nil, err
}
bodyTmpl, err := validateTemplate(args[3], false)
req2, bodyTmpl, err := validateTemplate(args[3], false)
if err != nil {
return nil, err
return phase, nil, err
}
level, err := validateLevel(args[0])
if err != nil {
return nil, err
return phase, nil, err
}
phase |= req1 | req2
// TODO: validate provider
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
@@ -480,34 +525,34 @@ var commands = map[string]struct {
// if err != nil {
// return nil, err
// }
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
return phase, &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
to := []string{provider}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
var respBuf strings.Builder
err := titleTmpl.ExpandVars(w, r, respBuf)
_, err := titleTmpl.ExpandVars(w, r, &respBuf)
if err != nil {
return err
}
titleLen := respBuf.Len()
err = bodyTmpl.ExpandVars(w, r, respBuf)
_, err = bodyTmpl.ExpandVars(w, r, &respBuf)
if err != nil {
return err
}
b := respBuf.Bytes()
s := respBuf.String()
notif.Notify(&notif.LogMessage{
Level: level,
Title: string(b[:titleLen]),
Body: notif.MessageBodyBytes(b[titleLen:]),
Title: s[:titleLen],
Body: notif.MessageBodyBytes(s[titleLen:]),
To: to,
})
return nil
})
}
},
},
}
@@ -519,121 +564,29 @@ type (
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {
executors := make([]CommandHandler, 0)
isResponseHandler := false
for line := range strings.SplitSeq(v, "\n") {
if line == "" {
continue
}
directive, args, err := parse(line)
if err != nil {
return err
}
if directive == CommandPass || directive == CommandPassAlt {
if len(args) != 0 {
return ErrExpectNoArg
}
executors = append(executors, BypassCommand{})
continue
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
validArgs, err := builder.validate(args)
if err != nil {
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
return gperr.PrependSubject(err, directive).With(builder.help.Error())
}
handler := builder.build(validArgs)
executors = append(executors, handler)
if builder.isResponseHandler || handler.IsResponseHandler() {
isResponseHandler = true
}
executors, parseErr := parseDoWithBlocks(v)
if parseErr != nil {
return parseErr
}
if len(executors) == 0 {
cmd.raw = v
cmd.exec = nil
cmd.isResponseHandler = false
cmd.pre = nil
cmd.post = nil
return nil
}
exec, err := buildCmd(executors)
if err != nil {
return err
}
cmd.raw = v
cmd.exec = exec
if exec.IsResponseHandler() {
isResponseHandler = true
for _, executor := range executors {
if executor.Phase().IsPostRule() {
cmd.post = append(cmd.post, executor)
} else {
cmd.pre = append(cmd.pre, executor)
}
}
cmd.isResponseHandler = isResponseHandler
return nil
}
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
// Validate the execution order.
//
// This allows sequences like:
// route ws-api
// log info /dev/stdout "..."
// where the first command is request-phase and the last is response-phase.
lastNonResp := -1
seenResp := false
for i, exec := range executors {
if exec.IsResponseHandler() {
seenResp = true
continue
}
if seenResp {
return nil, ErrInvalidCommandSequence.Withf("response handlers must be the last commands")
}
lastNonResp = i
}
for i, exec := range executors {
if i > lastNonResp {
break // response-handler tail
}
switch exec.(type) {
case TerminatingCommand, BypassCommand:
if i != lastNonResp {
return nil, ErrInvalidCommandSequence.
Withf("a response handler or terminating/bypass command must be the last command")
}
}
}
return Commands(executors), nil
}
// Command is purely "bypass" or empty.
func (cmd *Command) isBypass() bool {
if cmd == nil {
return true
}
switch cmd := cmd.exec.(type) {
case BypassCommand:
return true
case Commands:
// bypass command is always the last one
_, ok := cmd[len(cmd)-1].(BypassCommand)
return ok
default:
return false
}
}
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
return cmd.exec.Handle(w, r)
}
func (cmd *Command) String() string {
return cmd.raw
}