mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 01:38:50 +02:00
refactor(rules): introduce block DSL, phase-based execution, and flow validation
- add block syntax parser/scanner with nested @blocks and elif/else support - restructure rule execution into explicit pre/post phases with phase flags - classify commands by phase and termination behavior - enforce flow semantics (default rule handling, dead-rule detection) - expand HTTP flow coverage with block + YAML parity tests and benches - refresh rules README/spec and update playground/docs integration
This commit is contained in:
@@ -4,9 +4,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
@@ -15,37 +22,36 @@ import (
|
||||
|
||||
type (
|
||||
/*
|
||||
Example:
|
||||
Rules is a list of rules.
|
||||
|
||||
proxy.app1.rules: |
|
||||
- name: default
|
||||
do: |
|
||||
rewrite / /index.html
|
||||
serve /var/www/goaccess
|
||||
- name: ws
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
Example:
|
||||
|
||||
proxy.app2.rules: |
|
||||
- name: default
|
||||
do: bypass
|
||||
- name: block POST and PUT
|
||||
on: method POST | method PUT
|
||||
do: error 403 Forbidden
|
||||
proxy.app1.rules: |
|
||||
- name: default
|
||||
do: |
|
||||
rewrite / /index.html
|
||||
serve /var/www/goaccess
|
||||
- name: ws
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
|
||||
proxy.app2.rules: |
|
||||
- name: default
|
||||
do: bypass
|
||||
- name: block POST and PUT
|
||||
on: method POST | method PUT
|
||||
do: error 403 Forbidden
|
||||
*/
|
||||
//nolint:recvcheck
|
||||
Rules []Rule
|
||||
/*
|
||||
Rule is a rule for a reverse proxy.
|
||||
It do `Do` when `On` matches.
|
||||
|
||||
A rule can have multiple lines of on.
|
||||
|
||||
All lines of on must match,
|
||||
but each line can have multiple checks that
|
||||
one match means this line is matched.
|
||||
*/
|
||||
// Rule represents a reverse proxy rule.
|
||||
// The `Do` field is executed when `On` matches.
|
||||
//
|
||||
// - A rule may have multiple lines in the `On` section.
|
||||
// - All `On` lines must match for the rule to trigger.
|
||||
// - Each line can have several checks—one match per line is enough for that line.
|
||||
Rule struct {
|
||||
Name string `json:"name"`
|
||||
On RuleOn `json:"on" swaggertype:"string"`
|
||||
@@ -53,103 +59,230 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (rule *Rule) IsResponseRule() bool {
|
||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
||||
func isDefaultRule(rule Rule) bool {
|
||||
return rule.Name == "default" || rule.On.raw == OnDefault
|
||||
}
|
||||
|
||||
func (rules Rules) Validate() error {
|
||||
func (rules Rules) Validate() gperr.Error {
|
||||
var defaultRulesFound []int
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
if isDefaultRule(rule) {
|
||||
defaultRulesFound = append(defaultRulesFound, i)
|
||||
}
|
||||
if rules[i].Name == "" {
|
||||
// set name to index if name is empty
|
||||
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
}
|
||||
if len(defaultRulesFound) > 1 {
|
||||
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
||||
}
|
||||
for i := range rules {
|
||||
r1 := rules[i]
|
||||
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
||||
continue
|
||||
}
|
||||
sig1, ok := matcherSignature(r1.On.raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(rules); j++ {
|
||||
r2 := rules[j]
|
||||
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
||||
continue
|
||||
}
|
||||
sig2, ok := matcherSignature(r2.On.raw)
|
||||
if !ok || sig1 != sig2 {
|
||||
continue
|
||||
}
|
||||
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rule Rule) doesTerminateInPre() bool {
|
||||
for _, cmd := range rule.Do.pre {
|
||||
handler, ok := cmd.(Handler)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if handler.Terminates() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matcherSignature(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
andParts := splitAnd(raw)
|
||||
if len(andParts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
canonAnd := make([]string, 0, len(andParts))
|
||||
for _, andPart := range andParts {
|
||||
orParts := splitPipe(andPart)
|
||||
if len(orParts) == 0 {
|
||||
continue
|
||||
}
|
||||
canonOr := make([]string, 0, len(orParts))
|
||||
for _, atom := range orParts {
|
||||
subject, args, err := parse(strings.TrimSpace(atom))
|
||||
if err != nil || subject == "" {
|
||||
return "", false
|
||||
}
|
||||
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
||||
}
|
||||
slices.Sort(canonOr)
|
||||
canonOr = slices.Compact(canonOr)
|
||||
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
||||
}
|
||||
|
||||
slices.Sort(canonAnd)
|
||||
canonAnd = slices.Compact(canonAnd)
|
||||
if len(canonAnd) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(canonAnd, "&"), true
|
||||
}
|
||||
|
||||
// Parse parses a rule configuration string.
|
||||
// It first tries the block syntax (if the string contains a top-level '{'),
|
||||
// then falls back to YAML syntax.
|
||||
func (rules *Rules) Parse(config string) error {
|
||||
config = strings.TrimSpace(config)
|
||||
if config == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prefer block syntax if it looks like block syntax.
|
||||
if hasTopLevelLBrace(config) {
|
||||
blockRules, err := parseBlockRules(config)
|
||||
if err == nil {
|
||||
*rules = blockRules
|
||||
return nil
|
||||
}
|
||||
// Fall through to YAML (backward compatibility).
|
||||
}
|
||||
|
||||
// YAML fallback
|
||||
var anySlice []any
|
||||
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
||||
if yamlErr == nil {
|
||||
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
||||
}
|
||||
|
||||
// If YAML fails and we didn't try block syntax yet, try it now.
|
||||
blockRules, err := parseBlockRules(config)
|
||||
if err == nil {
|
||||
*rules = blockRules
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
||||
// Used to decide whether to prioritize the block syntax.
|
||||
func hasTopLevelLBrace(s string) bool {
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
||||
inBlock = false
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
if quote != '`' && c == '\\' && i+1 < len(s) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if rune(c) == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch c {
|
||||
case '\'', '"', '`':
|
||||
quote = rune(c)
|
||||
continue
|
||||
case '{':
|
||||
return true
|
||||
case '#':
|
||||
inLine = true
|
||||
continue
|
||||
case '/':
|
||||
if i+1 < len(s) && s[i+1] == '/' {
|
||||
inLine = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if i+1 < len(s) && s[i+1] == '*' {
|
||||
inBlock = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
default:
|
||||
if unicode.IsSpace(rune(c)) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
||||
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
||||
if len(rules) == 0 {
|
||||
return up
|
||||
}
|
||||
|
||||
defaultRule := Rule{
|
||||
Name: "default",
|
||||
Do: Command{
|
||||
raw: "pass",
|
||||
exec: BypassCommand{},
|
||||
},
|
||||
}
|
||||
var defaultRule *Rule
|
||||
|
||||
var nonDefaultRules Rules
|
||||
hasDefaultRule := false
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
defaultRule = rule
|
||||
hasDefaultRule = true
|
||||
for _, rule := range rules {
|
||||
if isDefaultRule(rule) {
|
||||
r := rule
|
||||
defaultRule = &r
|
||||
} else {
|
||||
// set name to index if name is empty
|
||||
if rule.Name == "" {
|
||||
rule.Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
nonDefaultRules = append(nonDefaultRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonDefaultRules) == 0 {
|
||||
if defaultRule.Do.isBypass() {
|
||||
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
||||
return up
|
||||
}
|
||||
if defaultRule.IsResponseRule() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
}
|
||||
}()
|
||||
w = rm
|
||||
up(w, r)
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
}
|
||||
}()
|
||||
w = rm
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err == nil {
|
||||
up(w, r)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preRules := make(Rules, 0, len(nonDefaultRules))
|
||||
postRules := make(Rules, 0, len(nonDefaultRules))
|
||||
for _, rule := range nonDefaultRules {
|
||||
if rule.IsResponseRule() {
|
||||
postRules = append(postRules, rule)
|
||||
} else {
|
||||
preRules = append(preRules, rule)
|
||||
}
|
||||
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||
return cmd.pre.ServeHTTP(w, r, up)
|
||||
}
|
||||
|
||||
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
|
||||
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
||||
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||
return cmd.post.ServeHTTP(w, r, up)
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
@@ -159,104 +292,84 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}()
|
||||
|
||||
w = rm
|
||||
var hasError bool
|
||||
|
||||
shouldCallUpstream := true
|
||||
preMatched := false
|
||||
preRules := make(Rules, 0, len(nonDefaultRules)+1)
|
||||
if defaultRule != nil {
|
||||
preRules = append(preRules, *defaultRule)
|
||||
}
|
||||
preRules = append(preRules, nonDefaultRules...)
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
executedPre := make([]bool, len(preRules))
|
||||
terminatedInPre := make([]bool, len(preRules))
|
||||
preTerminated := false
|
||||
for i, rule := range preRules {
|
||||
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||
continue
|
||||
}
|
||||
if preTerminated {
|
||||
// Preserve post-only commands (e.g. logging) even after
|
||||
// pre-phase termination.
|
||||
if len(rule.Do.pre) == 0 {
|
||||
executedPre[i] = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
for _, rule := range preRules {
|
||||
if rule.Check(w, r) {
|
||||
preMatched = true
|
||||
if rule.Do.isBypass() {
|
||||
break // post rules should still execute
|
||||
}
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &rule, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
break
|
||||
}
|
||||
executedPre[i] = true
|
||||
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
terminatedInPre[i] = true
|
||||
preTerminated = true
|
||||
continue
|
||||
}
|
||||
logError(err, r)
|
||||
hasError = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
return
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
if !rm.HasStatus() {
|
||||
if hasError {
|
||||
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
||||
up(rm, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Run post commands for rules that actually executed in pre phase,
|
||||
// unless that same rule terminated in pre phase.
|
||||
for i, rule := range preRules {
|
||||
if !executedPre[i] || terminatedInPre[i] {
|
||||
continue
|
||||
}
|
||||
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
logError(err, r)
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
up(w, r)
|
||||
}
|
||||
|
||||
// if no post rules, we are done here
|
||||
if len(postRules) == 0 && !isDefaultRulePost {
|
||||
return
|
||||
}
|
||||
|
||||
for _, rule := range postRules {
|
||||
if rule.Check(w, r) {
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &rule, err)
|
||||
}
|
||||
return
|
||||
// Run true post-matcher rules after response is available.
|
||||
for _, rule := range nonDefaultRules {
|
||||
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||
continue
|
||||
}
|
||||
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
||||
// as "pre" for requirement purposes still need to run in this phase.
|
||||
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
logError(err, r)
|
||||
}
|
||||
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
logError(err, r)
|
||||
}
|
||||
}
|
||||
|
||||
if isDefaultRulePost {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) {
|
||||
// rm.AppendError("rule: %s, error: %w", rule.Name, err)
|
||||
}
|
||||
|
||||
func isTerminatingHandler(handler CommandHandler) bool {
|
||||
switch h := handler.(type) {
|
||||
case TerminatingCommand:
|
||||
return true
|
||||
case Commands:
|
||||
if len(h) == 0 {
|
||||
return false
|
||||
}
|
||||
return isTerminatingHandler(h[len(h)-1])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,34 +377,30 @@ func (rule *Rule) String() string {
|
||||
return rule.Name
|
||||
}
|
||||
|
||||
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
if rule.On.checker == nil {
|
||||
return true
|
||||
}
|
||||
v := rule.On.checker.Check(w, r)
|
||||
return v
|
||||
}
|
||||
|
||||
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return rule.Do.exec.Handle(w, r)
|
||||
return rule.On.Check(w, r)
|
||||
}
|
||||
|
||||
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
||||
var errStreamClosed error
|
||||
|
||||
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
||||
var errClientDisconnected error
|
||||
|
||||
func logError(err error, r *http.Request) {
|
||||
if errors.Is(err, errStreamClosed) {
|
||||
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
||||
return
|
||||
}
|
||||
var h2Err http2.StreamError
|
||||
if errors.As(err, &h2Err) {
|
||||
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
||||
// ignore these errors
|
||||
if h2Err.Code == http2.ErrCodeStreamClosed {
|
||||
return
|
||||
}
|
||||
}
|
||||
var h3Err *http3.Error
|
||||
if errors.As(err, &h3Err) {
|
||||
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
||||
// ignore these errors
|
||||
switch h3Err.ErrorCode {
|
||||
case
|
||||
|
||||
Reference in New Issue
Block a user