refactor(rules): introduce block DSL, phase-based execution, and flow validation

- add block syntax parser/scanner with nested @blocks and elif/else support
- restructure rule execution into explicit pre/post phases with phase flags
- classify commands by phase and termination behavior
- enforce flow semantics (default rule handling, dead-rule detection)
- expand HTTP flow coverage with block + YAML parity tests and benches
- refresh rules README/spec and update playground/docs integration
This commit is contained in:
yusing
2026-02-23 22:24:15 +08:00
parent 0850ea3918
commit faecbab2cb
34 changed files with 4691 additions and 1057 deletions

View File

@@ -298,7 +298,7 @@ func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
On: onStr,
Do: doStr,
ValidationError: validationErr,
IsResponseRule: rule.IsResponseRule(),
// IsResponseRule: rule.Requirement()&rules.RequirementFlagResponse != 0,
})
// Only add valid rules to execution list

View File

@@ -36,15 +36,15 @@ type Rule struct {
}
type RuleOn struct {
raw string
checker Checker
isResponseChecker bool
raw string
checker Checker
phase PhaseFlag
}
type Command struct {
raw string
exec CommandHandler
isResponseHandler bool
raw string
pre Commands
post Commands
}
```
@@ -59,6 +59,9 @@ func ParseRules(config string) (Rules, error)
// ValidateRules validates rule syntax
func ValidateRules(config string) error
// Validate validates rule semantics (e.g., prevents multiple default rules)
func (rules Rules) Validate() gperr.Error
```
## Architecture
@@ -122,16 +125,52 @@ sequenceDiagram
Pre->>Pre: Execute handler
alt Terminating action
Pre-->>Req: Response
Return-->>Req: Return immediately
Note right of Pre: Stop remaining pre commands
end
end
Req->>Proxy: Forward request
Proxy-->>Req: Response
Req->>Post: Check post-rules
Post->>Post: Execute handlers
Post-->>Req: Modified response
opt No pre termination
Req->>Proxy: Forward request
Proxy-->>Req: Response
end
Req->>Post: Run scheduled post commands
Req->>Post: Evaluate response matchers
Post->>Post: Execute matched post handlers
Post-->>Req: Final response
```
### Execution Model (Authoritative)
Rules run in two phases:
1. **Pre phase**
- Evaluate only request-based matchers (`path`, `method`, `header`, `remote`, etc.) in declaration order.
- Execute matched rule `do` pre-commands in order.
- If a default rule exists (`name: default` or `on: default`), it is evaluated first as a baseline rule.
- If a terminating action runs, stop:
- remaining commands in that rule
- all later pre-phase commands.
- Exception: rules that only contain post commands (no pre commands) are still scheduled for post phase.
2. **Upstream phase**
- Upstream is called only if pre phase did not terminate.
3. **Post phase**
- Run post-commands for rules whose pre phase executed, except rules that terminated in pre.
- Then evaluate response-based matchers (`status`, `resp_header`) and execute their `do` commands.
- Response-based rules run even when the response was produced in pre phase.
**Important:** termination is explicit by command semantics, not inferred from status-code mutation.
### Phase Flags
Rule and command parsing tracks phase requirements via `PhaseFlag`:
- `PhasePre`
- `PhasePost`
- `PhasePre | PhasePost` (combined)
Combined flags are expected for nested/compound commands and variable templates that may need both request and response context.
### Condition Matchers
| Matcher | Type | Description |
@@ -166,22 +205,22 @@ path regex("/api/v[0-9]+/.*") // regex pattern
**Terminating Actions** (stop processing):
| Command | Description |
| ------------------------ | ---------------------- |
| `error <code> <message>` | Return HTTP error |
| `redirect <url>` | Redirect to URL |
| `serve <path>` | Serve local files |
| `route <name>` | Route to another route |
| `proxy <url>` | Proxy to upstream |
| Command | Description |
| ------------------------------ | ------------------------------------- |
| `upstream` / `bypass` / `pass` | Call upstream and terminate pre-phase |
| `error <code> <message>` | Return HTTP error |
| `redirect <url>` | Redirect to URL |
| `serve <path>` | Serve local files |
| `route <name>` | Route to another route |
| `proxy <url>` | Proxy to upstream |
| `require_basic_auth <realm>` | Return 401 challenge |
**Non-Terminating Actions** (modify and continue):
| Command | Description |
| ------------------------------ | ---------------------- |
| `pass` / `bypass` | Pass through unchanged |
| `rewrite <from> <to>` | Rewrite request path |
| `require_auth` | Require authentication |
| `require_basic_auth <realm>` | Basic auth challenge |
| `set <target> <field> <value>` | Set header/variable |
| `add <target> <field> <value>` | Add header/variable |
| `remove <target> <field>` | Remove header/variable |
@@ -208,6 +247,166 @@ rules:
action2
```
### Rule Configuration (Block Syntax)
This is an alternative (and will eventually be the primary) syntax for rules that avoids YAML.
It keeps the **inner** `on` and `do` DSLs exactly the same (same matchers, same commands, same optional quotes), but wraps each rule in a `{ ... }` block.
#### Key ideas
- A rule is:
- `default { <do...> }`,
- `{ <do...> }`, or
- `<on-expr> { <do...> }`
- Comments are supported:
- line comment: `// ...` (to end of line)
- line comment: `# ...` (to end of line, for YAML familiarity)
- block comment: `/* ... */` (may span multiple lines)
- Comments are ignored **only when outside quotes** (`"`, `'` or backticks).
- Environment variable syntax: `${NAME}` is supported by the inner DSL parser in [`parse()`](internal/route/rules/parser.go:34).
Block-syntax rule:
- In `on` (rule header): `${...}` must be inside quotes/backticks.
- In `do` (rule body): `${...}` may be unquoted; the outer parser must treat `${...}` as an opaque token so braces inside it are not structural.
#### Grammar sketch (EBNF-ish)
```text
file := { ws | comment | rule }
rule := default_rule | unconditional_rule | conditional_rule
default_rule := 'default' ws* block
unconditional_rule := ws* block
conditional_rule := on_expr ws* block
block := '{' do_body '}'
// on_expr and do_body are raw text regions.
// The outer parser only needs to:
// - find the top-level '{' to start a rule block
// - find the matching top-level '}' to end it
// while respecting quotes and comments.
```
#### Elif/Else Chain Grammar
```text
// Elif/Else chains can appear in do_body
do_stmt := command_line | nested_block | elif_else_chain
elif_else_chain := nested_block { elif_clause } [else_clause]
elif_clause := 'elif' ws* on_expr ws* '{' do_body '}'
else_clause := 'else' ws* '{' do_body '}'
```
#### Nested blocks (inline conditionals inside `do`)
Inside a rule body (`do_body`), you can write **nested blocks** that start with `@`:
```text
do_stmt := command_line | nested_block | elif_else_chain
nested_block := '@' on_expr ws* '{' do_body '}'
```
Notes:
- A nested block is only recognized when `@` is the **first non-space character on a line**.
- `on_expr` uses the same syntax as rule `on` (supports `|`, `&`, quoting/backticks, matcher functions, etc.).
- The nested block executes **in sequence**, at the point where it appears in the parent `do` list.
- Nested blocks are evaluated in the same phase the parent rule runs (no special phase promotion).
- Nested blocks can be chained with `elif`/`else` for conditional execution (see Elif/Else Chains section).
Example:
```go
default {
remove resp_header X-Secret
add resp_header X-Custom-Header custom-value
}
header X-Test-Header {
set header X-Remote-Type public
@remote 127.0.0.1 | remote 192.168.0.0/16 {
set header X-Remote-Type private
}
}
```
#### Elif/Else Chains
You can chain multiple conditions using `elif` and provide a fallback with `else`.
The `elif`/`else` keywords must appear on the same line as the preceding closing brace (`}`).
```go
header X-Test-Header {
@method GET {
set header X-Mode get
} elif method POST {
set header X-Mode post
} else {
set header X-Mode other
}
}
```
Notes:
- `elif` and `else` must be on the same line as the preceding `}`.
- Multiple `elif` branches are allowed; only one `else` is allowed.
- The entire chain is evaluated in sequence; the first matching branch executes.
- Elif/else chains can only be used within nested blocks (starting with `@`).
- Each `elif` clause must have its own condition expression and block.
- The `else` clause is optional and provides a default action when no conditions match.
#### Examples
Basic default rule:
```go
default {
bypass
}
```
WebSocket upgrade routing:
```bash
# WebSocket requests
header Connection Upgrade &
header Upgrade websocket {
route ws-api
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
}
```
Block comments:
```go
/* protect admin area */
path glob("/admin/*") {
require_auth
}
```
Always log the request
```bash
{
log info /dev/stdout "Request $req_method $req_path"
}
```
#### Notes and constraints
- The block syntax uses `{` and `}` as structure delimiters at **top-level** (outside quotes/comments).
- Braces inside quoted strings (including backticks) are not structural.
- `${...}` handling:
- `on`: must be quoted/backticked
- `do`: may be unquoted
Preferred style: always write env vars as `${NAME}` rather than a bare `$NAME`.
- If you need literal `{` or `}` outside quotes/backticks (for example unquoted templates like `{{ ... }}`), wrap that argument in quotes/backticks so the outer parser does not treat it as structure.
- Rule naming remains minimal: if no explicit name is provided by the syntax, it will behave like the current YAML behavior (empty name becomes `rule[index]` in [`Rules.BuildHandler()`](internal/route/rules/rules.go:75)).
- YAML remains supported as a fallback for backward compatibility.
### Condition Syntax
```yaml
@@ -215,12 +414,13 @@ rules:
on: path /api/users
# Multiple conditions (AND)
on: |
header Authorization Bearer
& path /api/admin/*
on: header Authorization Bearer & path glob("/api/admin/*")
# Negation
on: !path /public/*
on: !path glob("/public/*")
# Negation on matcher
on: path !glob("/public/*")
# OR within a line
on: method GET | method POST
@@ -228,21 +428,21 @@ on: method GET | method POST
### Variable Substitution
```go
// Static variables
$req_method // Request method
$req_host // Request host
$req_path // Request path
$status_code // Response status
$remote_host // Client IP
```bash
# Static variables
$req_method # Request method
$req_host # Request host
$req_path # Request path
$status_code # Response status
$remote_host # Client IP
// Dynamic variables
$header(Name) // Request header
$header(Name, index) // Header at index
$arg(Name) // Query argument
$form(Name) // Form field
# Dynamic variables
$header(Name) # Request header
$header(Name, index) # Header at index
$arg(Name) # Query argument
$form(Name) # Form field
// Environment variables
# Environment variables
${ENV_VAR}
```
@@ -277,12 +477,13 @@ Log context includes: `rule`, `alias`, `match_result`
## Failure Modes and Recovery
| Failure | Behavior | Recovery |
| ------------------- | ------------------------- | ---------------------------------- |
| Invalid rule syntax | Route validation fails | Fix YAML syntax |
| Missing variables | Variable renders as empty | Check variable sources |
| Rule timeout | Request times out | Increase timeout or simplify rules |
| Auth failure | Returns 401/403 | Fix credentials |
| Failure | Behavior | Recovery |
| ---------------------- | ------------------------- | ---------------------------------- |
| Invalid rule syntax | Route validation fails | Fix YAML syntax |
| Multiple default rules | Route validation fails | Remove duplicate default rules |
| Missing variables | Variable renders as empty | Check variable sources |
| Rule timeout | Request times out | Increase timeout or simplify rules |
| Auth failure | Returns 401/403 | Fix credentials |
## Usage Examples
@@ -297,11 +498,11 @@ Log context includes: `rule`, `alias`, `match_result`
```yaml
- name: api proxy
on: path /api/*
on: path glob("/api/*")
do: proxy http://api-backend:8080
- name: static files
on: path /static/*
on: path glob("/static/*")
do: serve /var/www/static
```
@@ -309,11 +510,11 @@ Log context includes: `rule`, `alias`, `match_result`
```yaml
- name: admin protection
on: path /admin/*
on: path glob("/admin/*")
do: require_auth
- name: basic auth for API
on: path /api/*
on: path glob("/api/*")
do: require_basic_auth "API Access"
```
@@ -321,7 +522,7 @@ Log context includes: `rule`, `alias`, `match_result`
```yaml
- name: rewrite API v1
on: path /v1/*
on: path glob("/v1/*")
do: |
rewrite /v1 /api/v1
proxy http://backend:8080
@@ -351,6 +552,27 @@ Log context includes: `rule`, `alias`, `match_result`
do: bypass
```
### Default Rule (Baseline)
```yaml
# Default runs first and can provide baseline behavior
- name: default
do: |
remove resp_header X-Internal
add resp_header X-Powered-By godoxy
# Specific rules can override or add to baseline behavior
- name: api routes
on: path glob("/api/*")
do: proxy http://api:8080
- name: api marker
on: path glob("/api/*")
do: set resp_header X-API true
```
Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors.
## Testing Notes
- Unit tests for all matchers and actions

View File

@@ -0,0 +1,406 @@
package rules
import (
"strings"
"unicode"
"github.com/yusing/goutils/env"
gperr "github.com/yusing/goutils/errs"
)
func getStringBuffer(size int) *strings.Builder {
var buf strings.Builder
if size > 0 {
buf.Grow(size)
}
return &buf
}
// expandEnvVarsRaw expands ${NAME} in-place using env.LookupEnv (prefix-aware).
func expandEnvVarsRaw(v string) (string, gperr.Error) {
buf := getStringBuffer(len(v))
envVar := getStringBuffer(0)
var missingEnvVars []string
inEnvVar := false
expectingBrace := false
for _, r := range v {
if expectingBrace && r != '{' && r != '$' {
buf.WriteRune('$')
expectingBrace = false
}
switch r {
case '$':
if expectingBrace {
buf.WriteRune('$')
expectingBrace = false
} else {
expectingBrace = true
}
case '{':
if expectingBrace {
inEnvVar = true
expectingBrace = false
envVar.Reset()
} else {
buf.WriteRune(r)
}
case '}':
if inEnvVar {
envValue, ok := env.LookupEnv(envVar.String())
if !ok {
missingEnvVars = append(missingEnvVars, envVar.String())
} else {
buf.WriteString(envValue)
}
inEnvVar = false
} else {
buf.WriteRune(r)
}
default:
if expectingBrace {
buf.WriteRune('$')
expectingBrace = false
}
if inEnvVar {
envVar.WriteRune(r)
} else {
buf.WriteRune(r)
}
}
}
if expectingBrace {
buf.WriteRune('$')
}
var err gperr.Error
if inEnvVar {
err = ErrUnterminatedEnvVar
}
if len(missingEnvVars) > 0 {
err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...)))
}
return buf.String(), err
}
// parseBlockRules parses the block-syntax rule format.
// Grammar:
//
// file := { ws | comment | rule }
// rule := default_rule | conditional_rule
// default_rule := 'default' ws* block
// conditional_rule := on_expr ws* block
// block := '{' do_body '}'
//
// Where:
// - on_expr is passed verbatim to RuleOn.Parse()
// - do_body is passed verbatim to Command.Parse()
//
// Comments (ignored outside quotes/backticks):
// - line comment: // ... or # ...
// - block comment: /* ... */
//
// Brace handling:
// - Braces inside quotes/backticks are ignored
// - Braces inside ${...} (env vars) are ignored in do_body
// - Braces in on_expr are not ignored (env vars must be quoted in on_expr)
//
//nolint:dupword
func parseBlockRules(src string) (Rules, gperr.Error) {
var rules Rules
var errs gperr.Builder
pos := 0
length := len(src)
t := newTokenizer(src)
for pos < length {
// Skip whitespace/comments between rules.
newPos, skipErr := t.skipComments(pos, true, true)
if skipErr != nil {
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
}
pos = newPos
if pos >= length {
break
}
// Stray closing brace at top-level: keep parsing but mark invalid so Rules.Validate() fails.
if src[pos] == '}' {
return nil, ErrInvalidBlockSyntax.Withf("unmatched '}' at position %d", pos)
}
// Parse rule header (default, unconditional, or on_expr)
headerStart := pos
header := parseRuleHeader(&t, src, &pos, length)
headerStr := src[headerStart:pos]
// Skip whitespace/comments before '{' (default header may end before '{').
newPos, skipErr = t.skipComments(pos, false, true)
if skipErr != nil {
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
}
pos = newPos
if pos >= length || src[pos] != '{' {
errs.AddSubjectf(ErrInvalidBlockSyntax, "expected '{' after rule header %q", headerStr)
return nil, errs.Error()
}
// Find matching '}' (respecting quotes and env vars in do_body)
bodyStart := pos + 1
bodyEnd, err := t.findMatchingBrace(bodyStart)
if err != nil {
errs.AddSubjectf(err, "rule header %q", headerStr)
return nil, errs.Error()
}
pos = bodyEnd + 1
onExpr := header
doBody := ""
if bodyStart < bodyEnd {
doBody = src[bodyStart:bodyEnd]
}
// Normalize do body for the inner DSL parser:
// - strip comments (outside quotes/backticks)
// - trim block whitespace/indentation
// - expand ${ENV} in-place so cmd.raw is usable/debuggable
doBody, err = preprocessDoBody(doBody)
if err != nil {
errs.AddSubjectf(err, "rule header %q", headerStr)
return nil, errs.Error()
}
rule := Rule{
Name: "", // auto-generate if empty
On: RuleOn{},
Do: Command{},
}
// Header semantics:
// - "default" => default rule (matched when no other rules are matched)
// - "" => unconditional rule (always matches)
// - otherwise => conditional rule (on expression)
switch onExpr {
case "default":
rule.On.raw = OnDefault
case "":
// leave rule.On as zero value => checker=nil => always matches
default:
if parseErr := rule.On.Parse(onExpr); parseErr != nil {
errs.AddSubjectf(parseErr, "on")
}
}
if doBody != "" {
if parseErr := rule.Do.Parse(doBody); parseErr != nil {
errs.AddSubjectf(parseErr, "do")
}
}
if errs.HasError() {
return nil, errs.Error()
}
rules = append(rules, rule)
}
return rules, nil
}
func preprocessDoBody(doBody string) (string, gperr.Error) {
doBody = strings.TrimSpace(doBody)
if doBody == "" {
return "", nil
}
normalized := doBody
// If comments are possible, strip them first while preserving line breaks.
if strings.ContainsAny(normalized, "#/") {
stripped, err := stripCommentsPreserveNewlines(normalized)
if err != nil {
return "", err
}
normalized = stripped
}
// Drop lines that are empty after trimming, while preserving indentation of non-empty lines.
out := getStringBuffer(len(normalized))
lineStart := 0
wroteLine := false
for i := 0; i <= len(normalized); i++ {
if i < len(normalized) && normalized[i] != '\n' {
continue
}
line := normalized[lineStart:i]
if strings.TrimSpace(line) != "" {
if wroteLine {
out.WriteByte('\n')
}
out.WriteString(line)
wroteLine = true
}
lineStart = i + 1
}
if !wroteLine {
return "", nil
}
normalized = out.String()
// Expand env vars to keep Command.raw consistent with parsed semantics.
if !strings.Contains(normalized, "${") {
return normalized, nil
}
expanded, err := expandEnvVarsRaw(normalized)
if err != nil {
return "", err
}
return expanded, nil
}
// stripCommentsPreserveNewlines removes //, #, and /* */ comments outside quotes/backticks.
// It preserves newlines so command line boundaries remain intact.
func stripCommentsPreserveNewlines(src string) (string, gperr.Error) {
if !strings.ContainsAny(src, "#/") {
return src, nil
}
out := getStringBuffer(len(src))
quote := rune(0)
inLine := false
inBlock := false
atLineStart := true
prevIsSpace := true
for i := 0; i < len(src); {
c := src[i]
if inLine {
if c == '\n' {
inLine = false
out.WriteByte('\n')
atLineStart = true
prevIsSpace = true
}
i++
continue
}
if inBlock {
if c == '\n' {
out.WriteByte('\n')
atLineStart = true
prevIsSpace = true
i++
continue
}
if c == '*' && i+1 < len(src) && src[i+1] == '/' {
inBlock = false
i += 2
continue
}
i++
continue
}
if quote != 0 {
out.WriteByte(c)
if c == '\\' && i+1 < len(src) {
// Write next char and skip it (escape sequence)
i++
out.WriteByte(src[i])
atLineStart = false
prevIsSpace = false
i++
continue
}
if rune(c) == quote {
quote = 0
}
if c == '\n' {
atLineStart = true
prevIsSpace = true
} else {
atLineStart = false
prevIsSpace = unicode.IsSpace(rune(c))
}
i++
continue
}
// Not in quote/comment.
switch c {
case '\'', '"', '`':
quote = rune(c)
out.WriteByte(c)
atLineStart = false
prevIsSpace = false
i++
continue
case '#':
if atLineStart || prevIsSpace {
inLine = true
i++
continue
}
case '/':
if i+1 < len(src) {
n := src[i+1]
if (atLineStart || prevIsSpace) && n == '/' {
inLine = true
i += 2
continue
}
if (atLineStart || prevIsSpace) && n == '*' {
inBlock = true
i += 2
continue
}
}
}
out.WriteByte(c)
if c == '\n' {
atLineStart = true
prevIsSpace = true
} else {
atLineStart = false
prevIsSpace = unicode.IsSpace(rune(c))
}
i++
}
if inBlock {
return "", ErrInvalidBlockSyntax.Withf("unterminated block comment")
}
return out.String(), nil
}
// parseRuleHeader parses the rule header (default or on expression).
// Returns the header string, or "" if parsing failed.
func parseRuleHeader(t *Tokenizer, src string, pos *int, length int) string {
start := *pos
// Check for 'default' keyword
if *pos+7 <= length && src[*pos:*pos+7] == "default" {
next := *pos + 7
if next >= length || unicode.IsSpace(rune(src[next])) {
*pos = next
return "default"
}
}
// Parse on expression until we hit '{' outside quotes.
bracePos, err := t.scanToBrace(*pos)
if err != nil {
*pos = length
return strings.TrimSpace(src[start:*pos])
}
*pos = bracePos
return strings.TrimSpace(src[start:*pos])
}

