Files
godoxy-yusing/internal/route/rules/on.go
yusing a0adc51269 feat(rules): support multiline or |
treat lines ending with unquoted `|` or `&` as continued
conditions in `do` block headers so nested blocks parse correctly
across line breaks.

update `on` condition splitting to avoid breaking on newlines that
follow an unescaped trailing pipe, while still respecting quotes,
escapes, and bracket nesting.

add coverage for multiline `|`/`&` continuations in `do` parsing,
`splitAnd`, `parseOn`, and HTTP flow nested block behavior.
2026-02-28 18:16:04 +08:00

740 lines
18 KiB
Go

package rules
import (
"net"
"net/http"
"slices"
"strings"
"github.com/yusing/godoxy/internal/route/routes"
gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http"
)
type RuleOn struct {
raw string
checker Checker
phase PhaseFlag
}
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
if on.checker == nil {
return true
}
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
}
// on request
const (
OnDefault = "default"
OnHeader = "header"
OnQuery = "query"
OnCookie = "cookie"
OnForm = "form"
OnPostForm = "postform"
OnProto = "proto"
OnMethod = "method"
OnHost = "host"
OnPath = "path"
OnRemote = "remote"
OnBasicAuth = "basic_auth"
OnRoute = "route"
)
// on response
const (
OnResponseHeader = "resp_header"
OnStatus = "status"
)
var checkers = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) CheckFunc
}{
OnDefault: {
help: Help{
command: OnDefault,
description: makeLines(
"Select the default (fallback) rule.",
),
args: map[string]string{},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 0 {
return phase, nil, ErrExpectNoArg
}
return phase, nil, nil
},
builder: func(args any) CheckFunc {
return func(w *httputils.ResponseModifier, r *http.Request) bool { return true }
},
},
OnHeader: {
help: Help{
command: OnHeader,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnHeader, "username", "user"),
helpExample(OnHeader, "username", helpFuncCall("glob", "user*")),
helpExample(OnHeader, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the header key",
"[value]": "the header value",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(r.Header[k]) > 0
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(r.Header[k], matcher)
}
},
},
OnResponseHeader: {
help: Help{
command: OnResponseHeader,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnResponseHeader, "username", "user"),
helpExample(OnResponseHeader, "username", helpFuncCall("glob", "user*")),
helpExample(OnResponseHeader, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the response header key",
"[value]": "the response header value",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePost
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(w.Header()[k]) > 0
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(w.Header()[k], matcher)
}
},
},
OnQuery: {
help: Help{
command: OnQuery,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnQuery, "username", "user"),
helpExample(OnQuery, "username", helpFuncCall("glob", "user*")),
helpExample(OnQuery, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the query key",
"[value]": "the query value",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(w.SharedData().GetQueries(r)[k]) > 0
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
}
},
},
OnCookie: {
help: Help{
command: OnCookie,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnCookie, "username", "user"),
helpExample(OnCookie, "username", helpFuncCall("glob", "user*")),
helpExample(OnCookie, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the cookie key",
"[value]": "the cookie value",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
cookies := w.SharedData().GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
return true
}
}
return false
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
cookies := w.SharedData().GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
if matcher(cookie.Value) {
return true
}
}
}
return false
}
},
},
//nolint:dupl
OnForm: {
help: Help{
command: OnForm,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnForm, "username", "user"),
helpExample(OnForm, "username", helpFuncCall("glob", "user*")),
helpExample(OnForm, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the form key",
"[value]": "the form value",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.FormValue(k) != ""
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.FormValue(k))
}
},
},
OnPostForm: {
help: Help{
command: OnPostForm,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnPostForm, "username", "user"),
helpExample(OnPostForm, "username", helpFuncCall("glob", "user*")),
helpExample(OnPostForm, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the form key",
"[value]": "the form value",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.PostFormValue(k) != ""
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.PostFormValue(k))
}
},
},
OnProto: {
help: Help{
command: OnProto,
args: map[string]string{
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 1 {
return phase, nil, ErrExpectOneArg
}
proto := args[0]
switch proto {
case "http", "https", "h1", "h2", "h2c", "h3":
return phase, proto, nil
default:
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
}
},
builder: func(args any) CheckFunc {
proto := args.(string)
switch proto {
case "http":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil
}
case "https":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil
}
case "h1":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil && r.ProtoMajor == 1
}
case "h2":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil && r.ProtoMajor == 2
}
case "h2c":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil && r.ProtoMajor == 2
}
default: // h3
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil && r.ProtoMajor == 3
}
}
},
},
OnMethod: {
help: Help{
command: OnMethod,
args: map[string]string{
"method": "the http method",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateMethod(args)
return
},
builder: func(args any) CheckFunc {
method := args.(string)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.Method == method
}
},
},
OnHost: {
help: Help{
command: OnHost,
description: makeLines(
"Supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnHost, "example.com"),
helpExample(OnHost, helpFuncCall("glob", "example*.com")),
helpExample(OnHost, helpFuncCall("regex", `(example\w+\.com)`)),
helpExample(OnHost, helpFuncCall("regex", `example\.com$`)),
),
args: map[string]string{
"host": "the host name",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateSingleMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.Host)
}
},
},
OnPath: {
help: Help{
command: OnPath,
description: makeLines(
"Supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnPath, "/path/to"),
helpExample(OnPath, helpFuncCall("glob", "/path/to/*")),
helpExample(OnPath, helpFuncCall("regex", `^/path/to/.*$`)),
helpExample(OnPath, helpFuncCall("regex", `/path/[A-Z]+/`)),
),
args: map[string]string{
"path": "the request path",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateURLPathMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
}
return matcher(reqPath)
}
},
},
OnRemote: {
help: Help{
command: OnRemote,
args: map[string]string{
"ip|cidr": "the remote ip or cidr",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateCIDR(args)
return
},
builder: func(args any) CheckFunc {
ipnet := args.(*net.IPNet)
// for /32 (IPv4) or /128 (IPv6), just compare the IP
if ones, bits := ipnet.Mask.Size(); ones == bits {
wantIP := ipnet.IP
return func(w *httputils.ResponseModifier, r *http.Request) bool {
ip := w.SharedData().GetRemoteIP(r)
if ip == nil {
return false
}
return ip.Equal(wantIP)
}
}
return func(w *httputils.ResponseModifier, r *http.Request) bool {
ip := w.SharedData().GetRemoteIP(r)
if ip == nil {
return false
}
return ipnet.Contains(ip)
}
},
},
OnBasicAuth: {
help: Help{
command: OnBasicAuth,
args: map[string]string{
"username": "the username",
"password": "the password encrypted with bcrypt",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateUserBCryptPassword(args)
return
},
builder: func(args any) CheckFunc {
cred := args.(*HashedCrendentials)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return cred.Match(w.SharedData().GetBasicAuth(r))
}
},
},
OnRoute: {
help: Help{
command: OnRoute,
description: makeLines(
"Supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnRoute, "example"),
helpExample(OnRoute, helpFuncCall("glob", "example*")),
helpExample(OnRoute, helpFuncCall("regex", "example\\w+")),
),
args: map[string]string{
"route": "the route name",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateSingleMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(routes.TryGetUpstreamName(r))
}
},
},
OnStatus: {
help: Help{
command: OnStatus,
description: makeLines(
"Supported formats are:",
helpExample(OnStatus, "<status>"),
helpExample(OnStatus, "<status>-<status>"),
helpExample(OnStatus, "1xx"),
helpExample(OnStatus, "2xx"),
helpExample(OnStatus, "3xx"),
helpExample(OnStatus, "4xx"),
helpExample(OnStatus, "5xx"),
),
args: map[string]string{
"status": "the status code range",
},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePost
parsedArgs, err = validateStatusRange(args)
return
},
builder: func(args any) CheckFunc {
beg, end := args.(*IntTuple).Unpack()
if beg == end {
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
return w.StatusCode() == beg
}
}
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
statusCode := w.StatusCode()
return statusCode >= beg && statusCode <= end
}
},
},
}
var (
asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
andSeps = [256]uint8{'&': 1, '\n': 1}
)
// splitAnd splits a condition string into AND parts.
// It treats '&' and newline as AND separators, except when a line ends with
// an unescaped '|' (OR continuation), where the newline stays in the same part.
// Empty parts are omitted.
func splitAnd(s string) []string {
if s == "" {
return []string{}
}
result := []string{}
forEachAndPart(s, func(part string) {
result = append(result, part)
})
return result
}
func lineEndsWithUnescapedPipe(s string, start, end int) bool {
for i := end - 1; i >= start; i-- {
if asciiSpace[s[i]] != 0 {
continue
}
if s[i] != '|' {
return false
}
escapes := 0
for j := i - 1; j >= start && s[j] == '\\'; j-- {
escapes++
}
return escapes%2 == 0
}
return false
}
func advanceSplitState(s string, i *int, quote *byte, brackets *int) bool {
c := s[*i]
if *quote != 0 {
if c == '\\' && *i+1 < len(s) {
*i++
return true
}
if c == *quote {
*quote = 0
}
return true
}
switch c {
case '\\':
if *i+1 < len(s) {
*i++
return true
}
case '"', '\'', '`':
*quote = c
return true
case '(':
*brackets++
return true
case ')':
if *brackets > 0 {
*brackets--
}
return true
}
return false
}
// splitPipe splits a string by "|" but respects quotes, brackets, and escaped characters.
// It's similar to the parser.go logic but specifically for pipe splitting.
func splitPipe(s string) []string {
if s == "" {
return []string{}
}
result := []string{}
forEachPipePart(s, func(part string) {
result = append(result, part)
})
return result
}
func forEachAndPart(s string, fn func(part string)) {
quote := byte(0)
brackets := 0
start := 0
for i := 0; i <= len(s); i++ {
if i < len(s) {
c := s[i]
if advanceSplitState(s, &i, &quote, &brackets) {
continue
}
if c == '\n' {
if brackets > 0 || lineEndsWithUnescapedPipe(s, start, i) {
continue
}
} else if c != '&' || brackets > 0 {
continue
}
}
if i < len(s) && andSeps[s[i]] == 0 {
continue
}
part := strings.TrimSpace(s[start:i])
if part != "" {
fn(part)
}
start = i + 1
}
}
func forEachPipePart(s string, fn func(part string)) {
quote := byte(0)
brackets := 0
start := 0
for i := 0; i < len(s); i++ {
if advanceSplitState(s, &i, &quote, &brackets) {
continue
}
if s[i] == '|' && brackets == 0 {
if part := strings.TrimSpace(s[start:i]); part != "" {
fn(part)
}
start = i + 1
}
}
if start < len(s) {
if part := strings.TrimSpace(s[start:]); part != "" {
fn(part)
}
}
}
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error {
on.raw = v
ruleCount := 0
forEachAndPart(v, func(_ string) {
ruleCount++
})
checkAnd := make(CheckMatchAll, 0, ruleCount)
errs := gperr.NewBuilder("rule.on syntax errors")
i := 0
forEachAndPart(v, func(rule string) {
i++
parsed, phase, err := parseOn(rule)
if err != nil {
errs.AddSubjectf(err, "line %d", i)
return
}
on.phase |= phase
checkAnd = append(checkAnd, parsed)
})
on.checker = checkAnd
return errs.Error()
}
func (on *RuleOn) String() string {
return on.raw
}
func (on *RuleOn) MarshalText() ([]byte, error) {
return []byte(on.String()), nil
}
func parseOn(line string) (Checker, PhaseFlag, error) {
orCount := 0
forEachPipePart(line, func(_ string) {
orCount++
})
if orCount > 1 {
var phase PhaseFlag
errs := gperr.NewBuilder("rule.on syntax errors")
checkOr := make(CheckMatchSingle, orCount)
i := 0
forEachPipePart(line, func(or string) {
i++
checkFunc, req, err := parseOnAtom(or)
if err != nil {
errs.AddSubjectf(err, "or[%d]", i)
return
}
checkOr[i-1] = checkFunc
phase |= req
})
if err := errs.Error(); err != nil {
return nil, phase, err
}
return checkOr, phase, nil
}
return parseOnAtom(line)
}
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
var phase PhaseFlag
subject, args, err := parse(line)
if err != nil {
return nil, phase, err
}
negate := false
if strings.HasPrefix(subject, "!") {
negate = true
subject = subject[1:]
}
checker, ok := checkers[subject]
if !ok {
return nil, phase, ErrInvalidOnTarget.Subject(subject)
}
req, validArgs, err := checker.validate(args)
if err != nil {
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
}
phase |= req
checkFunc := checker.builder(validArgs)
if negate {
origCheckFunc := checkFunc
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
return !origCheckFunc(w, r)
}
}
return checkFunc, phase, nil
}