mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-18 15:34:38 +01:00
- add block syntax parser/scanner with nested @blocks and elif/else support - restructure rule execution into explicit pre/post phases with phase flags - classify commands by phase and termination behavior - enforce flow semantics (default rule handling, dead-rule detection) - expand HTTP flow coverage with block + YAML parity tests and benches - refresh rules README/spec and update playground/docs integration
414 lines
9.3 KiB
Go
414 lines
9.3 KiB
Go
package rules
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/quic-go/quic-go/http3"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/yusing/godoxy/internal/serialization"
|
|
gperr "github.com/yusing/goutils/errs"
|
|
httputils "github.com/yusing/goutils/http"
|
|
"golang.org/x/net/http2"
|
|
|
|
_ "unsafe"
|
|
)
|
|
|
|
type (
|
|
/*
|
|
Rules is a list of rules.
|
|
|
|
Example:
|
|
|
|
proxy.app1.rules: |
|
|
- name: default
|
|
do: |
|
|
rewrite / /index.html
|
|
serve /var/www/goaccess
|
|
- name: ws
|
|
on: |
|
|
header Connection Upgrade
|
|
header Upgrade websocket
|
|
do: bypass
|
|
|
|
proxy.app2.rules: |
|
|
- name: default
|
|
do: bypass
|
|
- name: block POST and PUT
|
|
on: method POST | method PUT
|
|
do: error 403 Forbidden
|
|
*/
|
|
//nolint:recvcheck
|
|
Rules []Rule
|
|
// Rule represents a reverse proxy rule.
|
|
// The `Do` field is executed when `On` matches.
|
|
//
|
|
// - A rule may have multiple lines in the `On` section.
|
|
// - All `On` lines must match for the rule to trigger.
|
|
// - Each line can have several checks—one match per line is enough for that line.
|
|
Rule struct {
|
|
Name string `json:"name"`
|
|
On RuleOn `json:"on" swaggertype:"string"`
|
|
Do Command `json:"do" swaggertype:"string"`
|
|
}
|
|
)
|
|
|
|
func isDefaultRule(rule Rule) bool {
|
|
return rule.Name == "default" || rule.On.raw == OnDefault
|
|
}
|
|
|
|
func (rules Rules) Validate() gperr.Error {
|
|
var defaultRulesFound []int
|
|
for i := range rules {
|
|
rule := rules[i]
|
|
if isDefaultRule(rule) {
|
|
defaultRulesFound = append(defaultRulesFound, i)
|
|
}
|
|
if rules[i].Name == "" {
|
|
// set name to index if name is empty
|
|
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
|
}
|
|
}
|
|
if len(defaultRulesFound) > 1 {
|
|
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
|
}
|
|
for i := range rules {
|
|
r1 := rules[i]
|
|
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
|
continue
|
|
}
|
|
sig1, ok := matcherSignature(r1.On.raw)
|
|
if !ok {
|
|
continue
|
|
}
|
|
for j := i + 1; j < len(rules); j++ {
|
|
r2 := rules[j]
|
|
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
|
continue
|
|
}
|
|
sig2, ok := matcherSignature(r2.On.raw)
|
|
if !ok || sig1 != sig2 {
|
|
continue
|
|
}
|
|
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (rule Rule) doesTerminateInPre() bool {
|
|
for _, cmd := range rule.Do.pre {
|
|
handler, ok := cmd.(Handler)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if handler.Terminates() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func matcherSignature(raw string) (string, bool) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", false
|
|
}
|
|
|
|
andParts := splitAnd(raw)
|
|
if len(andParts) == 0 {
|
|
return "", false
|
|
}
|
|
|
|
canonAnd := make([]string, 0, len(andParts))
|
|
for _, andPart := range andParts {
|
|
orParts := splitPipe(andPart)
|
|
if len(orParts) == 0 {
|
|
continue
|
|
}
|
|
canonOr := make([]string, 0, len(orParts))
|
|
for _, atom := range orParts {
|
|
subject, args, err := parse(strings.TrimSpace(atom))
|
|
if err != nil || subject == "" {
|
|
return "", false
|
|
}
|
|
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
|
}
|
|
slices.Sort(canonOr)
|
|
canonOr = slices.Compact(canonOr)
|
|
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
|
}
|
|
|
|
slices.Sort(canonAnd)
|
|
canonAnd = slices.Compact(canonAnd)
|
|
if len(canonAnd) == 0 {
|
|
return "", false
|
|
}
|
|
return strings.Join(canonAnd, "&"), true
|
|
}
|
|
|
|
// Parse parses a rule configuration string.
|
|
// It first tries the block syntax (if the string contains a top-level '{'),
|
|
// then falls back to YAML syntax.
|
|
func (rules *Rules) Parse(config string) error {
|
|
config = strings.TrimSpace(config)
|
|
if config == "" {
|
|
return nil
|
|
}
|
|
|
|
// Prefer block syntax if it looks like block syntax.
|
|
if hasTopLevelLBrace(config) {
|
|
blockRules, err := parseBlockRules(config)
|
|
if err == nil {
|
|
*rules = blockRules
|
|
return nil
|
|
}
|
|
// Fall through to YAML (backward compatibility).
|
|
}
|
|
|
|
// YAML fallback
|
|
var anySlice []any
|
|
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
|
if yamlErr == nil {
|
|
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
|
}
|
|
|
|
// If YAML fails and we didn't try block syntax yet, try it now.
|
|
blockRules, err := parseBlockRules(config)
|
|
if err == nil {
|
|
*rules = blockRules
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
|
// Used to decide whether to prioritize the block syntax.
|
|
func hasTopLevelLBrace(s string) bool {
|
|
quote := rune(0)
|
|
inLine := false
|
|
inBlock := false
|
|
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
|
|
if inLine {
|
|
if c == '\n' {
|
|
inLine = false
|
|
}
|
|
continue
|
|
}
|
|
if inBlock {
|
|
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
|
inBlock = false
|
|
i++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if quote != 0 {
|
|
if quote != '`' && c == '\\' && i+1 < len(s) {
|
|
i++
|
|
continue
|
|
}
|
|
if rune(c) == quote {
|
|
quote = 0
|
|
}
|
|
continue
|
|
}
|
|
|
|
switch c {
|
|
case '\'', '"', '`':
|
|
quote = rune(c)
|
|
continue
|
|
case '{':
|
|
return true
|
|
case '#':
|
|
inLine = true
|
|
continue
|
|
case '/':
|
|
if i+1 < len(s) && s[i+1] == '/' {
|
|
inLine = true
|
|
i++
|
|
continue
|
|
}
|
|
if i+1 < len(s) && s[i+1] == '*' {
|
|
inBlock = true
|
|
i++
|
|
continue
|
|
}
|
|
default:
|
|
if unicode.IsSpace(rune(c)) {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
|
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
|
if len(rules) == 0 {
|
|
return up
|
|
}
|
|
|
|
var defaultRule *Rule
|
|
|
|
var nonDefaultRules Rules
|
|
for _, rule := range rules {
|
|
if isDefaultRule(rule) {
|
|
r := rule
|
|
defaultRule = &r
|
|
} else {
|
|
nonDefaultRules = append(nonDefaultRules, rule)
|
|
}
|
|
}
|
|
|
|
if len(nonDefaultRules) == 0 {
|
|
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
|
return up
|
|
}
|
|
}
|
|
|
|
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
|
return cmd.pre.ServeHTTP(w, r, up)
|
|
}
|
|
|
|
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
|
return cmd.post.ServeHTTP(w, r, up)
|
|
}
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
rm := httputils.NewResponseModifier(w)
|
|
defer func() {
|
|
if _, err := rm.FlushRelease(); err != nil {
|
|
logError(err, r)
|
|
}
|
|
}()
|
|
|
|
var hasError bool
|
|
|
|
preRules := make(Rules, 0, len(nonDefaultRules)+1)
|
|
if defaultRule != nil {
|
|
preRules = append(preRules, *defaultRule)
|
|
}
|
|
preRules = append(preRules, nonDefaultRules...)
|
|
|
|
executedPre := make([]bool, len(preRules))
|
|
terminatedInPre := make([]bool, len(preRules))
|
|
preTerminated := false
|
|
for i, rule := range preRules {
|
|
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
|
continue
|
|
}
|
|
if preTerminated {
|
|
// Preserve post-only commands (e.g. logging) even after
|
|
// pre-phase termination.
|
|
if len(rule.Do.pre) == 0 {
|
|
executedPre[i] = true
|
|
}
|
|
continue
|
|
}
|
|
|
|
executedPre[i] = true
|
|
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
terminatedInPre[i] = true
|
|
preTerminated = true
|
|
continue
|
|
}
|
|
logError(err, r)
|
|
hasError = true
|
|
}
|
|
}
|
|
|
|
if !rm.HasStatus() {
|
|
if hasError {
|
|
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
|
up(rm, r)
|
|
}
|
|
}
|
|
|
|
// Run post commands for rules that actually executed in pre phase,
|
|
// unless that same rule terminated in pre phase.
|
|
for i, rule := range preRules {
|
|
if !executedPre[i] || terminatedInPre[i] {
|
|
continue
|
|
}
|
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
logError(err, r)
|
|
}
|
|
}
|
|
|
|
// Run true post-matcher rules after response is available.
|
|
for _, rule := range nonDefaultRules {
|
|
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
|
continue
|
|
}
|
|
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
|
// as "pre" for requirement purposes still need to run in this phase.
|
|
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
logError(err, r)
|
|
}
|
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
|
if errors.Is(err, errTerminateRule) {
|
|
continue
|
|
}
|
|
logError(err, r)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rule *Rule) String() string {
|
|
return rule.Name
|
|
}
|
|
|
|
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
|
if rule.On.checker == nil {
|
|
return true
|
|
}
|
|
return rule.On.Check(w, r)
|
|
}
|
|
|
|
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
|
var errStreamClosed error
|
|
|
|
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
|
var errClientDisconnected error
|
|
|
|
func logError(err error, r *http.Request) {
|
|
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
|
return
|
|
}
|
|
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
|
// ignore these errors
|
|
if h2Err.Code == http2.ErrCodeStreamClosed {
|
|
return
|
|
}
|
|
}
|
|
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
|
// ignore these errors
|
|
switch h3Err.ErrorCode {
|
|
case
|
|
http3.ErrCodeNoError,
|
|
http3.ErrCodeRequestCanceled:
|
|
return
|
|
}
|
|
}
|
|
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
|
}
|