View File

@@ -0,0 +1,48 @@
package rules
import "testing"
func BenchmarkParseBlockRules(b *testing.B) {
const rulesString = `
default {
remove resp_header X-Secret
add resp_header X-Custom-Header custom-value
}
header X-Test-Header {
set header X-Remote-Type public
@remote 127.0.0.1 | remote 192.168.0.0/16 {
set header X-Remote-Type private
}
}
path glob(/api/admin/*) {
@cookie session-id {
set header X-Session-ID $cookie(session-id)
}
}
!remote 192.168.0.0/16 {
@!header X-User-Role admin & !header X-User-Role user {
error 403 "Access denied"
} elif remote 127.0.0.1 {
@header X-User-Role staff {
set header X-User-Role staff
}
} else {
error 403 "Access denied"
}
}
`
var rules Rules
err := rules.Parse(rulesString)
if err != nil {
b.Fatal(err)
}
for b.Loop() {
var rules Rules
_ = rules.Parse(rulesString)
}
}

View File

@@ -0,0 +1,339 @@
package rules
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/serialization"
httputils "github.com/yusing/goutils/http"
)
func testParseRules(t *testing.T, data string) Rules {
t.Helper()
var rules Rules
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
require.True(t, convertible)
require.NoError(t, err)
return rules
}
func testParseRulesError(t *testing.T, data string) error {
t.Helper()
var rules Rules
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
require.True(t, convertible)
return err
}
func TestParseBlockRules_DefaultRule(t *testing.T) {
rules := testParseRules(t, `default {
upstream
}`)
require.Len(t, rules, 1)
assert.Equal(t, OnDefault, rules[0].On.raw)
assert.Equal(t, "upstream", rules[0].Do.raw)
assert.True(t, rules[0].Do.raw == CommandUpstream)
}
func TestParseBlockRules_ConditionalRule(t *testing.T) {
rules := testParseRules(t, `path glob(/api/*) {
proxy http://localhost:8080
}`)
require.Len(t, rules, 1)
assert.Equal(t, "path glob(/api/*)", rules[0].On.raw)
assert.Equal(t, "proxy http://localhost:8080", rules[0].Do.raw)
require.Len(t, rules[0].Do.pre, 1)
_, ok := rules[0].Do.pre[0].(Handler)
require.True(t, ok)
require.Len(t, rules[0].Do.post, 0)
}
func TestParseBlockRules_MultipleRules(t *testing.T) {
rules := testParseRules(t, `default {
bypass
}
path /api/* {
proxy http://localhost:8080
}
header Connection Upgrade &
header Upgrade websocket {
route ws-api
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
}`)
require.Len(t, rules, 3)
// Default rule
assert.Equal(t, OnDefault, rules[0].On.raw)
assert.Equal(t, "bypass", rules[0].Do.raw)
// API rule
assert.Equal(t, "path /api/*", rules[1].On.raw)
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
// WebSocket rule
assert.Equal(t, "header Connection Upgrade &\nheader Upgrade websocket", rules[2].On.raw)
assert.Equal(t, `route ws-api
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"`, rules[2].Do.raw)
require.Len(t, rules[2].Do.pre, 2)
_, ok := rules[2].Do.pre[0].(Handler)
require.True(t, ok)
_, ok = rules[2].Do.pre[1].(Handler)
require.True(t, ok)
require.Len(t, rules[2].Do.post, 0)
}
func TestParseBlockRules_Comments(t *testing.T) {
rules := testParseRules(t, `// This is a comment
default {
bypass // inline comment
}
/* Block comment
spanning multiple lines */
path /admin/* {
require_auth
}`)
require.Len(t, rules, 2)
assert.Equal(t, OnDefault, rules[0].On.raw)
assert.Equal(t, "path /admin/*", rules[1].On.raw)
assert.Equal(t, "require_auth", rules[1].Do.raw)
}
func TestParseBlockRules_HashComment(t *testing.T) {
rules := testParseRules(t, `# YAML-style comment
default {
bypass
}`)
require.Len(t, rules, 1)
assert.Equal(t, OnDefault, rules[0].On.raw)
assert.Equal(t, "bypass", rules[0].Do.raw)
}
func TestParseBlockRules_EnvVars(t *testing.T) {
t.Setenv("CUSTOM_HEADER", "test-header")
rules := testParseRules(t, `path /api/* {
set header X-Custom "${CUSTOM_HEADER}"
}`)
require.Len(t, rules, 1)
assert.Equal(t, "path /api/*", rules[0].On.raw)
assert.Equal(t, `set header X-Custom "test-header"`, rules[0].Do.raw)
require.Len(t, rules[0].Do.pre, 1)
_, ok := rules[0].Do.pre[0].(Handler)
require.True(t, ok)
require.Len(t, rules[0].Do.post, 0)
}
func TestParseBlockRules_YAMLFallback(t *testing.T) {
rules := testParseRules(t, `- name: default
do: bypass
- name: api
on: path glob(/api/*)
do: proxy http://localhost:8080`)
require.Len(t, rules, 2)
assert.Equal(t, "default", rules[0].Name)
assert.Equal(t, "bypass", rules[0].Do.raw)
assert.Equal(t, "api", rules[1].Name)
assert.Equal(t, "path glob(/api/*)", rules[1].On.raw)
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
require.Len(t, rules[1].Do.pre, 1)
_, ok := rules[1].Do.pre[0].(Handler)
require.True(t, ok)
require.Len(t, rules[1].Do.post, 0)
}
func TestParseBlockRules_UnmatchedBrace(t *testing.T) {
t.Run("unquoted", func(t *testing.T) {
err := testParseRulesError(t, `path /api/* {
proxy http://localhost:8080}
}`)
require.Error(t, err)
})
t.Run("quoted", func(t *testing.T) {
_ = testParseRules(t, `path /api/* {
error 403 "some message}"
}`)
})
}
func TestParseBlockRules_UnterminatedBlockComment(t *testing.T) {
err := testParseRulesError(t, `/* unterminated block comment
default {
bypass
}`)
require.Error(t, err)
}
func TestParseBlockRules_NestedBlocks(t *testing.T) {
rules := testParseRules(t, `
header X-Test-Header {
set header X-Remote-Type public
@remote 127.0.0.1 | remote 192.168.0.0/16 {
set header X-Remote-Type private
}
}`)
require.Len(t, rules, 1)
assert.Equal(t, "header X-Test-Header", rules[0].On.raw)
require.Len(t, rules[0].Do.pre, 2)
_, ok := rules[0].Do.pre[0].(Handler)
require.True(t, ok)
require.Len(t, rules[0].Do.post, 0)
ifCmd, ok := rules[0].Do.pre[1].(IfBlockCommand)
require.True(t, ok)
assert.Equal(t, "remote 127.0.0.1 | remote 192.168.0.0/16", ifCmd.On.raw)
require.Len(t, ifCmd.Do, 1)
upstream := func(http.ResponseWriter, *http.Request) {}
t.Run("condition matched executes nested content", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Test-Header", "1")
req.RemoteAddr = "127.0.0.1:12345"
w := httptest.NewRecorder()
rm := httputils.NewResponseModifier(w)
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
require.NoError(t, err)
assert.Equal(t, "private", req.Header.Get("X-Remote-Type"))
})
t.Run("condition not matched skips nested content", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Test-Header", "1")
req.RemoteAddr = "10.0.0.1:12345"
w := httptest.NewRecorder()
rm := httputils.NewResponseModifier(w)
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
require.NoError(t, err)
assert.Equal(t, "public", req.Header.Get("X-Remote-Type"))
})
}
func TestParseBlockRules_NestedBlocks_ElifElse(t *testing.T) {
rules := testParseRules(t, `
header X-Test-Header {
set header X-Mode outer
@method GET {
set header X-Mode get
} elif method POST {
set header X-Mode post
} else {
set header X-Mode other
}
}`)
require.Len(t, rules, 1)
require.Len(t, rules[0].Do.pre, 2)
ifCmd, ok := rules[0].Do.pre[1].(IfElseBlockCommand)
require.True(t, ok)
require.Len(t, ifCmd.Ifs, 2)
assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw)
assert.Equal(t, "method POST", ifCmd.Ifs[1].On.raw)
require.NotNil(t, ifCmd.Else)
upstream := func(http.ResponseWriter, *http.Request) {}
cases := []struct {
name string
method string
want string
}{
{name: "get branch", method: http.MethodGet, want: "get"},
{name: "post branch", method: http.MethodPost, want: "post"},
{name: "else branch", method: http.MethodPut, want: "other"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.method, "/", nil)
req.Header.Set("X-Test-Header", "1")
w := httptest.NewRecorder()
rm := httputils.NewResponseModifier(w)
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
require.NoError(t, err)
assert.Equal(t, tc.want, req.Header.Get("X-Mode"))
})
}
}
func TestParseBlockRules_DefaultRule_CommentBeforeBrace(t *testing.T) {
rules := testParseRules(t, `default /* comment between header and brace */ {
bypass
}`)
require.Len(t, rules, 1)
assert.Equal(t, OnDefault, rules[0].On.raw)
assert.Equal(t, "bypass", rules[0].Do.raw)
}
func TestParseBlockRules_StrayClosingBraceAtTopLevel(t *testing.T) {
err := testParseRulesError(t, `}
default {
bypass
}`)
require.Error(t, err)
}
func TestParseBlockRules_NestedBlocks_ElifMustBeSameLine(t *testing.T) {
err := testParseRulesError(t, `header X-Test-Header {
@method GET {
set header X-Mode get
}
elif method POST {
set header X-Mode post
}
}`)
require.Error(t, err)
}
func TestParseBlockRules_NestedBlocks_ElseMustBeLastOnLine(t *testing.T) {
err := testParseRulesError(t, `header X-Test-Header {
@method GET {
set header X-Mode get
} else {
set header X-Mode other
} set header X-After else
}`)
require.Error(t, err)
assert.Contains(t, err.Error(), "unexpected token after else block")
}
func TestParseBlockRules_NestedBlocks_MultipleElse(t *testing.T) {
err := testParseRulesError(t, `header X-Test-Header {
@method GET {
set header X-Mode get
} else {
set header X-Mode other
} else {
set header X-Mode other2
}
}`)
require.Error(t, err)
// assert.Contains(t, err.Error(), "multiple 'else' branches")
assert.Contains(t, err.Error(), "unexpected token after else block")
}
func TestParseBlockRules_NestedBlocks_ElifMissingOnExpr(t *testing.T) {
err := testParseRulesError(t, `header X-Test-Header {
@method GET {
set header X-Mode get
} elif {
set header X-Mode post
}
}`)
require.Error(t, err)
assert.Contains(t, err.Error(), "expected on-expr after 'elif'")
}

View File

@@ -1,21 +1,25 @@
package rules
import "net/http"
import (
"net/http"
httputils "github.com/yusing/goutils/http"
)
type (
CheckFunc func(w http.ResponseWriter, r *http.Request) bool
CheckFunc func(w *httputils.ResponseModifier, r *http.Request) bool
Checker interface {
Check(w http.ResponseWriter, r *http.Request) bool
Check(w *httputils.ResponseModifier, r *http.Request) bool
}
CheckMatchSingle []Checker
CheckMatchAll []Checker
)
func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool {
func (checker CheckFunc) Check(w *httputils.ResponseModifier, r *http.Request) bool {
return checker(w, r)
}
func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool {
func (checkers CheckMatchSingle) Check(w *httputils.ResponseModifier, r *http.Request) bool {
for _, check := range checkers {
if check.Check(w, r) {
return true
@@ -24,7 +28,7 @@ func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) b
return false
}
func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool {
func (checkers CheckMatchAll) Check(w *httputils.ResponseModifier, r *http.Request) bool {
for _, check := range checkers {
if !check.Check(w, r) {
return false

View File

@@ -1,79 +1,62 @@
package rules
import "net/http"
import (
"errors"
"net/http"
httputils "github.com/yusing/goutils/http"
)
var errTerminateRule = errors.New("terminate rule")
type (
handlerFunc func(w http.ResponseWriter, r *http.Request) error
HandlerFunc func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
Handler struct {
fn HandlerFunc
phase PhaseFlag
terminate bool
}
CommandHandler interface {
// CommandHandler can read and modify the values
// then handle the request
// finally proceed to next command (or return) base on situation
Handle(w http.ResponseWriter, r *http.Request) error
IsResponseHandler() bool
ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
Phase() PhaseFlag
}
// NonTerminatingCommand will run then proceed to next command or reverse proxy.
NonTerminatingCommand handlerFunc
// TerminatingCommand will run then return immediately.
TerminatingCommand handlerFunc
// OnResponseCommand will run then return based on the response.
OnResponseCommand handlerFunc
// BypassCommand will skip all the following commands
// and directly return to reverse proxy.
BypassCommand struct{}
// Commands is a slice of CommandHandler.
Commands []CommandHandler
)
func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
return c(w, r)
func (h Handler) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
return h.fn(w, r, upstream)
}
func (c NonTerminatingCommand) IsResponseHandler() bool {
return false
func (h Handler) Phase() PhaseFlag {
return h.phase
}
func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
if err := c(w, r); err != nil {
return err
}
return errTerminated
func (h Handler) Terminates() bool {
return h.terminate
}
func (c TerminatingCommand) IsResponseHandler() bool {
return false
}
func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error {
return c(w, r)
}
func (c OnResponseCommand) IsResponseHandler() bool {
return true
}
func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error {
return errTerminated
}
func (c BypassCommand) IsResponseHandler() bool {
return false
}
func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error {
func (c Commands) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
for _, cmd := range c {
if err := cmd.Handle(w, r); err != nil {
err := cmd.ServeHTTP(w, r, upstream)
if err != nil {
// Terminating actions stop the command chain immediately.
// Will be handled by the caller.
return err
}
}
return nil
}
func (c Commands) IsResponseHandler() bool {
func (c Commands) Phase() PhaseFlag {
req := PhaseNone
for _, cmd := range c {
if cmd.IsResponseHandler() {
return true
}
req |= cmd.Phase()
}
return false
return req
}

View File

@@ -1,7 +1,6 @@
package rules
import (
"bytes"
"fmt"
"io"
"net/http"
@@ -24,17 +23,17 @@ import (
type (
Command struct {
raw string
exec CommandHandler
isResponseHandler bool
raw string
pre Commands // runs before w.WriteHeader
post Commands
}
)
func (cmd *Command) IsResponseHandler() bool {
return cmd.isResponseHandler
}
const (
CommandUpstream = "upstream"
CommandUpstreamOld = "bypass"
CommandUpstreamOld2 = "pass"
CommandRequireAuth = "require_auth"
CommandRewrite = "rewrite"
CommandServe = "serve"
@@ -48,8 +47,6 @@ const (
CommandRemove = "remove"
CommandLog = "log"
CommandNotify = "notify"
CommandPass = "pass"
CommandPassAlt = "bypass"
)
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
@@ -60,36 +57,57 @@ func InitAuthHandler(handler AuthHandler) {
authHandler = handler
}
func init() {
commands[CommandUpstreamOld] = commands[CommandUpstream]
commands[CommandUpstreamOld2] = commands[CommandUpstream]
}
var commands = map[string]struct {
help Help
validate ValidateFunc
build func(args any) CommandHandler
isResponseHandler bool
help Help
validate ValidateFunc
build func(args any) HandlerFunc
terminate bool
}{
CommandUpstream: {
help: Help{
command: CommandUpstream,
description: makeLines("Pass the request to the upstream"),
args: map[string]string{},
},
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 0 {
return phase, nil, ErrExpectNoArg
}
return phase, nil, nil
},
build: func(args any) HandlerFunc {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
upstream(w, r)
return errTerminateRule
}
},
terminate: true,
},
CommandRequireAuth: {
help: Help{
command: CommandRequireAuth,
description: makeLines("Require HTTP authentication for incoming requests"),
args: map[string]string{},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 0 {
return nil, ErrExpectNoArg
return phase, nil, ErrExpectNoArg
}
//nolint:nilnil
return nil, nil
return phase, nil, nil
},
build: func(args any) CommandHandler {
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
if authHandler == nil {
http.Error(w, "Auth handler not initialized", http.StatusInternalServerError)
return errTerminated
}
if !authHandler(w, r) {
return errTerminated
build: func(args any) HandlerFunc {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
if proceed := authHandler(w, r); !proceed {
return errTerminateRule
}
return nil
})
}
},
},
CommandRewrite: {
@@ -104,26 +122,27 @@ var commands = map[string]struct {
"to": "the path to rewrite to, must start with /",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 2 {
return nil, ErrExpectTwoArgs
return phase, nil, ErrExpectTwoArgs
}
path1, err1 := validateURLPath(args[:1])
path2, err2 := validateURLPath(args[1:])
if err1 != nil {
err1 = gperr.PrependSubject(err1, "from")
err1 = gperr.Errorf("from: %w", err1)
}
if err2 != nil {
err2 = gperr.PrependSubject(err2, "to")
err2 = gperr.Errorf("to: %w", err2)
}
if err1 != nil || err2 != nil {
return nil, gperr.Join(err1, err2)
return phase, nil, gperr.Join(err1, err2)
}
return &StrTuple{path1.(string), path2.(string)}, nil
return phase, &StrTuple{path1.(string), path2.(string)}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
orig, repl := args.(*StrTuple).Unpack()
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
path := r.URL.Path
if len(path) > 0 && path[0] != '/' {
path = "/" + path
@@ -133,10 +152,10 @@ var commands = map[string]struct {
}
path = repl + path[len(orig):]
r.URL.Path = path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
r.URL.RawPath = ""
r.RequestURI = ""
return nil
})
}
},
},
CommandServe: {
@@ -150,14 +169,19 @@ var commands = map[string]struct {
"root": "the file system path to serve, must be an existing directory",
},
},
validate: validateFSPath,
build: func(args any) CommandHandler {
root := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
return nil
})
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
parsedArgs, err = validateFSPath(args)
return
},
build: func(args any) HandlerFunc {
root := args.(string)
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
return errTerminateRule
}
},
terminate: true,
},
CommandRedirect: {
help: Help{
@@ -170,14 +194,19 @@ var commands = map[string]struct {
"to": "the url to redirect to, can be relative or absolute URL",
},
},
validate: validateURL,
build: func(args any) CommandHandler {
target := args.(*nettypes.URL).String()
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
return nil
})
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
parsedArgs, err = validateURL(args)
return
},
build: func(args any) HandlerFunc {
target := args.(*nettypes.URL).String()
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
return errTerminateRule
}
},
terminate: true,
},
CommandRoute: {
help: Help{
@@ -190,15 +219,16 @@ var commands = map[string]struct {
"route": "the route to route to",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 1 {
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
}
return args[0], nil
return phase, args[0], nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
route := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
return func(w *httputils.ResponseModifier, req *http.Request, upstream http.HandlerFunc) error {
ep := entrypoint.FromCtx(req.Context())
r, ok := ep.HTTPRoutes().Get(route)
if !ok {
@@ -212,9 +242,10 @@ var commands = map[string]struct {
} else {
http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound)
}
return nil
})
return errTerminateRule
}
},
terminate: true,
},
CommandError: {
help: Help{
@@ -228,34 +259,40 @@ var commands = map[string]struct {
"text": "the error message to return",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) != 2 {
return nil, ErrExpectTwoArgs
return phase, nil, ErrExpectTwoArgs
}
codeStr, text := args[0], args[1]
code, err := strconv.Atoi(codeStr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
return phase, nil, ErrInvalidArguments.With(err)
}
if !httputils.IsStatusCodeValid(code) {
return nil, ErrInvalidArguments.Subject(codeStr)
return phase, nil, ErrInvalidArguments.Subject(codeStr)
}
textTmpl, err := validateTemplate(text, true)
tmplReq, textTmpl, err := validateTemplate(text, true)
if err != nil {
return nil, ErrInvalidArguments.With(err)
return phase, nil, ErrInvalidArguments.With(err)
}
return &Tuple[int, templateString]{code, textTmpl}, nil
phase |= tmplReq
return phase, &Tuple[int, templateString]{code, textTmpl}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
// error command should overwrite the response body
httputils.GetInitResponseModifier(w).ResetBody()
w.ResetBody()
w.WriteHeader(code)
err := textTmpl.ExpandVars(w, r, w)
return err
})
_, err := textTmpl.ExpandVars(w, r, w.BodyBuffer())
if err != nil {
return err
}
return errTerminateRule
}
},
terminate: true,
},
CommandRequireBasicAuth: {
help: Help{
@@ -268,20 +305,22 @@ var commands = map[string]struct {
"realm": "the authentication realm",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
if len(args) == 1 {
return args[0], nil
return phase, args[0], nil
}
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
realm := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, realm))
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return nil
})
return errTerminateRule
}
},
terminate: true,
},
CommandProxy: {
help: Help{
@@ -294,14 +333,19 @@ var commands = map[string]struct {
"to": "the url to proxy to, must be an absolute URL",
},
},
validate: validateURL,
build: func(args any) CommandHandler {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePre
parsedArgs, err = validateURL(args)
return
},
build: func(args any) HandlerFunc {
target := args.(*nettypes.URL)
if target.Scheme == "" {
target.Scheme = "http"
}
if target.Host == "" {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
rawPath := target.EscapedPath()
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
url := target.URL
url.Host = routes.TryGetUpstreamHostPort(r)
if url.Host == "" {
@@ -309,18 +353,19 @@ var commands = map[string]struct {
}
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
r.URL.Path = target.Path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
r.URL.RawPath = rawPath
r.RequestURI = ""
rp.ServeHTTP(w, r)
return nil
})
return errTerminateRule
}
}
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
rp.ServeHTTP(w, r)
return nil
})
return errTerminateRule
}
},
terminate: true,
},
CommandSet: {
help: Help{
@@ -335,11 +380,11 @@ var commands = map[string]struct {
"value": "the value to set",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
return validateModField(ModFieldSet, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
build: func(args any) HandlerFunc {
return args.(HandlerFunc)
},
},
CommandAdd: {
@@ -355,11 +400,11 @@ var commands = map[string]struct {
"value": "the value to add",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
return validateModField(ModFieldAdd, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
build: func(args any) HandlerFunc {
return args.(HandlerFunc)
},
},
CommandRemove: {
@@ -374,15 +419,14 @@ var commands = map[string]struct {
"field": "the field to remove",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
return validateModField(ModFieldRemove, args)
},
build: func(args any) CommandHandler {
return args.(CommandHandler)
build: func(args any) HandlerFunc {
return args.(HandlerFunc)
},
},
CommandLog: {
isResponseHandler: true,
help: Help{
command: CommandLog,
description: makeLines(
@@ -399,28 +443,28 @@ var commands = map[string]struct {
"template": "the template to log",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 3 {
return nil, ErrExpectThreeArgs
return phase, nil, ErrExpectThreeArgs
}
tmpl, err := validateTemplate(args[2], true)
phase, tmpl, err := validateTemplate(args[2], true)
if err != nil {
return nil, err
return phase, nil, err
}
level, err := validateLevel(args[0])
if err != nil {
return nil, err
return phase, nil, err
}
// NOTE: file will stay opened forever
// it leverages accesslog.NewFileIO so
// it will be opened only once for the same path
f, err := openFile(args[1])
if err != nil {
return nil, err
return phase, nil, err
}
return &onLogArgs{level, f, tmpl}, nil
return phase, &onLogArgs{level, f, tmpl}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
level, f, tmpl := args.(*onLogArgs).Unpack()
var logger io.Writer
if f == stdout || f == stderr {
@@ -428,17 +472,16 @@ var commands = map[string]struct {
} else {
logger = f
}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
err := tmpl.ExpandVars(w, r, logger)
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
_, err := tmpl.ExpandVars(w, r, logger)
if err != nil {
return err
}
return nil
})
}
},
},
CommandNotify: {
isResponseHandler: true,
help: Help{
command: CommandNotify,
description: makeLines(
@@ -456,22 +499,24 @@ var commands = map[string]struct {
"body": "the body of the notification",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 4 {
return nil, ErrExpectFourArgs
return phase, nil, ErrExpectFourArgs
}
titleTmpl, err := validateTemplate(args[2], false)
req1, titleTmpl, err := validateTemplate(args[2], false)
if err != nil {
return nil, err
return phase, nil, err
}
bodyTmpl, err := validateTemplate(args[3], false)
req2, bodyTmpl, err := validateTemplate(args[3], false)
if err != nil {
return nil, err
return phase, nil, err
}
level, err := validateLevel(args[0])
if err != nil {
return nil, err
return phase, nil, err
}
phase |= req1 | req2
// TODO: validate provider
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
@@ -480,34 +525,34 @@ var commands = map[string]struct {
// if err != nil {
// return nil, err
// }
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
return phase, &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
},
build: func(args any) CommandHandler {
build: func(args any) HandlerFunc {
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
to := []string{provider}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
var respBuf strings.Builder
err := titleTmpl.ExpandVars(w, r, respBuf)
_, err := titleTmpl.ExpandVars(w, r, &respBuf)
if err != nil {
return err
}
titleLen := respBuf.Len()
err = bodyTmpl.ExpandVars(w, r, respBuf)
_, err = bodyTmpl.ExpandVars(w, r, &respBuf)
if err != nil {
return err
}
b := respBuf.Bytes()
s := respBuf.String()
notif.Notify(&notif.LogMessage{
Level: level,
Title: string(b[:titleLen]),
Body: notif.MessageBodyBytes(b[titleLen:]),
Title: s[:titleLen],
Body: notif.MessageBodyBytes(s[titleLen:]),
To: to,
})
return nil
})
}
},
},
}
@@ -519,121 +564,29 @@ type (
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {
executors := make([]CommandHandler, 0)
isResponseHandler := false
for line := range strings.SplitSeq(v, "\n") {
if line == "" {
continue
}
directive, args, err := parse(line)
if err != nil {
return err
}
if directive == CommandPass || directive == CommandPassAlt {
if len(args) != 0 {
return ErrExpectNoArg
}
executors = append(executors, BypassCommand{})
continue
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
validArgs, err := builder.validate(args)
if err != nil {
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
return gperr.PrependSubject(err, directive).With(builder.help.Error())
}
handler := builder.build(validArgs)
executors = append(executors, handler)
if builder.isResponseHandler || handler.IsResponseHandler() {
isResponseHandler = true
}
executors, parseErr := parseDoWithBlocks(v)
if parseErr != nil {
return parseErr
}
if len(executors) == 0 {
cmd.raw = v
cmd.exec = nil
cmd.isResponseHandler = false
cmd.pre = nil
cmd.post = nil
return nil
}
exec, err := buildCmd(executors)
if err != nil {
return err
}
cmd.raw = v
cmd.exec = exec
if exec.IsResponseHandler() {
isResponseHandler = true
for _, executor := range executors {
if executor.Phase().IsPostRule() {
cmd.post = append(cmd.post, executor)
} else {
cmd.pre = append(cmd.pre, executor)
}
}
cmd.isResponseHandler = isResponseHandler
return nil
}
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
// Validate the execution order.
//
// This allows sequences like:
// route ws-api
// log info /dev/stdout "..."
// where the first command is request-phase and the last is response-phase.
lastNonResp := -1
seenResp := false
for i, exec := range executors {
if exec.IsResponseHandler() {
seenResp = true
continue
}
if seenResp {
return nil, ErrInvalidCommandSequence.Withf("response handlers must be the last commands")
}
lastNonResp = i
}
for i, exec := range executors {
if i > lastNonResp {
break // response-handler tail
}
switch exec.(type) {
case TerminatingCommand, BypassCommand:
if i != lastNonResp {
return nil, ErrInvalidCommandSequence.
Withf("a response handler or terminating/bypass command must be the last command")
}
}
}
return Commands(executors), nil
}
// Command is purely "bypass" or empty.
func (cmd *Command) isBypass() bool {
if cmd == nil {
return true
}
switch cmd := cmd.exec.(type) {
case BypassCommand:
return true
case Commands:
// bypass command is always the last one
_, ok := cmd[len(cmd)-1].(BypassCommand)
return ok
default:
return false
}
}
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
return cmd.exec.Handle(w, r)
}
func (cmd *Command) String() string {
return cmd.raw
}

View File

@@ -0,0 +1,386 @@
package rules
import (
"net/http"
"strings"
"unicode"
gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http"
)
// IfBlockCommand is an inline conditional block inside a do-body.
//
// Syntax (within a rule do block):
//
// @<on-expr> { <do...> }
//
// Semantics:
// - Evaluated in the same phase the parent rule runs.
// - If <on-expr> matches, run the nested commands in-order.
// - Otherwise do nothing.
//
// NOTE: Per current design decision, we keep this permissive:
// nested blocks may use response matchers and response commands; no extra phase validation is performed.
type IfBlockCommand struct {
On RuleOn
Do []CommandHandler
}
func (c IfBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
if c.Do == nil {
return nil
}
// If On.checker is nil, treat as unconditional (should not happen if parsed).
if c.On.checker == nil {
return Commands(c.Do).ServeHTTP(w, r, upstream)
}
if c.On.checker.Check(w, r) {
return Commands(c.Do).ServeHTTP(w, r, upstream)
}
return nil
}
func (c IfBlockCommand) Phase() PhaseFlag {
phase := c.On.phase
for _, cmd := range c.Do {
phase |= cmd.Phase()
}
return phase
}
// IfElseBlockCommand is a chained conditional block inside a do-body.
//
// Syntax (within a rule do block):
//
// @<on-expr> { <do...> } elif <on-expr> { <do...> } ... else { <do...> }
//
// NOTE: `elif`/`else` must appear on the same line as the preceding closing brace (`}`),
// e.g. `} elif ... {` and `} else {`.
type IfElseBlockCommand struct {
Ifs []IfBlockCommand
Else []CommandHandler
}
func (c IfElseBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
for _, br := range c.Ifs {
// If On.checker is nil, treat as unconditional.
if br.On.checker == nil {
if br.Do == nil {
continue
}
return Commands(br.Do).ServeHTTP(w, r, upstream)
}
if br.Do == nil {
continue
}
if br.On.checker.Check(w, r) {
return Commands(br.Do).ServeHTTP(w, r, upstream)
}
}
if len(c.Else) > 0 {
return Commands(c.Else).ServeHTTP(w, r, upstream)
}
return nil
}
func (c IfElseBlockCommand) Phase() PhaseFlag {
phase := PhaseNone
for _, br := range c.Ifs {
phase |= br.Phase()
}
if len(c.Else) > 0 {
phase |= Commands(c.Else).Phase()
}
return phase
}
func skipSameLineSpace(src string, pos int) int {
for pos < len(src) {
switch src[pos] {
case '\n':
return pos
case '\r':
pos++
continue
case ' ', '\t':
pos++
continue
default:
return pos
}
}
return pos
}
func parseAtBlockChain(src string, atPos int) (CommandHandler, int, error) {
length := len(src)
headerStart := atPos + 1
parseBranch := func(onExpr string, bodyStart int, bodyEnd int) (RuleOn, []CommandHandler, error) {
var on RuleOn
if err := on.Parse(onExpr); err != nil {
return RuleOn{}, nil, err
}
innerSrc := ""
if bodyStart < bodyEnd {
innerSrc = src[bodyStart:bodyEnd]
}
inner, err := parseDoWithBlocks(innerSrc)
if err != nil {
return RuleOn{}, nil, err
}
if len(inner) == 0 {
return on, nil, nil
}
return on, inner, nil
}
onExpr, bracePos, herr := parseHeaderToBrace(src, headerStart)
if herr != nil {
return nil, 0, herr
}
if bracePos >= length || src[bracePos] != '{' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after nested block header")
}
// Parse first @<on-expr> { ... }
p := bracePos
bodyStart := p + 1
bodyEnd, ferr := findMatchingBrace(src, &p, bodyStart)
if ferr != nil {
return nil, 0, ferr
}
firstOn, firstDo, berr := parseBranch(onExpr, bodyStart, bodyEnd)
if berr != nil {
return nil, 0, berr
}
ifs := []IfBlockCommand{{On: firstOn, Do: firstDo}}
var elseDo []CommandHandler
hasChain := false
hasElse := false
for {
q := skipSameLineSpace(src, p)
if q >= length || src[q] == '\n' {
break
}
// elif <on-expr> { ... }
if strings.HasPrefix(src[q:], "elif") {
next := q + len("elif")
if next >= length {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
}
if src[next] == '\n' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
}
if !unicode.IsSpace(rune(src[next])) {
if src[next] == '{' || src[next] == '}' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
}
return nil, 0, ErrInvalidBlockSyntax.Withf("expected whitespace after 'elif'")
}
next++
for next < length {
c := src[next]
if c == '\n' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
}
if c == '\r' {
next++
continue
}
if !unicode.IsSpace(rune(c)) {
break
}
next++
}
p2 := next
elifOnExpr, bracePos, herr := parseHeaderToBrace(src, p2)
if herr != nil {
return nil, 0, herr
}
if elifOnExpr == "" {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
}
if bracePos >= length || src[bracePos] != '{' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
}
p2 = bracePos
elifBodyStart := p2 + 1
elifBodyEnd, ferr := findMatchingBrace(src, &p2, elifBodyStart)
if ferr != nil {
return nil, 0, ferr
}
elifOn, elifDo, berr := parseBranch(elifOnExpr, elifBodyStart, elifBodyEnd)
if berr != nil {
return nil, 0, berr
}
ifs = append(ifs, IfBlockCommand{On: elifOn, Do: elifDo})
hasChain = true
p = p2
continue
}
// else { ... }
if strings.HasPrefix(src[q:], "else") {
if hasElse {
return nil, 0, ErrInvalidBlockSyntax.Withf("multiple 'else' branches")
}
next := q + len("else")
for next < length {
c := src[next]
if c == '\n' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
}
if c == '\r' {
next++
continue
}
if !unicode.IsSpace(rune(c)) {
break
}
next++
}
if next >= length || src[next] != '{' {
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
}
elseBodyStart := next + 1
p2 := next
elseBodyEnd, ferr := findMatchingBrace(src, &p2, elseBodyStart)
if ferr != nil {
return nil, 0, ferr
}
innerSrc := ""
if elseBodyStart < elseBodyEnd {
innerSrc = src[elseBodyStart:elseBodyEnd]
}
inner, ierr := parseDoWithBlocks(innerSrc)
if ierr != nil {
return nil, 0, ierr
}
if len(inner) == 0 {
elseDo = nil
} else {
elseDo = inner
}
hasChain = true
hasElse = true
p = p2
// else must be the last branch on that line.
for q2 := skipSameLineSpace(src, p); q2 < length && src[q2] != '\n'; q2 = skipSameLineSpace(src, q2) {
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after else block")
}
break
}
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after nested block; expected 'elif'/'else' or newline")
}
if hasChain {
return IfElseBlockCommand{Ifs: ifs, Else: elseDo}, p, nil
}
return IfBlockCommand{On: ifs[0].On, Do: ifs[0].Do}, p, nil
}
// parseDoWithBlocks parses a do-body containing plain command lines and nested @-blocks.
// It returns the outer command handlers and the require phase.
//
// A nested block is only recognized when '@' is the first non-space character on a line.
func parseDoWithBlocks(src string) (handlers []CommandHandler, err error) {
pos := 0
length := len(src)
lineStart := true
handlers = make([]CommandHandler, 0, strings.Count(src, "\n")+1)
appendLineCommand := func(line string) error {
line = strings.TrimSpace(line)
if line == "" {
return nil
}
directive, args, err := parse(line)
if err != nil {
return err
}
builder, ok := commands[directive]
if !ok {
return ErrUnknownDirective.Subject(directive)
}
phase, validArgs, err := builder.validate(args)
if err != nil {
return gperr.PrependSubject(err, directive).With(builder.help.Error())
}
h := builder.build(validArgs)
handlers = append(handlers, Handler{fn: h, phase: phase, terminate: builder.terminate})
return nil
}
for pos < length {
// Handle newlines
switch src[pos] {
case '\n':
pos++
lineStart = true
continue
case '\r':
// tolerate CRLF
pos++
continue
}
if lineStart {
// Find first non-space on the line.
linePos := pos
for linePos < length {
c := rune(src[linePos])
if c == '\n' {
break
}
if !unicode.IsSpace(c) {
break
}
linePos++
}
if linePos < length && src[linePos] == '@' {
h, next, err := parseAtBlockChain(src, linePos)
if err != nil {
return nil, err
}
handlers = append(handlers, h)
pos = next
lineStart = false
continue
}
// Not a nested block; parse the rest of this line as a command.
lineEnd := pos
for lineEnd < length && src[lineEnd] != '\n' {
lineEnd++
}
if lerr := appendLineCommand(src[pos:lineEnd]); lerr != nil {
return nil, lerr
}
pos = lineEnd
lineStart = true
continue
}
// Not at line start; advance to the next line boundary.
for pos < length && src[pos] != '\n' {
pos++
}
lineStart = true
}
return handlers, nil
}

View File

@@ -0,0 +1,73 @@
package rules
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
httputils "github.com/yusing/goutils/http"
)
func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoFallsThrough(t *testing.T) {
elseCalled := false
cmd := IfElseBlockCommand{
Ifs: []IfBlockCommand{
{
On: RuleOn{},
Do: nil,
},
},
Else: []CommandHandler{
Handler{
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
elseCalled = true
return nil
},
phase: PhaseNone,
},
},
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
rm := httputils.NewResponseModifier(w)
err := cmd.ServeHTTP(rm, req, nil)
require.NoError(t, err)
assert.True(t, elseCalled)
}
func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoFallsThrough(t *testing.T) {
elseCalled := false
cmd := IfElseBlockCommand{
Ifs: []IfBlockCommand{
{
On: RuleOn{
checker: CheckFunc(func(_ *httputils.ResponseModifier, _ *http.Request) bool {
return true
}),
},
Do: nil,
},
},
Else: []CommandHandler{
Handler{
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
elseCalled = true
return nil
},
phase: PhaseNone,
},
},
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
rm := httputils.NewResponseModifier(w)
err := cmd.ServeHTTP(rm, req, nil)
require.NoError(t, err)
assert.True(t, elseCalled)
}

View File

@@ -37,7 +37,7 @@ func parseRules(data string, target *Rules) error {
}
func TestLogCommand_TemporaryFile(t *testing.T) {
upstream := mockUpstreamWithHeaders(200, "success response", http.Header{
upstream := mockUpstreamWithHeaders(http.StatusOK, "success response", http.Header{
"Content-Type": []string{"application/json"},
})
@@ -45,10 +45,9 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
var rules Rules
err := parseRules(fmt.Sprintf(`
- name: log-request-response
do: |
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
`, logFile), &rules)
default {
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
}`, logFile), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
@@ -59,7 +58,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "success response", w.Body.String())
// Read and verify log content
@@ -74,12 +73,10 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
var rules Rules
err := parseRules(`
- name: log-stdout
do: |
log info /dev/stdout "stdout: $req_method $status_code"
- name: log-stderr
do: |
log error /dev/stderr "stderr: $req_path $status_code"
default {
log info /dev/stdout "stdout: $req_method $status_code"
log error /dev/stderr "stderr: $req_path $status_code"
}
`, &rules)
require.NoError(t, err)
@@ -90,7 +87,7 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, http.StatusOK, w.Code)
// Note: We can't easily capture stdout/stderr in unit tests,
// but we can verify no errors occurred and the handler completed
}
@@ -104,26 +101,22 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
var rules Rules
err := parseRules(fmt.Sprintf(`
- name: log-info
do: |
log info %s "INFO: $req_method $status_code"
- name: log-warn
do: |
log warn %s "WARN: $req_path $status_code"
- name: log-error
do: |
log error %s "ERROR: $req_method $req_path $status_code"
default {
log info %s "INFO: $req_method $status_code"
log warn %s "WARN: $req_path $status_code"
log error %s "ERROR: $req_method $req_path $status_code"
}
`, infoFile, warnFile, errorFile), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("DELETE", "/api/resource/123", nil)
req := httptest.NewRequest(http.MethodDelete, "/api/resource/123", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
// Verify each log file
infoContent := TestFileContent(infoFile)
@@ -148,22 +141,22 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
var rules Rules
err := parseRules(fmt.Sprintf(`
- name: log-with-templates
do: |
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
default {
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
}
`, tempFile), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("PUT", "/api/resource", nil)
req := httptest.NewRequest(http.MethodPut, "/api/resource", nil)
req.Header.Set("User-Agent", "test-client/1.0")
req.Host = "example.com"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
assert.Equal(t, http.StatusCreated, w.Code)
// Verify log content
content := TestFileContent(tempFile)
@@ -192,14 +185,12 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
var rules Rules
err := parseRules(fmt.Sprintf(`
- name: log-success
on: status 2xx
do: |
log info %q "SUCCESS: $req_method $req_path $status_code"
- name: log-error
on: status 4xx | status 5xx
do: |
log error %q "ERROR: $req_method $req_path $status_code"
status 2xx {
log info %q "SUCCESS: $req_method $req_path $status_code"
}
status 4xx | status 5xx {
log error %q "ERROR: $req_method $req_path $status_code"
}
`, successFile, errorFile), &rules)
require.NoError(t, err)
@@ -244,9 +235,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
var rules Rules
err := parseRules(fmt.Sprintf(`
- name: log-multiple
do: |
log info %q "$req_method $req_path $status_code"`, tempFile), &rules)
default {
log info %q "$req_method $req_path $status_code"
}
`, tempFile), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
@@ -256,10 +248,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
method string
path string
}{
{"GET", "/users"},
{"POST", "/users"},
{"PUT", "/users/1"},
{"DELETE", "/users/1"},
{http.MethodGet, "/users"},
{http.MethodPost, "/users"},
{http.MethodPost, "/users/1"},
{http.MethodDelete, "/users/1"},
}
for _, reqInfo := range requests {
@@ -287,8 +279,9 @@ func TestLogCommand_InvalidTemplate(t *testing.T) {
// Test with invalid template syntax
err := parseRules(`
- name: log-invalid
do: |
log info /dev/stdout "$invalid_var"`, &rules)
assert.ErrorIs(t, err, ErrUnexpectedVar)
default {
log info /dev/stdout "$invalid_var"
}
`, &rules)
require.ErrorIs(t, err, ErrUnexpectedVar)
}

View File

@@ -12,7 +12,7 @@ import (
type (
FieldHandler struct {
set, add, remove CommandHandler
set, add, remove HandlerFunc
}
FieldModifier string
)
@@ -49,30 +49,30 @@ var modFields = map[string]struct {
"value": "the header template",
},
},
validate: toKeyValueTemplate,
validate: validatePreRequestKVTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
r.Header[k] = []string{v}
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
},
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
r.Header[k] = append(r.Header[k], v)
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
},
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
delete(r.Header, k)
return nil
}),
},
}
},
},
@@ -84,30 +84,30 @@ var modFields = map[string]struct {
"value": "the response header template",
},
},
validate: toKeyValueTemplate,
validate: validatePostResponseKVTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
w.Header()[k] = []string{v}
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
},
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
w.Header()[k] = append(w.Header()[k], v)
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
},
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
delete(w.Header(), k)
return nil
}),
},
}
},
},
@@ -119,36 +119,36 @@ var modFields = map[string]struct {
"value": "the query template",
},
},
validate: toKeyValueTemplate,
validate: validatePreRequestKVTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
w.SharedData().UpdateQueries(r, func(queries url.Values) {
queries.Set(k, v)
})
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
},
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
w.SharedData().UpdateQueries(r, func(queries url.Values) {
queries.Add(k, v)
})
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
},
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
w.SharedData().UpdateQueries(r, func(queries url.Values) {
queries.Del(k)
})
return nil
}),
},
}
},
},
@@ -160,16 +160,16 @@ var modFields = map[string]struct {
"value": "the cookie value",
},
},
validate: toKeyValueTemplate,
validate: validatePreRequestKVTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
for i, c := range cookies {
if c.Name == k {
cookies[i].Value = v
@@ -179,19 +179,19 @@ var modFields = map[string]struct {
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := tmpl.ExpandVarsToString(w, r)
},
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
v, _, err := tmpl.ExpandVarsToString(w, r)
if err != nil {
return err
}
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
},
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
index := -1
for i, c := range cookies {
if c.Name == k {
@@ -208,7 +208,7 @@ var modFields = map[string]struct {
return cookies
})
return nil
}),
},
}
},
},
@@ -227,24 +227,27 @@ var modFields = map[string]struct {
"template": "the body template",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
return 0, nil, ErrExpectOneArg
}
return validateTemplate(args[0], true)
phase = PhasePre
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
phase |= tmplReq
return
},
builder: func(args any) *FieldHandler {
tmpl := args.(templateString)
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
if r.Body != nil {
r.Body.Close()
r.Body = nil
}
bufPool := httputils.GetInitResponseModifier(w).BufPool()
bufPool := w.BufPool()
b := bufPool.GetBuffer()
err := tmpl.ExpandVars(w, r, b)
_, err := tmpl.ExpandVars(w, r, b)
if err != nil {
return err
}
@@ -252,7 +255,7 @@ var modFields = map[string]struct {
bufPool.PutBuffer(b)
})
return nil
}),
},
}
},
},
@@ -272,20 +275,26 @@ var modFields = map[string]struct {
"template": "the response body template",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
return 0, nil, ErrExpectOneArg
}
return validateTemplate(args[0], true)
phase = PhasePost
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
phase |= tmplReq
return
},
builder: func(args any) *FieldHandler {
tmpl := args.(templateString)
return &FieldHandler{
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
rm := httputils.GetInitResponseModifier(w)
rm.ResetBody()
return tmpl.ExpandVars(w, r, rm)
}),
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
w.ResetBody()
_, err := tmpl.ExpandVars(w, r, w)
if err != nil {
return err
}
return nil
},
}
},
},
@@ -300,26 +309,27 @@ var modFields = map[string]struct {
"code": "the status code",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
}
phase = PhasePost
status, err := strconv.Atoi(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
return phase, nil, ErrInvalidArguments.With(err)
}
if status < 100 || status > 599 {
return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
return phase, nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
}
return status, nil
return phase, status, nil
},
builder: func(args any) *FieldHandler {
status := args.(int)
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
httputils.GetInitResponseModifier(w).WriteHeader(status)
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
w.WriteHeader(status)
return nil
}),
},
}
},
},

View File

@@ -5,7 +5,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"
@@ -72,12 +71,12 @@ func TestFieldHandler_Header(t *testing.T) {
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.value, false)
_, tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
var cmd HandlerFunc
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
@@ -87,7 +86,7 @@ func TestFieldHandler_Header(t *testing.T) {
cmd = handler.remove
}
err := cmd.Handle(w, req)
err := cmd(httputils.NewResponseModifier(w), req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
@@ -150,12 +149,12 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
tt.setup(w)
}
tmpl, tErr := validateTemplate(tt.value, false)
_, tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
var cmd HandlerFunc
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
@@ -165,7 +164,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
cmd = handler.remove
}
err := cmd.Handle(w, req)
err := cmd(httputils.NewResponseModifier(w), req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
@@ -237,12 +236,12 @@ func TestFieldHandler_Query(t *testing.T) {
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.value, false)
_, tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
var cmd HandlerFunc
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
@@ -252,7 +251,7 @@ func TestFieldHandler_Query(t *testing.T) {
cmd = handler.remove
}
err := cmd.Handle(w, req)
err := cmd(httputils.NewResponseModifier(w), req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
@@ -335,12 +334,12 @@ func TestFieldHandler_Cookie(t *testing.T) {
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.value, false)
_, tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
var cmd HandlerFunc
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
@@ -350,7 +349,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
cmd = handler.remove
}
err := cmd.Handle(w, req)
err := cmd(httputils.NewResponseModifier(w), req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
@@ -371,7 +370,7 @@ func TestFieldHandler_Body(t *testing.T) {
name: "set body with template",
template: "Hello $req_method $req_path",
setup: func(r *http.Request) {
r.Method = "POST"
r.Method = http.MethodPost
r.URL.Path = "/test"
},
verify: func(r *http.Request) {
@@ -399,15 +398,15 @@ func TestFieldHandler_Body(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
w := httputils.NewResponseModifier(httptest.NewRecorder())
tmpl, tErr := validateTemplate(tt.template, false)
_, tmpl, tErr := validateTemplate(tt.template, false)
if tErr != nil {
t.Fatalf("Failed to parse template: %v", tErr)
}
handler := modFields[FieldBody].builder(tmpl)
err := handler.set.Handle(w, req)
err := handler.set(w, req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
@@ -428,7 +427,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
name: "set response body with template",
template: "Response: $req_method $req_path",
setup: func(r *http.Request) {
r.Method = "GET"
r.Method = http.MethodGet
r.URL.Path = "/api/test"
},
verify: func(rm *httputils.ResponseModifier) {
@@ -443,23 +442,20 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
w := httputils.NewResponseModifier(httptest.NewRecorder())
// Create ResponseModifier wrapper
rm := httputils.NewResponseModifier(w)
tmpl, tErr := validateTemplate(tt.template, false)
_, tmpl, tErr := validateTemplate(tt.template, false)
if tErr != nil {
t.Fatalf("Failed to parse template: %v", tErr)
}
handler := modFields[FieldResponseBody].builder(tmpl)
err := handler.set.Handle(rm, req)
err := handler.set(w, req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(rm)
tt.verify(w)
})
}
}
@@ -472,23 +468,23 @@ func TestFieldHandler_StatusCode(t *testing.T) {
}{
{
name: "set status code 200",
status: 200,
status: http.StatusOK,
verify: func(w *httptest.ResponseRecorder) {
assert.Equal(t, 200, w.Code, "Expected status code 200")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
},
},
{
name: "set status code 404",
status: 404,
status: http.StatusNotFound,
verify: func(w *httptest.ResponseRecorder) {
assert.Equal(t, 404, w.Code, "Expected status code 404")
assert.Equal(t, http.StatusNotFound, w.Code, "Expected status code 404")
},
},
{
name: "set status code 500",
status: 500,
status: http.StatusInternalServerError,
verify: func(w *httptest.ResponseRecorder) {
assert.Equal(t, 500, w.Code, "Expected status code 500")
assert.Equal(t, http.StatusInternalServerError, w.Code, "Expected status code 500")
},
},
}
@@ -503,12 +499,11 @@ func TestFieldHandler_StatusCode(t *testing.T) {
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
err = cmd.ServeHTTP(rm, req)
err = cmd.post.ServeHTTP(rm, req, nil)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
rm.FlushRelease()
tt.verify(w)
})
}
@@ -600,7 +595,7 @@ func TestFieldValidation(t *testing.T) {
field, exists := modFields[tt.field]
assert.True(t, exists, "Field %s does not exist", tt.field)
_, err := field.validate(tt.args)
_, _, err := field.validate(tt.args)
if tt.wantError {
assert.Error(t, err, "Expected error but got none")
} else {
@@ -610,25 +605,6 @@ func TestFieldValidation(t *testing.T) {
}
}
func TestAllFields(t *testing.T) {
expectedFields := []string{
FieldHeader,
FieldResponseHeader,
FieldQuery,
FieldCookie,
FieldBody,
FieldResponseBody,
FieldStatusCode,
}
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
for _, expected := range expectedFields {
found := slices.Contains(AllFields, expected)
assert.True(t, found, "Expected field %s not found in AllFields", expected)
}
}
func TestModFields(t *testing.T) {
for fieldName, field := range modFields {
// Test that each field has required components

View File

@@ -14,8 +14,9 @@ var (
ErrEnvVarNotFound = gperr.New("env variable not found")
ErrInvalidArguments = gperr.New("invalid arguments")
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
ErrMultipleDefaultRules = gperr.New("multiple default rules")
ErrMultipleDefaultRules = gperr.New("multiple default rules")
ErrDeadRule = gperr.New("dead rule")
// vars errors
ErrNoArgProvided = gperr.New("no argument provided")
@@ -31,5 +32,5 @@ var (
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
errTerminated = gperr.New("terminated")
ErrInvalidBlockSyntax = gperr.New("invalid block syntax") // TODO: struct this error
)

View File

@@ -131,12 +131,13 @@ Error generates help string as error, e.g.
from: the path to rewrite, must start with /
to: the path to rewrite to, must start with /
*/
func (h *Help) Error() error {
var lines gperr.MultilineError
func (h *Help) Error() gperr.Error {
help := gperr.New(ansi.WithANSI(h.command, ansi.HighlightGreen))
for _, line := range h.description {
help = help.Withf("%s", line)
}
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
lines.AddStrings(h.description...)
lines.Adds(" args:")
args := gperr.New("args")
argKeys := make([]string, 0, len(h.args))
longestArg := 0
@@ -151,7 +152,9 @@ func (h *Help) Error() error {
slices.Sort(argKeys)
for _, arg := range argKeys {
desc := h.args[arg]
lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc)
paddedArg := fmt.Sprintf("%-"+strconv.Itoa(longestArg)+"s", arg)
args = args.Withf("%s%s", ansi.WithANSI(paddedArg, ansi.HighlightCyan)+": ", desc)
}
return &lines
return help.With(args)
}

File diff suppressed because it is too large Load Diff

View File

@@ -23,8 +23,9 @@ import (
)
// mockUpstream creates a simple upstream handler for testing
func mockUpstream(body string) http.HandlerFunc {
func mockUpstream(status int, body string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status)
w.Write([]byte(body))
}
}
@@ -47,7 +48,7 @@ func parseRules(data string, target *Rules) error {
return err
}
func TestHTTPFlow_BasicPreRules(t *testing.T) {
func TestHTTPFlow_BasicPreRulesYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
w.WriteHeader(http.StatusOK)
@@ -74,8 +75,8 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
}
func TestHTTPFlow_BypassRule(t *testing.T) {
upstream := mockUpstream("upstream response")
func TestHTTPFlow_BypassRuleYAML(t *testing.T) {
upstream := mockUpstream(http.StatusOK, "upstream response")
var rules Rules
err := parseRules(`
@@ -99,8 +100,8 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
assert.Equal(t, "upstream response", w.Body.String())
}
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
upstream := mockUpstream("should not be called")
func TestHTTPFlow_TerminatingCommandYAML(t *testing.T) {
upstream := mockUpstream(http.StatusOK, "should not be called")
var rules Rules
err := parseRules(`
@@ -120,13 +121,13 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "Forbidden\n", w.Body.String())
assert.Empty(t, w.Header().Get("X-Header"))
}
func TestHTTPFlow_RedirectFlow(t *testing.T) {
upstream := mockUpstream("should not be called")
func TestHTTPFlow_RedirectFlowYAML(t *testing.T) {
upstream := mockUpstream(http.StatusOK, "should not be called")
var rules Rules
err := parseRules(`
@@ -143,11 +144,11 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
assert.Equal(t, 307, w.Code) // TemporaryRedirect
assert.Equal(t, "/new-path", w.Header().Get("Location"))
}
func TestHTTPFlow_RewriteFlow(t *testing.T) {
func TestHTTPFlow_RewriteFlowYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("path: " + r.URL.Path))
@@ -172,7 +173,7 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
assert.Equal(t, "path: /v1/users", w.Body.String())
}
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
func TestHTTPFlow_MultiplePreRulesYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
@@ -201,7 +202,7 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
}
func TestHTTPFlow_PostResponseRule(t *testing.T) {
func TestHTTPFlow_PostResponseRuleYAML(t *testing.T) {
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
"X-Upstream": []string{"upstream-value"},
})
@@ -229,11 +230,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
// Check log file
content := TestFileContent(tempFile)
require.NoError(t, err)
assert.Equal(t, "GET 200\n", string(content))
}
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
func TestHTTPFlow_ResponseRuleWithStatusConditionYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/success" {
w.WriteHeader(http.StatusOK)
@@ -246,14 +246,17 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
var rules Rules
// Create a temporary file for logging
tempFile := TestRandomFileName()
errorLog := TestRandomFileName()
infoLog := TestRandomFileName()
err := parseRules(fmt.Sprintf(`
- name: log-errors
on: status 4xx
do: log error %s "$req_url returned $status_code"
`, tempFile), &rules)
status 4xx {
log error %s "$req_url returned $status_code"
}
status 200 {
log info %s "$req_url returned $status_code"
}
`, errorLog, infoLog), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
@@ -273,14 +276,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w2.Code)
// Check log file
content := TestFileContent(tempFile)
require.NoError(t, err)
content := TestFileContent(errorLog)
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
require.Len(t, lines, 1, "only 4xx requests should be logged")
assert.Equal(t, "/notfound returned 404", lines[0])
infoContent := TestFileContent(infoLog)
lines = strings.Split(strings.TrimSpace(string(infoContent)), "\n")
require.Len(t, lines, 1, "only 200 requests should be logged")
assert.Equal(t, "/success returned 200", lines[0])
}
func TestHTTPFlow_ConditionalRules(t *testing.T) {
func TestHTTPFlow_ConditionalRulesYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello " + r.Header.Get("X-Username")))
@@ -320,22 +327,21 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
}
func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
func TestHTTPFlow_ComplexFlowWithPreAndPostRulesYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate different responses based on path
if r.URL.Path == "/protected" {
if r.Header.Get("X-Auth") != "valid" {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("unauthorized"))
fmt.Fprint(w, "unauthorized")
return
}
}
w.Header().Set("X-Response-Time", "100ms")
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
fmt.Fprint(w, "success")
})
// Create temporary files for logging
logFile := TestRandomFileName()
errorLogFile := TestRandomFileName()
@@ -374,7 +380,7 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
handler.ServeHTTP(w2, req2)
assert.Equal(t, http.StatusUnauthorized, w2.Code)
assert.Equal(t, "Unauthorized\n", w2.Body.String())
assert.Equal(t, w2.Body.String(), "Unauthorized\n")
// Test authorized protected request
req3 := httptest.NewRequest(http.MethodGet, "/protected", nil)
@@ -402,8 +408,8 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
}
func TestHTTPFlow_DefaultRule(t *testing.T) {
upstream := mockUpstream("upstream response")
func TestHTTPFlow_DefaultRuleYAML(t *testing.T) {
upstream := mockUpstream(http.StatusOK, "upstream response")
var rules Rules
err := parseRules(`
@@ -436,11 +442,12 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
}
func TestHTTPFlow_HeaderManipulation(t *testing.T) {
func TestHTTPFlow_HeaderManipulationYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Echo back a header
headerValue := r.Header.Get("X-Test-Header")
w.Header().Set("X-Echoed-Header", headerValue)
w.Header().Set("X-Secret", "sensitive-data")
w.WriteHeader(http.StatusOK)
w.Write([]byte("header echoed"))
})
@@ -460,7 +467,6 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Secret", "secret-value")
req.Header.Set("X-Test-Header", "original-value")
w := httptest.NewRecorder()
@@ -469,11 +475,10 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
// Ensure the secret header was removed and not passed to upstream
// (we can't directly test this, but the upstream shouldn't see it)
assert.Empty(t, w.Header().Get("X-Secret"))
}
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
func TestHTTPFlow_QueryParameterHandlingYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
w.WriteHeader(http.StatusOK)
@@ -500,13 +505,15 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
assert.Equal(t, "query: added-value", w.Body.String())
}
func TestHTTPFlow_ServeCommand(t *testing.T) {
func TestHTTPFlow_ServeCommandYAML(t *testing.T) {
// Create a temporary directory with test files
tempDir := t.TempDir()
tempDir, err := os.MkdirTemp("", "test-serve-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create test files directly in the temp directory
testFile := filepath.Join(tempDir, "index.html")
err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
require.NoError(t, err)
var rules Rules
@@ -517,7 +524,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
`, tempDir), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(mockUpstream("should not be called"))
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
// Test serving a file - serve command serves files relative to the root directory
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
@@ -546,7 +553,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w2.Code)
}
func TestHTTPFlow_ProxyCommand(t *testing.T) {
func TestHTTPFlow_ProxyCommandYAML(t *testing.T) {
// Create a mock upstream server
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream-Header", "upstream-value")
@@ -563,7 +570,7 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
`, upstreamServer.URL), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(mockUpstream("should not be called"))
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
w := httptest.NewRecorder()
@@ -576,11 +583,28 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
}
func TestHTTPFlow_NotifyCommand(t *testing.T) {
// TODO:
func TestHTTPFlow_NotifyCommandYAML(t *testing.T) {
upstream := mockUpstream(http.StatusOK, "ok")
var rules Rules
err := parseRules(`
- name: notify-rule
on: path /notify
do: notify info test-provider "title $req_method" "body $req_url $status_code"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest(http.MethodGet, "/notify", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "ok", w.Body.String())
}
func TestHTTPFlow_FormConditions(t *testing.T) {
func TestHTTPFlow_FormConditionsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("form processed"))
@@ -620,7 +644,7 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
}
func TestHTTPFlow_RemoteConditions(t *testing.T) {
func TestHTTPFlow_RemoteConditionsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("remote processed"))
@@ -654,11 +678,11 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, http.StatusForbidden, w2.Code)
assert.Equal(t, 403, w2.Code)
assert.Equal(t, "Private network blocked\n", w2.Body.String())
}
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
func TestHTTPFlow_BasicAuthConditionsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("auth processed"))
@@ -702,7 +726,7 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
}
func TestHTTPFlow_RouteConditions(t *testing.T) {
func TestHTTPFlow_RouteConditionsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("route processed"))
@@ -742,10 +766,10 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
}
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
func TestHTTPFlow_ResponseStatusConditionsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write([]byte("method not allowed"))
fmt.Fprint(w, "method not allowed")
})
var rules Rules
@@ -767,11 +791,11 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
assert.Equal(t, "error\n", w.Body.String())
}
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
func TestHTTPFlow_ResponseHeaderConditionsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Response-Header", "response header")
w.WriteHeader(http.StatusOK)
w.Write([]byte("processed"))
fmt.Fprint(w, "processed")
})
t.Run("any_value", func(t *testing.T) {
@@ -831,7 +855,65 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
})
}
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMatchersYAML(t *testing.T) {
upstreamCalled := false
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstreamCalled = true
w.WriteHeader(http.StatusOK)
w.Write([]byte("upstream"))
})
var rules Rules
err := parseRules(`
- on: path /
do: error 403 blocked
- on: path /
do: set resp_header X-Late should-not-run
- on: status 4xx
do: set resp_header X-Post true
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.False(t, upstreamCalled)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "blocked\n", w.Body.String())
assert.Equal(t, "should-not-run", w.Header().Get("X-Late"))
assert.Equal(t, "true", w.Header().Get("X-Post"))
}
func TestHTTPFlow_PostRuleTermination_StopsRemainingCommandsInRuleYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
})
var rules Rules
err := parseRules(`
- on: status 200
do: |
error 500 failed
set resp_header X-After should-not-run
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "failed\n", w.Body.String())
assert.Empty(t, w.Header().Get("X-After"))
}
func TestHTTPFlow_ComplexRuleCombinationsYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("complex processed"))
@@ -887,12 +969,12 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3)
assert.Equal(t, 200, w3.Code)
assert.Equal(t, http.StatusOK, w3.Code)
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
assert.Empty(t, w3.Header()["X-API-Version"])
}
func TestHTTPFlow_ResponseModifier(t *testing.T) {
func TestHTTPFlow_ResponseModifierYAML(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("original response"))

View File

@@ -10,6 +10,7 @@ import (
"github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/logging/accesslog"
gperr "github.com/yusing/goutils/errs"
)
type noopWriteCloser struct {
@@ -30,7 +31,7 @@ var (
testFilesLock sync.Mutex
)
func openFile(path string) (io.WriteCloser, error) {
func openFile(path string) (io.WriteCloser, gperr.Error) {
switch path {
case "/dev/stdout":
return stdout, nil

View File

@@ -5,6 +5,7 @@ import (
"strings"
"github.com/gobwas/glob"
"github.com/puzpuzpuz/xsync/v4"
gperr "github.com/yusing/goutils/errs"
)
@@ -13,6 +14,8 @@ type (
MatcherType string
)
var matcherCache = xsync.NewMap[string, Matcher]() // map[string]Matcher
const (
MatcherTypeString MatcherType = "string"
MatcherTypeGlob MatcherType = "glob"
@@ -59,7 +62,12 @@ func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Erro
}
func ParseMatcher(expr string) (Matcher, gperr.Error) {
if cached, ok := matcherCache.Load(expr); ok {
return cached, nil
}
negate := false
origExpr := expr
if strings.HasPrefix(expr, "!") {
negate = true
expr = expr[1:]
@@ -72,11 +80,23 @@ func ParseMatcher(expr string) (Matcher, gperr.Error) {
switch t {
case MatcherTypeString:
return StringMatcher(expr, negate)
m, err := StringMatcher(expr, negate)
if err == nil {
matcherCache.Store(origExpr, m)
}
return m, err
case MatcherTypeGlob:
return GlobMatcher(expr, negate)
m, err := GlobMatcher(expr, negate)
if err == nil {
matcherCache.Store(origExpr, m)
}
return m, err
case MatcherTypeRegex:
return RegexMatcher(expr, negate)
m, err := RegexMatcher(expr, negate)
if err == nil {
matcherCache.Store(origExpr, m)
}
return m, err
}
// won't reach here
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)

View File

@@ -12,19 +12,19 @@ import (
)
type RuleOn struct {
raw string
checker Checker
isResponseChecker bool
}
func (on *RuleOn) IsResponseChecker() bool {
return on.isResponseChecker
raw string
checker Checker
phase PhaseFlag
}
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
return on.checker.Check(w, r)
if on.checker == nil {
return true
}
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
}
// on request
const (
OnDefault = "default"
OnHeader = "header"
@@ -39,35 +39,36 @@ const (
OnRemote = "remote"
OnBasicAuth = "basic_auth"
OnRoute = "route"
)
// on response
// on response
const (
OnResponseHeader = "resp_header"
OnStatus = "status"
)
var checkers = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) CheckFunc
isResponseChecker bool
help Help
validate ValidateFunc
builder func(args any) CheckFunc
}{
OnDefault: {
help: Help{
command: OnDefault,
description: makeLines(
"The default rule is matched when no other rules are matched.",
"Select the default (baseline) rule.",
),
args: map[string]string{},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 0 {
return nil, ErrExpectNoArg
return phase, nil, ErrExpectNoArg
}
//nolint:nilnil
return nil, nil
return phase, nil, nil
},
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
builder: func(args any) CheckFunc {
return func(w *httputils.ResponseModifier, r *http.Request) bool { return false }
}, // this should never be called
},
OnHeader: {
help: Help{
@@ -83,21 +84,23 @@ var checkers = map[string]struct {
"[value]": "the header value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(r.Header[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(r.Header[k], matcher)
}
},
},
OnResponseHeader: {
isResponseChecker: true,
help: Help{
command: OnResponseHeader,
description: makeLines(
@@ -111,16 +114,20 @@ var checkers = map[string]struct {
"[value]": "the response header value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePost
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(w.Header()[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(w.Header()[k], matcher)
}
},
},
@@ -138,16 +145,19 @@ var checkers = map[string]struct {
"[value]": "the query value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return len(w.SharedData().GetQueries(r)[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
}
},
},
@@ -165,12 +175,15 @@ var checkers = map[string]struct {
"[value]": "the cookie value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
cookies := httputils.GetSharedData(w).GetCookies(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
cookies := w.SharedData().GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
return true
@@ -179,8 +192,8 @@ var checkers = map[string]struct {
return false
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
cookies := httputils.GetSharedData(w).GetCookies(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
cookies := w.SharedData().GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
if matcher(cookie.Value) {
@@ -192,6 +205,7 @@ var checkers = map[string]struct {
}
},
},
//nolint:dupl
OnForm: {
help: Help{
command: OnForm,
@@ -206,15 +220,18 @@ var checkers = map[string]struct {
"[value]": "the form value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.FormValue(k) != ""
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.FormValue(k))
}
},
@@ -233,15 +250,18 @@ var checkers = map[string]struct {
"[value]": "the form value",
},
},
validate: toKVOptionalVMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = toKVOptionalVMatcher(args)
return
},
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.PostFormValue(k) != ""
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.PostFormValue(k))
}
},
@@ -250,32 +270,46 @@ var checkers = map[string]struct {
help: Help{
command: OnProto,
args: map[string]string{
"proto": "the http protocol (http, https, h3)",
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
},
},
validate: func(args []string) (any, error) {
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
return phase, nil, ErrExpectOneArg
}
proto := args[0]
if proto != "http" && proto != "https" && proto != "h3" {
return nil, ErrInvalidArguments.Withf("proto: %q", proto)
switch proto {
case "http", "https", "h1", "h2", "h2c", "h3":
return phase, proto, nil
default:
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
}
return proto, nil
},
builder: func(args any) CheckFunc {
proto := args.(string)
switch proto {
case "http":
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil
}
case "https":
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil
}
case "h1":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil && r.ProtoMajor == 1
}
case "h2":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil && r.ProtoMajor == 2
}
case "h2c":
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS == nil && r.ProtoMajor == 2
}
default: // h3
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.TLS != nil && r.ProtoMajor == 3
}
}
@@ -288,10 +322,13 @@ var checkers = map[string]struct {
"method": "the http method",
},
},
validate: validateMethod,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateMethod(args)
return
},
builder: func(args any) CheckFunc {
method := args.(string)
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return r.Method == method
}
},
@@ -310,10 +347,13 @@ var checkers = map[string]struct {
"host": "the host name",
},
},
validate: validateSingleMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateSingleMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(r.Host)
}
},
@@ -332,10 +372,13 @@ var checkers = map[string]struct {
"path": "the request path",
},
},
validate: validateURLPathMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateURLPathMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(w http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
@@ -351,22 +394,25 @@ var checkers = map[string]struct {
"ip|cidr": "the remote ip or cidr",
},
},
validate: validateCIDR,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateCIDR(args)
return
},
builder: func(args any) CheckFunc {
ipnet := args.(*net.IPNet)
// for /32 (IPv4) or /128 (IPv6), just compare the IP
if ones, bits := ipnet.Mask.Size(); ones == bits {
wantIP := ipnet.IP
return func(w http.ResponseWriter, r *http.Request) bool {
ip := httputils.GetSharedData(w).GetRemoteIP(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
ip := w.SharedData().GetRemoteIP(r)
if ip == nil {
return false
}
return ip.Equal(wantIP)
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
ip := httputils.GetSharedData(w).GetRemoteIP(r)
return func(w *httputils.ResponseModifier, r *http.Request) bool {
ip := w.SharedData().GetRemoteIP(r)
if ip == nil {
return false
}
@@ -382,11 +428,14 @@ var checkers = map[string]struct {
"password": "the password encrypted with bcrypt",
},
},
validate: validateUserBCryptPassword,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateUserBCryptPassword(args)
return
},
builder: func(args any) CheckFunc {
cred := args.(*HashedCrendentials)
return func(w http.ResponseWriter, r *http.Request) bool {
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return cred.Match(w.SharedData().GetBasicAuth(r))
}
},
},
@@ -403,16 +452,18 @@ var checkers = map[string]struct {
"route": "the route name",
},
},
validate: validateSingleMatcher,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
parsedArgs, err = validateSingleMatcher(args)
return
},
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(_ http.ResponseWriter, r *http.Request) bool {
return func(w *httputils.ResponseModifier, r *http.Request) bool {
return matcher(routes.TryGetUpstreamName(r))
}
},
},
OnStatus: {
isResponseChecker: true,
help: Help{
command: OnStatus,
description: makeLines(
@@ -429,16 +480,20 @@ var checkers = map[string]struct {
"status": "the status code range",
},
},
validate: validateStatusRange,
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
phase = PhasePost
parsedArgs, err = validateStatusRange(args)
return
},
builder: func(args any) CheckFunc {
beg, end := args.(*IntTuple).Unpack()
if beg == end {
return func(w http.ResponseWriter, _ *http.Request) bool {
return httputils.GetInitResponseModifier(w).StatusCode() == beg
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
return w.StatusCode() == beg
}
}
return func(w http.ResponseWriter, _ *http.Request) bool {
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
statusCode := w.StatusCode()
return statusCode >= beg && statusCode <= end
}
},
@@ -515,85 +570,119 @@ func splitPipe(s string) []string {
return []string{}
}
var result []string
var current strings.Builder
escaped := false
quote := rune(0)
result := make([]string, 0, 2)
quote := byte(0)
brackets := 0
start := 0
for _, r := range s {
if escaped {
current.WriteRune(r)
escaped = false
continue
}
switch r {
for i := 0; i < len(s); i++ {
switch s[i] {
case '\\':
escaped = true
current.WriteRune(r)
// Skip escaped character.
if i+1 < len(s) {
i++
}
case '"', '\'', '`':
if quote == 0 && brackets == 0 {
quote = r
} else if r == quote {
quote = s[i]
} else if s[i] == quote {
quote = 0
}
current.WriteRune(r)
case '(':
brackets++
current.WriteRune(r)
case ')':
if brackets > 0 {
brackets--
}
current.WriteRune(r)
case '|':
if quote == 0 && brackets == 0 {
// Found a pipe outside quotes/brackets, split here
result = append(result, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(r)
result = append(result, strings.TrimSpace(s[start:i]))
start = i + 1
}
default:
current.WriteRune(r)
}
}
// Add the last part
if current.Len() > 0 {
result = append(result, strings.TrimSpace(current.String()))
// drop trailing empty part.
if start < len(s) {
result = append(result, strings.TrimSpace(s[start:]))
}
return result
}
func forEachAndPart(s string, fn func(part string)) {
start := 0
for i := 0; i <= len(s); i++ {
if i < len(s) && andSeps[s[i]] == 0 {
continue
}
part := strings.TrimSpace(s[start:i])
if part != "" {
fn(part)
}
start = i + 1
}
}
func forEachPipePart(s string, fn func(part string)) {
quote := byte(0)
brackets := 0
start := 0
for i := 0; i < len(s); i++ {
switch s[i] {
case '\\':
if i+1 < len(s) {
i++
}
case '"', '\'', '`':
if quote == 0 && brackets == 0 {
quote = s[i]
} else if s[i] == quote {
quote = 0
}
case '(':
brackets++
case ')':
if brackets > 0 {
brackets--
}
case '|':
if quote == 0 && brackets == 0 {
fn(strings.TrimSpace(s[start:i]))
start = i + 1
}
}
}
if start < len(s) {
fn(strings.TrimSpace(s[start:]))
}
}
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error {
on.raw = v
rules := splitAnd(v)
checkAnd := make(CheckMatchAll, 0, len(rules))
ruleCount := 0
forEachAndPart(v, func(_ string) {
ruleCount++
})
checkAnd := make(CheckMatchAll, 0, ruleCount)
errs := gperr.NewBuilder("rule.on syntax errors")
isResponseChecker := false
for i, rule := range rules {
if rule == "" {
continue
}
parsed, isResp, err := parseOn(rule)
i := 0
forEachAndPart(v, func(rule string) {
i++
parsed, phase, err := parseOn(rule)
if err != nil {
errs.AddSubjectf(err, "line %d", i+1)
continue
}
if isResp {
isResponseChecker = true
errs.AddSubjectf(err, "line %d", i)
return
}
on.phase |= phase
checkAnd = append(checkAnd, parsed)
}
})
on.checker = checkAnd
on.isResponseChecker = isResponseChecker
return errs.Error()
}
@@ -605,33 +694,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
return []byte(on.String()), nil
}
func parseOn(line string) (Checker, bool, error) {
ors := splitPipe(line)
if len(ors) > 1 {
func parseOn(line string) (Checker, PhaseFlag, error) {
orCount := 0
forEachPipePart(line, func(_ string) {
orCount++
})
if orCount > 1 {
var phase PhaseFlag
errs := gperr.NewBuilder("rule.on syntax errors")
checkOr := make(CheckMatchSingle, len(ors))
isResponseChecker := false
for i, or := range ors {
curCheckers, isResp, err := parseOn(or)
checkOr := make(CheckMatchSingle, orCount)
i := 0
forEachPipePart(line, func(or string) {
i++
checkFunc, req, err := parseOnAtom(or)
if err != nil {
errs.Add(err)
continue
errs.AddSubjectf(err, "or[%d]", i)
return
}
if isResp {
isResponseChecker = true
}
checkOr[i] = curCheckers.(CheckFunc)
}
checkOr[i-1] = checkFunc
phase |= req
})
if err := errs.Error(); err != nil {
return nil, false, err
return nil, phase, err
}
return checkOr, isResponseChecker, nil
return checkOr, phase, nil
}
return parseOnAtom(line)
}
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
var phase PhaseFlag
subject, args, err := parse(line)
if err != nil {
return nil, false, err
return nil, phase, err
}
negate := false
@@ -642,20 +738,21 @@ func parseOn(line string) (Checker, bool, error) {
checker, ok := checkers[subject]
if !ok {
return nil, false, ErrInvalidOnTarget.Subject(subject)
return nil, phase, ErrInvalidOnTarget.Subject(subject)
}
validArgs, err := checker.validate(args)
req, validArgs, err := checker.validate(args)
if err != nil {
return nil, false, gperr.Wrap(err).With(checker.help.Error())
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
}
phase |= req
checkFunc := checker.builder(validArgs)
if negate {
origCheckFunc := checkFunc
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
return !origCheckFunc(w, r)
}
}
return checkFunc, checker.isResponseChecker, nil
return checkFunc, phase, nil
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/yusing/godoxy/internal/route"
"github.com/yusing/godoxy/internal/route/routes"
. "github.com/yusing/godoxy/internal/route/rules"
httputils "github.com/yusing/goutils/http"
expect "github.com/yusing/goutils/testing"
"golang.org/x/crypto/bcrypt"
)
@@ -386,7 +387,7 @@ func TestOnCorrectness(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
w := httputils.NewResponseModifier(httptest.NewRecorder())
var on RuleOn
err := on.Parse(tt.checker)
expect.NoError(t, err)

View File

@@ -1,8 +1,7 @@
package rules
import (
"bytes"
"fmt"
"strings"
"unicode"
"github.com/yusing/goutils/env"
@@ -25,6 +24,76 @@ var quoteChars = [256]bool{
'`': true,
}
func parseSimple(v string) (subject string, args []string, err error, ok bool) {
brackets := 0
for i := range len(v) {
switch v[i] {
case '\\', '$', '"', '\'', '`', '\t', '\r', '\n':
return "", nil, nil, false
case '(':
brackets++
case ')':
if brackets == 0 {
return "", nil, ErrUnterminatedBrackets, true
}
brackets--
}
}
if brackets != 0 {
return "", nil, ErrUnterminatedBrackets, true
}
i := 0
for i < len(v) && v[i] == ' ' {
i++
}
if i >= len(v) {
return "", nil, nil, true
}
start := i
for i < len(v) && v[i] != ' ' {
i++
}
subject = v[start:i]
if i >= len(v) {
return subject, nil, nil, true
}
argCount := 0
for j := i; j < len(v); {
for j < len(v) && v[j] == ' ' {
j++
}
if j >= len(v) {
break
}
argCount++
for j < len(v) && v[j] != ' ' {
j++
}
}
if argCount == 0 {
return subject, nil, nil, true
}
args = make([]string, 0, argCount)
for i < len(v) {
for i < len(v) && v[i] == ' ' {
i++
}
if i >= len(v) {
break
}
start = i
for i < len(v) && v[i] != ' ' {
i++
}
args = append(args, v[start:i])
}
return subject, args, nil, true
}
// parse expression to subject and args
// with support for quotes, escaped chars, and env substitution, e.g.
//
@@ -32,14 +101,21 @@ var quoteChars = [256]bool{
// error 403 Forbidden\ \"foo\"\ \"bar\".
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
func parse(v string) (subject string, args []string, err error) {
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
if subject, args, err, ok := parseSimple(v); ok {
return subject, args, err
}
buf := getStringBuffer(len(v))
args = make([]string, 0, 4)
escaped := false
quote := rune(0)
brackets := 0
var envVar bytes.Buffer
var missingEnvVars []string
var (
envVar strings.Builder
missingEnvVars []string
)
inEnvVar := false
expectingBrace := false
@@ -71,7 +147,8 @@ func parse(v string) (subject string, args []string, err error) {
if ch, ok := escapedChars[r]; ok {
buf.WriteRune(ch)
} else {
fmt.Fprintf(buf, `\%c`, r)
buf.WriteRune('\\')
buf.WriteRune(r)
}
escaped = false
continue

View File

@@ -4,6 +4,7 @@ import (
"strconv"
"testing"
gperr "github.com/yusing/goutils/errs"
expect "github.com/yusing/goutils/testing"
)
@@ -13,6 +14,7 @@ func TestParser(t *testing.T) {
input string
subject string
args []string
wantErr gperr.Error
}{
{
name: "basic",
@@ -90,6 +92,10 @@ func TestParser(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subject, args, err := parse(tt.input)
if tt.wantErr != nil {
expect.ErrorIs(t, tt.wantErr, err)
return
}
// t.Log(subject, args, err)
expect.NoError(t, err)
expect.Equal(t, subject, tt.subject)

View File

@@ -0,0 +1,29 @@
package rules
import "strings"
type PhaseFlag uint8
const (
PhaseNone PhaseFlag = 0
PhasePre PhaseFlag = 1 << (iota - 1)
PhasePost
)
func (phase PhaseFlag) IsPostRule() bool {
return phase&PhasePost != 0
}
func (phase PhaseFlag) String() string {
if phase == PhaseNone {
return "none"
}
var flags []string
if phase&PhasePre != 0 {
flags = append(flags, "PhasePre")
}
if phase&PhasePost != 0 {
flags = append(flags, "PhasePost")
}
return strings.Join(flags, ",")
}

View File

@@ -4,9 +4,16 @@ import (
"errors"
"fmt"
"net/http"
"reflect"
"slices"
"strings"
"unicode"
"github.com/goccy/go-yaml"
"github.com/quic-go/quic-go/http3"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/serialization"
gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http"
"golang.org/x/net/http2"
@@ -15,37 +22,36 @@ import (
type (
/*
Example:
Rules is a list of rules.
proxy.app1.rules: |
- name: default
do: |
rewrite / /index.html
serve /var/www/goaccess
- name: ws
on: |
header Connection Upgrade
header Upgrade websocket
do: bypass
Example:
proxy.app2.rules: |
- name: default
do: bypass
- name: block POST and PUT
on: method POST | method PUT
do: error 403 Forbidden
proxy.app1.rules: |
- name: default
do: |
rewrite / /index.html
serve /var/www/goaccess
- name: ws
on: |
header Connection Upgrade
header Upgrade websocket
do: bypass
proxy.app2.rules: |
- name: default
do: bypass
- name: block POST and PUT
on: method POST | method PUT
do: error 403 Forbidden
*/
//nolint:recvcheck
Rules []Rule
/*
Rule is a rule for a reverse proxy.
It do `Do` when `On` matches.
A rule can have multiple lines of on.
All lines of on must match,
but each line can have multiple checks that
one match means this line is matched.
*/
// Rule represents a reverse proxy rule.
// The `Do` field is executed when `On` matches.
//
// - A rule may have multiple lines in the `On` section.
// - All `On` lines must match for the rule to trigger.
// - Each line can have several checks—one match per line is enough for that line.
Rule struct {
Name string `json:"name"`
On RuleOn `json:"on" swaggertype:"string"`
@@ -53,103 +59,230 @@ type (
}
)
func (rule *Rule) IsResponseRule() bool {
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
func isDefaultRule(rule Rule) bool {
return rule.Name == "default" || rule.On.raw == OnDefault
}
func (rules Rules) Validate() error {
func (rules Rules) Validate() gperr.Error {
var defaultRulesFound []int
for i, rule := range rules {
if rule.Name == "default" || rule.On.raw == OnDefault {
for i := range rules {
rule := rules[i]
if isDefaultRule(rule) {
defaultRulesFound = append(defaultRulesFound, i)
}
if rules[i].Name == "" {
// set name to index if name is empty
rules[i].Name = fmt.Sprintf("rule[%d]", i)
}
}
if len(defaultRulesFound) > 1 {
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
}
for i := range rules {
r1 := rules[i]
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
continue
}
sig1, ok := matcherSignature(r1.On.raw)
if !ok {
continue
}
for j := i + 1; j < len(rules); j++ {
r2 := rules[j]
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
continue
}
sig2, ok := matcherSignature(r2.On.raw)
if !ok || sig1 != sig2 {
continue
}
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
}
}
return nil
}
func (rule Rule) doesTerminateInPre() bool {
for _, cmd := range rule.Do.pre {
handler, ok := cmd.(Handler)
if !ok {
continue
}
if handler.Terminates() {
return true
}
}
return false
}
func matcherSignature(raw string) (string, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", false
}
andParts := splitAnd(raw)
if len(andParts) == 0 {
return "", false
}
canonAnd := make([]string, 0, len(andParts))
for _, andPart := range andParts {
orParts := splitPipe(andPart)
if len(orParts) == 0 {
continue
}
canonOr := make([]string, 0, len(orParts))
for _, atom := range orParts {
subject, args, err := parse(strings.TrimSpace(atom))
if err != nil || subject == "" {
return "", false
}
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
}
slices.Sort(canonOr)
canonOr = slices.Compact(canonOr)
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
}
slices.Sort(canonAnd)
canonAnd = slices.Compact(canonAnd)
if len(canonAnd) == 0 {
return "", false
}
return strings.Join(canonAnd, "&"), true
}
// Parse parses a rule configuration string.
// It first tries the block syntax (if the string contains a top-level '{'),
// then falls back to YAML syntax.
func (rules *Rules) Parse(config string) error {
config = strings.TrimSpace(config)
if config == "" {
return nil
}
// Prefer block syntax if it looks like block syntax.
if hasTopLevelLBrace(config) {
blockRules, err := parseBlockRules(config)
if err == nil {
*rules = blockRules
return nil
}
// Fall through to YAML (backward compatibility).
}
// YAML fallback
var anySlice []any
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
if yamlErr == nil {
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
}
// If YAML fails and we didn't try block syntax yet, try it now.
blockRules, err := parseBlockRules(config)
if err == nil {
*rules = blockRules
return nil
}
return err
}
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
// Used to decide whether to prioritize the block syntax.
func hasTopLevelLBrace(s string) bool {
quote := rune(0)
inLine := false
inBlock := false
for i := 0; i < len(s); i++ {
c := s[i]
if inLine {
if c == '\n' {
inLine = false
}
continue
}
if inBlock {
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
inBlock = false
i++
}
continue
}
if quote != 0 {
if quote != '`' && c == '\\' && i+1 < len(s) {
i++
continue
}
if rune(c) == quote {
quote = 0
}
continue
}
switch c {
case '\'', '"', '`':
quote = rune(c)
continue
case '{':
return true
case '#':
inLine = true
continue
case '/':
if i+1 < len(s) && s[i+1] == '/' {
inLine = true
i++
continue
}
if i+1 < len(s) && s[i+1] == '*' {
inBlock = true
i++
continue
}
default:
if unicode.IsSpace(rune(c)) {
continue
}
}
}
return false
}
// BuildHandler returns a http.HandlerFunc that implements the rules.
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
if len(rules) == 0 {
return up
}
defaultRule := Rule{
Name: "default",
Do: Command{
raw: "pass",
exec: BypassCommand{},
},
}
var defaultRule *Rule
var nonDefaultRules Rules
hasDefaultRule := false
for i, rule := range rules {
if rule.Name == "default" || rule.On.raw == OnDefault {
defaultRule = rule
hasDefaultRule = true
for _, rule := range rules {
if isDefaultRule(rule) {
r := rule
defaultRule = &r
} else {
// set name to index if name is empty
if rule.Name == "" {
rule.Name = fmt.Sprintf("rule[%d]", i)
}
nonDefaultRules = append(nonDefaultRules, rule)
}
}
if len(nonDefaultRules) == 0 {
if defaultRule.Do.isBypass() {
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
return up
}
if defaultRule.IsResponseRule() {
return func(w http.ResponseWriter, r *http.Request) {
rm := httputils.NewResponseModifier(w)
defer func() {
if _, err := rm.FlushRelease(); err != nil {
logError(err, r)
}
}()
w = rm
up(w, r)
err := defaultRule.Do.exec.Handle(w, r)
if err != nil && !errors.Is(err, errTerminated) {
appendRuleError(rm, &defaultRule, err)
}
}
}
return func(w http.ResponseWriter, r *http.Request) {
rm := httputils.NewResponseModifier(w)
defer func() {
if _, err := rm.FlushRelease(); err != nil {
logError(err, r)
}
}()
w = rm
err := defaultRule.Do.exec.Handle(w, r)
if err == nil {
up(w, r)
return
}
if !errors.Is(err, errTerminated) {
appendRuleError(rm, &defaultRule, err)
}
}
}
preRules := make(Rules, 0, len(nonDefaultRules))
postRules := make(Rules, 0, len(nonDefaultRules))
for _, rule := range nonDefaultRules {
if rule.IsResponseRule() {
postRules = append(postRules, rule)
} else {
preRules = append(preRules, rule)
}
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
return cmd.pre.ServeHTTP(w, r, up)
}
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
return cmd.post.ServeHTTP(w, r, up)
}
return func(w http.ResponseWriter, r *http.Request) {
rm := httputils.NewResponseModifier(w)
@@ -159,104 +292,84 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
}
}()
w = rm
var hasError bool
shouldCallUpstream := true
preMatched := false
preRules := make(Rules, 0, len(nonDefaultRules)+1)
if defaultRule != nil {
preRules = append(preRules, *defaultRule)
}
preRules = append(preRules, nonDefaultRules...)
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
if defaultRule.Do.isBypass() {
// continue to upstream
} else {
err := defaultRule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
appendRuleError(rm, &defaultRule, err)
}
shouldCallUpstream = false
executedPre := make([]bool, len(preRules))
terminatedInPre := make([]bool, len(preRules))
preTerminated := false
for i, rule := range preRules {
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
continue
}
if preTerminated {
// Preserve post-only commands (e.g. logging) even after
// pre-phase termination.
if len(rule.Do.pre) == 0 {
executedPre[i] = true
}
continue
}
}
if shouldCallUpstream {
for _, rule := range preRules {
if rule.Check(w, r) {
preMatched = true
if rule.Do.isBypass() {
break // post rules should still execute
}
err := rule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
appendRuleError(rm, &rule, err)
}
shouldCallUpstream = false
break
}
executedPre[i] = true
if err := execPreCommand(rule.Do, rm, r); err != nil {
if errors.Is(err, errTerminateRule) {
terminatedInPre[i] = true
preTerminated = true
continue
}
logError(err, r)
hasError = true
}
}
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
if defaultRule.Do.isBypass() {
// continue to upstream
} else {
err := defaultRule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
appendRuleError(rm, &defaultRule, err)
return
}
shouldCallUpstream = false
if !rm.HasStatus() {
if hasError {
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
} else { // call upstream if no WriteHeader or Write was called and no error occurred
up(rm, r)
}
}
// Run post commands for rules that actually executed in pre phase,
// unless that same rule terminated in pre phase.
for i, rule := range preRules {
if !executedPre[i] || terminatedInPre[i] {
continue
}
if err := execPostCommand(rule.Do, rm, r); err != nil {
if errors.Is(err, errTerminateRule) {
continue
}
logError(err, r)
}
}
if shouldCallUpstream {
up(w, r)
}
// if no post rules, we are done here
if len(postRules) == 0 && !isDefaultRulePost {
return
}
for _, rule := range postRules {
if rule.Check(w, r) {
err := rule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
appendRuleError(rm, &rule, err)
}
return
// Run true post-matcher rules after response is available.
for _, rule := range nonDefaultRules {
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
continue
}
// Post-rule matchers are only evaluated after upstream, so commands parsed
// as "pre" for requirement purposes still need to run in this phase.
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
if errors.Is(err, errTerminateRule) {
continue
}
logError(err, r)
}
if err := execPostCommand(rule.Do, rm, r); err != nil {
if errors.Is(err, errTerminateRule) {
continue
}
logError(err, r)
}
}
if isDefaultRulePost {
err := defaultRule.Handle(w, r)
if err != nil && !errors.Is(err, errTerminated) {
appendRuleError(rm, &defaultRule, err)
}
}
}
}
func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) {
// rm.AppendError("rule: %s, error: %w", rule.Name, err)
}
func isTerminatingHandler(handler CommandHandler) bool {
switch h := handler.(type) {
case TerminatingCommand:
return true
case Commands:
if len(h) == 0 {
return false
}
return isTerminatingHandler(h[len(h)-1])
default:
return false
}
}
@@ -264,34 +377,30 @@ func (rule *Rule) String() string {
return rule.Name
}
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
if rule.On.checker == nil {
return true
}
v := rule.On.checker.Check(w, r)
return v
}
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
return rule.Do.exec.Handle(w, r)
return rule.On.Check(w, r)
}
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
var errStreamClosed error
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
var errClientDisconnected error
func logError(err error, r *http.Request) {
if errors.Is(err, errStreamClosed) {
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
return
}
var h2Err http2.StreamError
if errors.As(err, &h2Err) {
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
// ignore these errors
if h2Err.Code == http2.ErrCodeStreamClosed {
return
}
}
var h3Err *http3.Error
if errors.As(err, &h3Err) {
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
// ignore these errors
switch h3Err.ErrorCode {
case

View File

@@ -19,28 +19,71 @@ func TestRulesValidate(t *testing.T) {
{
name: "no default rule",
rules: `
- name: rule1
on: header Host example.com
do: pass
`,
header Host example.com {
pass
}`,
},
{
name: "multiple default rules",
rules: `
- name: default
do: pass
- name: rule1
on: default
do: pass
`,
default {
pass
}
default {
pass
}`,
want: ErrMultipleDefaultRules,
},
{
name: "multiple responses on same condition",
rules: `
header Host example.com {
error 404 "not found"
}
header Host example.com {
error 403 "forbidden"
}
`,
want: ErrDeadRule,
},
{
name: "same condition different formatting error then proxy",
rules: `
header Host example.com & method GET {
error 404 "not found"
}
method GET
header Host example.com {
proxy http://127.0.0.1:8080
}
`,
want: ErrDeadRule,
},
{
name: "same condition with non terminating first rule",
rules: `
header Host example.com {
set resp_header X-Test first
}
header Host example.com {
error 403 "forbidden"
}
`,
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var rules Rules
convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules))
require.True(t, convertible)
require.NoError(t, err)
err = rules.Validate()
if tt.want == nil {
assert.NoError(t, err)
@@ -50,3 +93,38 @@ func TestRulesValidate(t *testing.T) {
})
}
}
func TestHasTopLevelLBrace(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "escaped quote inside double quoted string",
in: `"test\"more{"`,
want: false,
},
{
name: "escaped quote inside single quoted string",
in: "'test\\'more{'",
want: false,
},
{
name: "top-level brace outside quoted string",
in: `"test\"more" {`,
want: true,
},
{
name: "backtick keeps existing behavior",
in: "`test\\`more{`",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, hasTopLevelLBrace(tt.in))
})
}
}

View File

@@ -0,0 +1,273 @@
package rules
import (
"strings"
"unicode"
gperr "github.com/yusing/goutils/errs"
)
// Tokenizer provides utilities for parsing rule syntax with proper handling
// of quotes, comments, and env vars.
//
// This is intentionally reusable by both the top-level rule block parser and
// the nested do-block parser.
type Tokenizer struct {
src string
length int
}
// newTokenizer creates a tokenizer for the given source.
func newTokenizer(src string) Tokenizer {
return Tokenizer{src: src, length: len(src)}
}
// skipComments skips whitespace, line comments, and block comments.
// It returns the new position and an error if a block comment is unterminated.
func (t *Tokenizer) skipComments(pos int, atLineStart bool, prevIsSpace bool) (int, gperr.Error) {
for pos < t.length {
c := t.src[pos]
// Skip whitespace
if unicode.IsSpace(rune(c)) {
pos++
atLineStart = false
prevIsSpace = true
continue
}
// Check for line comment: // or #
if c == '/' {
if pos+1 < t.length && t.src[pos+1] == '/' {
// Skip to end of line
for pos < t.length && t.src[pos] != '\n' {
pos++
}
atLineStart = true
prevIsSpace = true
continue
}
}
if c == '#' && (atLineStart || prevIsSpace) {
// Skip to end of line
for pos < t.length && t.src[pos] != '\n' {
pos++
}
atLineStart = true
prevIsSpace = true
continue
}
// Check for block comment: /*
if c == '/' && pos+1 < t.length && t.src[pos+1] == '*' {
pos += 2
closed := false
for pos+1 < t.length {
if t.src[pos] == '*' && t.src[pos+1] == '/' {
pos += 2
closed = true
break
}
pos++
}
if !closed {
return 0, ErrInvalidBlockSyntax.Withf("unterminated block comment")
}
atLineStart = false
prevIsSpace = true
continue
}
break
}
return pos, nil
}
// scanToBrace scans from pos until it finds '{' outside quotes, or returns an error.
func (t *Tokenizer) scanToBrace(pos int) (int, gperr.Error) {
quote := rune(0)
for pos < t.length {
c := rune(t.src[pos])
if quote != 0 {
if c == quote {
quote = 0
}
pos++
continue
}
if c == '"' || c == '\'' || c == '`' {
quote = c
pos++
continue
}
if c == '{' {
return pos, nil
}
if c == '}' {
return 0, ErrInvalidBlockSyntax.Withf("unmatched '}' in block header")
}
pos++
}
return 0, ErrInvalidBlockSyntax.Withf("expected '{' after block header")
}
// findMatchingBrace finds the matching '}' for a '{' starting at startPos.
// It respects quotes/backticks and ${...} env vars.
func (t *Tokenizer) findMatchingBrace(startPos int) (int, gperr.Error) {
pos := startPos
braceDepth := 1
quote := rune(0)
inLine := false
inBlock := false
atLineStart := true
prevIsSpace := true
for pos < t.length {
c := rune(t.src[pos])
if inLine {
if c == '\n' {
inLine = false
atLineStart = true
prevIsSpace = true
}
pos++
continue
}
if inBlock {
if c == '*' && pos+1 < t.length && t.src[pos+1] == '/' {
pos += 2
inBlock = false
continue
}
if c == '\n' {
atLineStart = true
prevIsSpace = true
}
pos++
continue
}
if quote != 0 {
if c == quote {
quote = 0
}
if c == '\n' {
atLineStart = true
prevIsSpace = true
} else {
atLineStart = false
prevIsSpace = unicode.IsSpace(c)
}
pos++
continue
}
if c == '"' || c == '\'' || c == '`' {
quote = c
atLineStart = false
prevIsSpace = false
pos++
continue
}
// Comments (only outside quotes) at token boundary
if c == '#' && (atLineStart || prevIsSpace) {
inLine = true
pos++
continue
}
if c == '/' && pos+1 < t.length {
n := rune(t.src[pos+1])
if (atLineStart || prevIsSpace) && n == '/' {
inLine = true
pos += 2
continue
}
if (atLineStart || prevIsSpace) && n == '*' {
inBlock = true
pos += 2
continue
}
}
if c == '$' && pos+1 < t.length && t.src[pos+1] == '{' {
// Skip env var ${...}
pos += 2
envBraceDepth := 1
envQuote := rune(0)
for pos < t.length {
ec := rune(t.src[pos])
if envQuote != 0 {
if ec == envQuote {
envQuote = 0
}
pos++
continue
}
if ec == '"' || ec == '\'' || ec == '`' {
envQuote = ec
pos++
continue
}
if ec == '{' {
envBraceDepth++
} else if ec == '}' {
envBraceDepth--
if envBraceDepth == 0 {
pos++ // Move past the closing '}'
break
}
}
pos++
}
continue
}
switch c {
case '{':
braceDepth++
case '}':
braceDepth--
if braceDepth == 0 {
return pos, nil
}
}
if c == '\n' {
atLineStart = true
prevIsSpace = true
} else {
atLineStart = false
prevIsSpace = unicode.IsSpace(c)
}
pos++
}
return 0, ErrInvalidBlockSyntax.Withf("unmatched '{' at position %d", startPos)
}
// parseHeaderToBrace parses an expression/header starting at start and returns:
// - header: trimmed src[start:bracePos]
// - bracePos: position of '{' (outside quotes/backticks)
func parseHeaderToBrace(src string, start int) (header string, bracePos int, err gperr.Error) {
t := newTokenizer(src)
bracePos, err = t.scanToBrace(start)
if err != nil {
return "", 0, err
}
return strings.TrimSpace(src[start:bracePos]), bracePos, nil
}
// findMatchingBrace finds the matching '}' for a '{' at position startPos.
// It respects quotes/backticks and ${...} env vars in do_body.
func findMatchingBrace(src string, pos *int, startPos int) (int, gperr.Error) {
t := newTokenizer(src)
endPos, err := t.findMatchingBrace(startPos)
if err != nil {
return 0, err
}
*pos = endPos + 1
return endPos, nil
}

View File

@@ -0,0 +1,39 @@
package rules
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestTokenizer_skipComments_UnterminatedBlockComment(t *testing.T) {
tok := newTokenizer("/* unterminated")
pos, err := tok.skipComments(0, true, true)
require.Error(t, err)
require.Equal(t, 0, pos)
}
func TestTokenizer_skipComments_SkipsLineAndBlockComments(t *testing.T) {
src := " // line\n /* block */\n # hash\n default"
tok := newTokenizer(src)
pos, err := tok.skipComments(0, true, true)
require.NoError(t, err)
require.Equal(t, strings.Index(src, "default"), pos)
}
func TestTokenizer_scanToBrace_IgnoresQuotedBraces(t *testing.T) {
src := "cond \"{\" {" // the brace inside quotes must be ignored
tok := newTokenizer(src)
bracePos, err := tok.scanToBrace(0)
require.NoError(t, err)
require.Equal(t, strings.LastIndex(src, "{"), bracePos)
}
func TestTokenizer_findMatchingBrace_IgnoresQuotedClosingBrace(t *testing.T) {
src := `{ "}" }`
tok := newTokenizer(src)
endPos, err := tok.findMatchingBrace(1) // body starts after the first '{'
require.NoError(t, err)
require.Equal(t, strings.LastIndex(src, "}"), endPos)
}

View File

@@ -4,13 +4,13 @@ import (
"io"
"net/http"
"strings"
"unsafe"
httputils "github.com/yusing/goutils/http"
)
type templateString struct {
string
isTemplate bool
}
@@ -23,32 +23,28 @@ func (tmpl *keyValueTemplate) Unpack() (string, templateString) {
return tmpl.key, tmpl.tmpl
}
func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error {
func (tmpl *templateString) ExpandVars(w *httputils.ResponseModifier, req *http.Request, dst io.Writer) (phase PhaseFlag, err error) {
if !tmpl.isTemplate {
_, err := dstW.Write(strtobNoCopy(tmpl.string))
return err
_, err := asBytesBufferLike(dst).WriteString(tmpl.string)
return PhaseNone, err
}
return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW)
return ExpandVars(w, req, tmpl.string, dst)
}
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
func (tmpl *templateString) ExpandVarsToString(w *httputils.ResponseModifier, r *http.Request) (string, PhaseFlag, error) {
if !tmpl.isTemplate {
return tmpl.string, nil
return tmpl.string, PhaseNone, nil
}
var buf strings.Builder
err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf)
phase, err := tmpl.ExpandVars(w, r, &buf)
if err != nil {
return "", err
return "", PhaseNone, err
}
return buf.String(), nil
return buf.String(), phase, nil
}
func (tmpl *templateString) Len() int {
return len(tmpl.string)
}
func strtobNoCopy(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}

View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
nettypes "github.com/yusing/godoxy/internal/net/types"
gperr "github.com/yusing/goutils/errs"
@@ -16,7 +17,7 @@ import (
)
type (
ValidateFunc func(args []string) (any, error)
ValidateFunc func(args []string) (phase PhaseFlag, parsedArgs any, err error)
Tuple[T1, T2 any] struct {
First T1
Second T2
@@ -37,6 +38,8 @@ type (
MapValueMatcher = Tuple[string, Matcher]
)
var cidrCache = xsync.NewMap[string, *net.IPNet]()
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
return t.First, t.Second
}
@@ -62,7 +65,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string {
}
// validateSingleMatcher returns Matcher with the matcher validated.
func validateSingleMatcher(args []string) (any, error) {
func validateSingleMatcher(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
@@ -70,7 +73,7 @@ func validateSingleMatcher(args []string) (any, error) {
}
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
func toKVOptionalVMatcher(args []string) (any, error) {
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
switch len(args) {
case 1:
return &MapValueMatcher{args[0], nil}, nil
@@ -85,20 +88,8 @@ func toKVOptionalVMatcher(args []string) (any, error) {
}
}
func toKeyValueTemplate(args []string) (any, error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
isTemplate, err := validateTemplate(args[1], false)
if err != nil {
return nil, err
}
return &keyValueTemplate{args[0], isTemplate}, nil
}
// validateURL returns types.URL with the URL validated.
func validateURL(args []string) (any, error) {
func validateURL(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
@@ -116,22 +107,27 @@ func validateURL(args []string) (any, error) {
}
// validateCIDR returns types.CIDR with the CIDR validated.
func validateCIDR(args []string) (any, error) {
func validateCIDR(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
if !strings.Contains(args[0], "/") {
args[0] += "/32"
cidr := args[0]
if !strings.Contains(cidr, "/") {
cidr += "/32"
}
_, ipnet, err := net.ParseCIDR(args[0])
if cached, ok := cidrCache.Load(cidr); ok {
return cached, nil
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
cidrCache.Store(cidr, ipnet)
return ipnet, nil
}
// validateURLPath returns string with the path validated.
func validateURLPath(args []string) (any, error) {
func validateURLPath(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
@@ -148,7 +144,7 @@ func validateURLPath(args []string) (any, error) {
return p, nil
}
func validateURLPathMatcher(args []string) (any, error) {
func validateURLPathMatcher(args []string) (any, gperr.Error) {
path, err := validateURLPath(args)
if err != nil {
return nil, err
@@ -157,7 +153,7 @@ func validateURLPathMatcher(args []string) (any, error) {
}
// validateFSPath returns string with the path validated.
func validateFSPath(args []string) (any, error) {
func validateFSPath(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
@@ -169,7 +165,7 @@ func validateFSPath(args []string) (any, error) {
}
// validateMethod returns string with the method validated.
func validateMethod(args []string) (any, error) {
func validateMethod(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
@@ -200,7 +196,7 @@ func validateStatusCode(status string) (int, error) {
// - 3xx
// - 4xx
// - 5xx
func validateStatusRange(args []string) (any, error) {
func validateStatusRange(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
@@ -232,7 +228,7 @@ func validateStatusRange(args []string) (any, error) {
}
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
func validateUserBCryptPassword(args []string) (any, error) {
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
@@ -240,64 +236,93 @@ func validateUserBCryptPassword(args []string) (any, error) {
}
// validateModField returns CommandHandler with the field validated.
func validateModField(mod FieldModifier, args []string) (CommandHandler, error) {
func validateModField(mod FieldModifier, args []string) (phase PhaseFlag, handler HandlerFunc, err error) {
if len(args) == 0 {
return nil, ErrExpectTwoOrThreeArgs
return phase, nil, ErrExpectTwoOrThreeArgs
}
setField, ok := modFields[args[0]]
if !ok {
return nil, ErrUnknownModField.Subject(args[0])
return phase, nil, ErrUnknownModField.Subject(args[0])
}
if mod == ModFieldRemove {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
return phase, nil, ErrExpectTwoArgs
}
// setField expect validateStrTuple
args = append(args, "")
}
validArgs, err := setField.validate(args[1:])
phase, validArgs, err := setField.validate(args[1:])
if err != nil {
return nil, gperr.Wrap(err).With(setField.help.Error())
return phase, nil, gperr.Wrap(err).With(setField.help.Error())
}
modder := setField.builder(validArgs)
switch mod {
case ModFieldAdd:
add := modder.add
if add == nil {
return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod)
return phase, nil, ErrInvalidArguments.Withf("add is not supported for field %s", args[0])
}
return add, nil
return phase, add, nil
case ModFieldRemove:
remove := modder.remove
if remove == nil {
return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod)
return phase, nil, ErrInvalidArguments.Withf("remove is not supported for field %s", args[0])
}
return remove, nil
return phase, remove, nil
}
set := modder.set
if set == nil {
return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod)
return phase, nil, ErrInvalidArguments.Withf("set is not supported for field %s", args[0])
}
return set, nil
return phase, set, nil
}
func validateTemplate(tmplStr string, newline bool) (templateString, error) {
func validateTemplate(tmplStr string, newline bool) (phase PhaseFlag, tmpl templateString, err error) {
if newline && !strings.HasSuffix(tmplStr, "\n") {
tmplStr += "\n"
}
if !NeedExpandVars(tmplStr) {
return templateString{tmplStr, false}, nil
return phase, templateString{tmplStr, false}, nil
}
err := ValidateVars(tmplStr)
phase, err = ValidateVars(tmplStr)
if err != nil {
return templateString{}, err
return phase, templateString{}, gperr.Wrap(err)
}
return templateString{tmplStr, true}, nil
return phase, templateString{tmplStr, true}, nil
}
func validateLevel(level string) (zerolog.Level, error) {
func validatePreRequestKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 2 {
return phase, nil, ErrExpectTwoArgs
}
phase = PhasePre
tmplReq, tmpl, err := validateTemplate(args[1], false)
if err != nil {
return phase, nil, err
}
phase |= tmplReq
return phase, &keyValueTemplate{args[0], tmpl}, nil
}
func validatePostResponseKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
if len(args) != 2 {
return phase, nil, ErrExpectTwoArgs
}
phase = PhasePost
tmplReq, tmpl, err := validateTemplate(args[1], false)
if err != nil {
return phase, nil, err
}
phase |= tmplReq
return phase, &keyValueTemplate{args[0], tmpl}, nil
}
func validateLevel(level string) (zerolog.Level, gperr.Error) {
l, err := zerolog.ParseLevel(level)
if err != nil {
return zerolog.NoLevel, ErrInvalidArguments.With(err)

View File

@@ -23,7 +23,7 @@ func BenchmarkExpandVars(b *testing.B) {
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
for b.Loop() {
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
if err != nil {
b.Fatal(err)
}

View File

@@ -1,15 +1,16 @@
package rules
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"unsafe"
httputils "github.com/yusing/goutils/http"
ioutils "github.com/yusing/goutils/io"
)
// TODO: remove middleware/vars.go and use this instead
@@ -45,41 +46,84 @@ var (
}
)
type bytesBufferLike interface {
io.Writer
WriteByte(c byte) error
WriteString(s string) (int, error)
}
type bytesBufferAdapter struct {
io.Writer
}
func (b bytesBufferAdapter) WriteByte(c byte) error {
buf := [1]byte{c}
_, err := b.Write(buf[:])
return err
}
func (b bytesBufferAdapter) WriteString(s string) (int, error) {
return b.Write(unsafe.Slice(unsafe.StringData(s), len(s))) // avoid copy
}
func asBytesBufferLike(w io.Writer) bytesBufferLike {
switch w := w.(type) {
case *bytes.Buffer:
return w
case bytesBufferLike:
return w
default:
return bytesBufferAdapter{w}
}
}
// ValidateVars validates the variables in the given string.
// It returns ErrUnexpectedVar if any invalid variable is found.
func ValidateVars(s string) error {
// It returns the phase that the variables require and an error if any error occurs.
//
// Possible errors:
// - ErrUnexpectedVar: if any invalid variable is found
// - ErrUnterminatedEnvVar: missing closing }
// - ErrUnterminatedQuotes: missing closing " or ' or `
// - ErrUnterminatedParenthesis: missing closing )
func ValidateVars(s string) (phase PhaseFlag, err error) {
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
}
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
dst := ioutils.NewBufferedWriter(dstW, 1024)
defer dst.Close()
// ExpandVars expands the variables in the given string and writes the result to the given writer.
// It returns the phase that the variables require and an error if any error occurs.
//
// Possible errors:
// - ErrUnexpectedVar: if any invalid variable is found
// - ErrUnterminatedEnvVar: missing closing }
// - ErrUnterminatedQuotes: missing closing " or ' or `
// - ErrUnterminatedParenthesis: missing closing )
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) (phase PhaseFlag, err error) {
dst := asBytesBufferLike(dstW)
for i := 0; i < len(src); i++ {
ch := src[i]
if ch != '$' {
if err := dst.WriteByte(ch); err != nil {
return err
if err = dst.WriteByte(ch); err != nil {
return phase, err
}
continue
}
// Look ahead
if i+1 >= len(src) {
return ErrUnterminatedEnvVar
return phase, ErrUnterminatedEnvVar
}
j := i + 1
switch src[j] {
case '$': // $$ -> literal '$'
if err := dst.WriteByte('$'); err != nil {
return err
return phase, err
}
i = j
continue
case '{': // ${...} pass through as-is
if _, err := dst.WriteString("${"); err != nil {
return err
return phase, err
}
i = j // we've consumed the '{' too
continue
@@ -102,24 +146,26 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
if getter, ok := dynamicVarSubsMap[name]; ok {
// Function-like variables
isStatic = false
phase |= getter.phase
args, nextIdx, err := extractArgs(src, j, name)
if err != nil {
return err
return phase, err
}
i = nextIdx
actual, err = getter(args, w, req)
actual, err = getter.get(args, w, req)
if err != nil {
return err
return phase, err
}
} else if getter, ok := staticReqVarSubsMap[name]; ok {
} else if getter, ok := staticReqVarSubsMap[name]; ok { // always available
actual = getter(req)
} else if getter, ok := staticRespVarSubsMap[name]; ok {
} else if getter, ok := staticRespVarSubsMap[name]; ok { // post response
actual = getter(w)
phase |= PhasePost
} else {
return ErrUnexpectedVar.Subject(name)
return phase, ErrUnexpectedVar.Subject(name)
}
if _, err := dst.WriteString(actual); err != nil {
return err
return phase, err
}
if isStatic {
i = k - 1
@@ -128,10 +174,10 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
}
// No valid construct after '$'
return ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
return phase, ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
}
return nil
return phase, nil
}
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {

View File

@@ -11,58 +11,88 @@ import (
var (
VarHeader = "header"
VarResponseHeader = "resp_header"
VarCookie = "cookie"
VarQuery = "arg"
VarForm = "form"
VarPostForm = "postform"
)
type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
type dynamicVarGetter struct {
phase PhaseFlag
get func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
}
var dynamicVarSubsMap = map[string]dynamicVarGetter{
VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(req.Header, key, index)
},
VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(w.Header(), key, index)
},
VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
},
VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
if req.Form == nil {
if err := req.ParseForm(); err != nil {
VarHeader: {
phase: PhaseNone,
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
}
return getValueByKeyAtIndex(req.Form, key, index)
return getValueByKeyAtIndex(req.Header, key, index)
},
},
VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
if req.Form == nil {
if err := req.ParseForm(); err != nil {
VarResponseHeader: {
phase: PhasePost,
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
}
return getValueByKeyAtIndex(req.PostForm, key, index)
return getValueByKeyAtIndex(w.Header(), key, index)
},
},
VarCookie: {
phase: PhaseNone,
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
sharedData := httputils.GetSharedData(w)
return getValueByKeyAtIndex(sharedData.GetCookiesMap(req), key, index)
},
},
VarQuery: {
phase: PhaseNone,
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
},
},
VarForm: {
phase: PhaseNone,
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
if req.Form == nil {
if err := req.ParseForm(); err != nil {
return "", err
}
}
return getValueByKeyAtIndex(req.Form, key, index)
},
},
VarPostForm: {
phase: PhaseNone,
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
key, index, err := getKeyAndIndex(args)
if err != nil {
return "", err
}
if req.Form == nil {
if err := req.ParseForm(); err != nil {
return "", err
}
}
return getValueByKeyAtIndex(req.PostForm, key, index)
},
},
}

View File

@@ -232,7 +232,7 @@ func TestExpandVars(t *testing.T) {
{
name: "req_method",
input: "$req_method",
want: "POST",
want: http.MethodPost,
},
{
name: "req_path",
@@ -484,7 +484,7 @@ func TestExpandVars(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
_, err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
if tt.wantErr {
require.Error(t, err)
@@ -506,7 +506,7 @@ func TestExpandVars_Integration(t *testing.T) {
testResponseModifier.WriteHeader(http.StatusOK)
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest,
_, err := ExpandVars(testResponseModifier, testRequest,
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
&out)
@@ -520,7 +520,7 @@ func TestExpandVars_Integration(t *testing.T) {
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest,
_, err := ExpandVars(testResponseModifier, testRequest,
"Query: $arg(q), Page: $arg(page)",
&out)
@@ -537,7 +537,7 @@ func TestExpandVars_Integration(t *testing.T) {
testResponseModifier.WriteHeader(http.StatusOK)
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest,
_, err := ExpandVars(testResponseModifier, testRequest,
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
&out)
@@ -560,7 +560,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
{
name: "https scheme",
request: &http.Request{
Method: "GET",
Method: http.MethodGet,
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
TLS: &tls.ConnectionState{}, // Simulate TLS connection
},
@@ -572,7 +572,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
var out strings.Builder
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
_, err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
require.NoError(t, err)
require.Equal(t, tt.expected, out.String())
})
@@ -598,7 +598,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
for _, varExpr := range upstreamVars {
t.Run(varExpr, func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
_, err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
// Should not error, may return empty string
require.NoError(t, err)
})
@@ -614,16 +614,16 @@ func TestExpandVars_NoHostPort(t *testing.T) {
t.Run("req_host without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
_, err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
require.NoError(t, err)
require.Equal(t, "example.com", out.String())
})
t.Run("req_port without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
_, err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
require.NoError(t, err)
require.Empty(t, out.String())
require.Equal(t, "", out.String())
})
}
@@ -636,16 +636,16 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
t.Run("remote_host without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
require.NoError(t, err)
require.Empty(t, out.String())
require.Equal(t, "", out.String())
})
t.Run("remote_port without port", func(t *testing.T) {
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
require.NoError(t, err)
require.Empty(t, out.String())
require.Equal(t, "", out.String())
})
}
@@ -654,7 +654,7 @@ func TestExpandVars_WhitespaceHandling(t *testing.T) {
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
var out strings.Builder
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
require.NoError(t, err)
require.Equal(t, "GET /test", out.String())
}
@@ -699,7 +699,7 @@ func TestValidateVars(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateVars(tt.input)
_, err := ValidateVars(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {