mirror of
https://github.com/yusing/godoxy.git
synced 2026-02-23 17:24:58 +01:00
Compare commits
4 Commits
main
...
feat/rules
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08de9086c3 | ||
|
|
1a17f3943a | ||
|
|
9bb5c54e7c | ||
|
|
faecbab2cb |
2
goutils
2
goutils
Submodule goutils updated: 482b5bca9f...3be815cb6e
@@ -298,7 +298,7 @@ func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
|
||||
On: onStr,
|
||||
Do: doStr,
|
||||
ValidationError: validationErr,
|
||||
IsResponseRule: rule.IsResponseRule(),
|
||||
// IsResponseRule: rule.Requirement()&rules.RequirementFlagResponse != 0,
|
||||
})
|
||||
|
||||
// Only add valid rules to execution list
|
||||
|
||||
@@ -36,15 +36,15 @@ type Rule struct {
|
||||
}
|
||||
|
||||
type RuleOn struct {
|
||||
raw string
|
||||
checker Checker
|
||||
isResponseChecker bool
|
||||
raw string
|
||||
checker Checker
|
||||
phase PhaseFlag
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
raw string
|
||||
exec CommandHandler
|
||||
isResponseHandler bool
|
||||
raw string
|
||||
pre Commands
|
||||
post Commands
|
||||
}
|
||||
```
|
||||
|
||||
@@ -59,6 +59,9 @@ func ParseRules(config string) (Rules, error)
|
||||
|
||||
// ValidateRules validates rule syntax
|
||||
func ValidateRules(config string) error
|
||||
|
||||
// Validate validates rule semantics (e.g., prevents multiple default rules)
|
||||
func (rules Rules) Validate() gperr.Error
|
||||
```
|
||||
|
||||
## Architecture
|
||||
@@ -122,16 +125,52 @@ sequenceDiagram
|
||||
Pre->>Pre: Execute handler
|
||||
alt Terminating action
|
||||
Pre-->>Req: Response
|
||||
Return-->>Req: Return immediately
|
||||
Note right of Pre: Stop remaining pre commands
|
||||
end
|
||||
end
|
||||
Req->>Proxy: Forward request
|
||||
Proxy-->>Req: Response
|
||||
Req->>Post: Check post-rules
|
||||
Post->>Post: Execute handlers
|
||||
Post-->>Req: Modified response
|
||||
opt No pre termination
|
||||
Req->>Proxy: Forward request
|
||||
Proxy-->>Req: Response
|
||||
end
|
||||
Req->>Post: Run scheduled post commands
|
||||
Req->>Post: Evaluate response matchers
|
||||
Post->>Post: Execute matched post handlers
|
||||
Post-->>Req: Final response
|
||||
```
|
||||
|
||||
### Execution Model (Authoritative)
|
||||
|
||||
Rules run in two phases:
|
||||
|
||||
1. **Pre phase**
|
||||
- Evaluate only request-based matchers (`path`, `method`, `header`, `remote`, etc.) in declaration order.
|
||||
- Execute matched rule `do` pre-commands in order.
|
||||
- If a default rule exists (`name: default` or `on: default`), it is a fallback and runs only when no non-default pre rule matches.
|
||||
- If a terminating action runs, stop:
|
||||
- remaining commands in that rule
|
||||
- all later pre-phase commands.
|
||||
- Exception: rules that only contain post commands (no pre commands) are still scheduled for post phase.
|
||||
|
||||
2. **Upstream phase**
|
||||
- Upstream is called only if pre phase did not terminate.
|
||||
|
||||
3. **Post phase**
|
||||
- Run post-commands for rules whose pre phase executed, except rules that terminated in pre.
|
||||
- Then evaluate response-based matchers (`status`, `resp_header`) and execute their `do` commands.
|
||||
- Response-based rules run even when the response was produced in pre phase.
|
||||
|
||||
**Important:** termination is explicit by command semantics, not inferred from status-code mutation.
|
||||
|
||||
### Phase Flags
|
||||
|
||||
Rule and command parsing tracks phase requirements via `PhaseFlag`:
|
||||
|
||||
- `PhasePre`
|
||||
- `PhasePost`
|
||||
- `PhasePre | PhasePost` (combined)
|
||||
|
||||
Combined flags are expected for nested/compound commands and variable templates that may need both request and response context.
|
||||
|
||||
### Condition Matchers
|
||||
|
||||
| Matcher | Type | Description |
|
||||
@@ -166,22 +205,22 @@ path regex("/api/v[0-9]+/.*") // regex pattern
|
||||
|
||||
**Terminating Actions** (stop processing):
|
||||
|
||||
| Command | Description |
|
||||
| ------------------------ | ---------------------- |
|
||||
| `error <code> <message>` | Return HTTP error |
|
||||
| `redirect <url>` | Redirect to URL |
|
||||
| `serve <path>` | Serve local files |
|
||||
| `route <name>` | Route to another route |
|
||||
| `proxy <url>` | Proxy to upstream |
|
||||
| Command | Description |
|
||||
| ------------------------------ | ------------------------------------- |
|
||||
| `upstream` / `bypass` / `pass` | Call upstream and terminate pre-phase |
|
||||
| `error <code> <message>` | Return HTTP error |
|
||||
| `redirect <url>` | Redirect to URL |
|
||||
| `serve <path>` | Serve local files |
|
||||
| `route <name>` | Route to another route |
|
||||
| `proxy <url>` | Proxy to upstream |
|
||||
| `require_basic_auth <realm>` | Return 401 challenge |
|
||||
|
||||
**Non-Terminating Actions** (modify and continue):
|
||||
|
||||
| Command | Description |
|
||||
| ------------------------------ | ---------------------- |
|
||||
| `pass` / `bypass` | Pass through unchanged |
|
||||
| `rewrite <from> <to>` | Rewrite request path |
|
||||
| `require_auth` | Require authentication |
|
||||
| `require_basic_auth <realm>` | Basic auth challenge |
|
||||
| `set <target> <field> <value>` | Set header/variable |
|
||||
| `add <target> <field> <value>` | Add header/variable |
|
||||
| `remove <target> <field>` | Remove header/variable |
|
||||
@@ -208,6 +247,166 @@ rules:
|
||||
action2
|
||||
```
|
||||
|
||||
### Rule Configuration (Block Syntax)
|
||||
|
||||
This is an alternative (and will eventually be the primary) syntax for rules that avoids YAML.
|
||||
It keeps the **inner** `on` and `do` DSLs exactly the same (same matchers, same commands, same optional quotes), but wraps each rule in a `{ ... }` block.
|
||||
|
||||
#### Key ideas
|
||||
|
||||
- A rule is:
|
||||
- `default { <do...> }`,
|
||||
- `{ <do...> }`, or
|
||||
- `<on-expr> { <do...> }`
|
||||
- Comments are supported:
|
||||
- line comment: `// ...` (to end of line)
|
||||
- line comment: `# ...` (to end of line, for YAML familiarity)
|
||||
- block comment: `/* ... */` (may span multiple lines)
|
||||
- Comments are ignored **only when outside quotes** (`"`, `'` or backticks).
|
||||
- Environment variable syntax: `${NAME}` is supported by the inner DSL parser in [`parse()`](internal/route/rules/parser.go:34).
|
||||
Block-syntax rule:
|
||||
- In `on` (rule header): `${...}` must be inside quotes/backticks.
|
||||
- In `do` (rule body): `${...}` may be unquoted; the outer parser must treat `${...}` as an opaque token so braces inside it are not structural.
|
||||
|
||||
#### Grammar sketch (EBNF-ish)
|
||||
|
||||
```text
|
||||
file := { ws | comment | rule }
|
||||
rule := default_rule | unconditional_rule | conditional_rule
|
||||
|
||||
default_rule := 'default' ws* block
|
||||
unconditional_rule := ws* block
|
||||
conditional_rule := on_expr ws* block
|
||||
|
||||
block := '{' do_body '}'
|
||||
|
||||
// on_expr and do_body are raw text regions.
|
||||
// The outer parser only needs to:
|
||||
// - find the top-level '{' to start a rule block
|
||||
// - find the matching top-level '}' to end it
|
||||
// while respecting quotes and comments.
|
||||
```
|
||||
|
||||
#### Elif/Else Chain Grammar
|
||||
|
||||
```text
|
||||
// Elif/Else chains can appear in do_body
|
||||
do_stmt := command_line | nested_block | elif_else_chain
|
||||
elif_else_chain := nested_block { elif_clause } [else_clause]
|
||||
elif_clause := 'elif' ws* on_expr ws* '{' do_body '}'
|
||||
else_clause := 'else' ws* '{' do_body '}'
|
||||
```
|
||||
|
||||
#### Nested blocks (inline conditionals inside `do`)
|
||||
|
||||
Inside a rule body (`do_body`), you can write **nested blocks** that start with `@`:
|
||||
|
||||
```text
|
||||
do_stmt := command_line | nested_block | elif_else_chain
|
||||
|
||||
nested_block := '@' on_expr ws* '{' do_body '}'
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- A nested block is only recognized when `@` is the **first non-space character on a line**.
|
||||
- `on_expr` uses the same syntax as rule `on` (supports `|`, `&`, quoting/backticks, matcher functions, etc.).
|
||||
- The nested block executes **in sequence**, at the point where it appears in the parent `do` list.
|
||||
- Nested blocks are evaluated in the same phase the parent rule runs (no special phase promotion).
|
||||
- Nested blocks can be chained with `elif`/`else` for conditional execution (see Elif/Else Chains section).
|
||||
|
||||
Example:
|
||||
|
||||
```go
|
||||
default {
|
||||
remove resp_header X-Secret
|
||||
add resp_header X-Custom-Header custom-value
|
||||
}
|
||||
|
||||
header X-Test-Header {
|
||||
set header X-Remote-Type public
|
||||
@remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||
set header X-Remote-Type private
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Elif/Else Chains
|
||||
|
||||
You can chain multiple conditions using `elif` and provide a fallback with `else`.
|
||||
The `elif`/`else` keywords must appear on the same line as the preceding closing brace (`}`).
|
||||
|
||||
```go
|
||||
header X-Test-Header {
|
||||
@method GET {
|
||||
set header X-Mode get
|
||||
} elif method POST {
|
||||
set header X-Mode post
|
||||
} else {
|
||||
set header X-Mode other
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `elif` and `else` must be on the same line as the preceding `}`.
|
||||
- Multiple `elif` branches are allowed; only one `else` is allowed.
|
||||
- The entire chain is evaluated in sequence; the first matching branch executes.
|
||||
- Elif/else chains can only be used within nested blocks (starting with `@`).
|
||||
- Each `elif` clause must have its own condition expression and block.
|
||||
- The `else` clause is optional and provides a default action when no conditions match.
|
||||
|
||||
#### Examples
|
||||
|
||||
Basic default rule:
|
||||
|
||||
```go
|
||||
default {
|
||||
bypass
|
||||
}
|
||||
```
|
||||
|
||||
WebSocket upgrade routing:
|
||||
|
||||
```bash
|
||||
# WebSocket requests
|
||||
header Connection Upgrade &
|
||||
header Upgrade websocket {
|
||||
route ws-api
|
||||
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
|
||||
}
|
||||
```
|
||||
|
||||
Block comments:
|
||||
|
||||
```go
|
||||
/* protect admin area */
|
||||
path glob("/admin/*") {
|
||||
require_auth
|
||||
}
|
||||
```
|
||||
|
||||
Always log the request
|
||||
|
||||
```bash
|
||||
{
|
||||
log info /dev/stdout "Request $req_method $req_path"
|
||||
}
|
||||
```
|
||||
|
||||
#### Notes and constraints
|
||||
|
||||
- The block syntax uses `{` and `}` as structure delimiters at **top-level** (outside quotes/comments).
|
||||
- Braces inside quoted strings (including backticks) are not structural.
|
||||
- `${...}` handling:
|
||||
- `on`: must be quoted/backticked
|
||||
- `do`: may be unquoted
|
||||
Preferred style: always write env vars as `${NAME}` rather than a bare `$NAME`.
|
||||
- If you need literal `{` or `}` outside quotes/backticks (for example unquoted templates like `{{ ... }}`), wrap that argument in quotes/backticks so the outer parser does not treat it as structure.
|
||||
- Rule naming remains minimal: if no explicit name is provided by the syntax, it will behave like the current YAML behavior (empty name becomes `rule[index]` in [`Rules.BuildHandler()`](internal/route/rules/rules.go:75)).
|
||||
- YAML remains supported as a fallback for backward compatibility.
|
||||
|
||||
### Condition Syntax
|
||||
|
||||
```yaml
|
||||
@@ -215,12 +414,13 @@ rules:
|
||||
on: path /api/users
|
||||
|
||||
# Multiple conditions (AND)
|
||||
on: |
|
||||
header Authorization Bearer
|
||||
& path /api/admin/*
|
||||
on: header Authorization Bearer & path glob("/api/admin/*")
|
||||
|
||||
# Negation
|
||||
on: !path /public/*
|
||||
on: !path glob("/public/*")
|
||||
|
||||
# Negation on matcher
|
||||
on: path !glob("/public/*")
|
||||
|
||||
# OR within a line
|
||||
on: method GET | method POST
|
||||
@@ -228,21 +428,21 @@ on: method GET | method POST
|
||||
|
||||
### Variable Substitution
|
||||
|
||||
```go
|
||||
// Static variables
|
||||
$req_method // Request method
|
||||
$req_host // Request host
|
||||
$req_path // Request path
|
||||
$status_code // Response status
|
||||
$remote_host // Client IP
|
||||
```bash
|
||||
# Static variables
|
||||
$req_method # Request method
|
||||
$req_host # Request host
|
||||
$req_path # Request path
|
||||
$status_code # Response status
|
||||
$remote_host # Client IP
|
||||
|
||||
// Dynamic variables
|
||||
$header(Name) // Request header
|
||||
$header(Name, index) // Header at index
|
||||
$arg(Name) // Query argument
|
||||
$form(Name) // Form field
|
||||
# Dynamic variables
|
||||
$header(Name) # Request header
|
||||
$header(Name, index) # Header at index
|
||||
$arg(Name) # Query argument
|
||||
$form(Name) # Form field
|
||||
|
||||
// Environment variables
|
||||
# Environment variables
|
||||
${ENV_VAR}
|
||||
```
|
||||
|
||||
@@ -277,12 +477,13 @@ Log context includes: `rule`, `alias`, `match_result`
|
||||
|
||||
## Failure Modes and Recovery
|
||||
|
||||
| Failure | Behavior | Recovery |
|
||||
| ------------------- | ------------------------- | ---------------------------------- |
|
||||
| Invalid rule syntax | Route validation fails | Fix YAML syntax |
|
||||
| Missing variables | Variable renders as empty | Check variable sources |
|
||||
| Rule timeout | Request times out | Increase timeout or simplify rules |
|
||||
| Auth failure | Returns 401/403 | Fix credentials |
|
||||
| Failure | Behavior | Recovery |
|
||||
| ---------------------- | ------------------------- | ---------------------------------- |
|
||||
| Invalid rule syntax | Route validation fails | Fix YAML syntax |
|
||||
| Multiple default rules | Route validation fails | Remove duplicate default rules |
|
||||
| Missing variables | Variable renders as empty | Check variable sources |
|
||||
| Rule timeout | Request times out | Increase timeout or simplify rules |
|
||||
| Auth failure | Returns 401/403 | Fix credentials |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -297,11 +498,11 @@ Log context includes: `rule`, `alias`, `match_result`
|
||||
|
||||
```yaml
|
||||
- name: api proxy
|
||||
on: path /api/*
|
||||
on: path glob("/api/*")
|
||||
do: proxy http://api-backend:8080
|
||||
|
||||
- name: static files
|
||||
on: path /static/*
|
||||
on: path glob("/static/*")
|
||||
do: serve /var/www/static
|
||||
```
|
||||
|
||||
@@ -309,11 +510,11 @@ Log context includes: `rule`, `alias`, `match_result`
|
||||
|
||||
```yaml
|
||||
- name: admin protection
|
||||
on: path /admin/*
|
||||
on: path glob("/admin/*")
|
||||
do: require_auth
|
||||
|
||||
- name: basic auth for API
|
||||
on: path /api/*
|
||||
on: path glob("/api/*")
|
||||
do: require_basic_auth "API Access"
|
||||
```
|
||||
|
||||
@@ -321,7 +522,7 @@ Log context includes: `rule`, `alias`, `match_result`
|
||||
|
||||
```yaml
|
||||
- name: rewrite API v1
|
||||
on: path /v1/*
|
||||
on: path glob("/v1/*")
|
||||
do: |
|
||||
rewrite /v1 /api/v1
|
||||
proxy http://backend:8080
|
||||
@@ -351,6 +552,27 @@ Log context includes: `rule`, `alias`, `match_result`
|
||||
do: bypass
|
||||
```
|
||||
|
||||
### Default Rule (Fallback)
|
||||
|
||||
```yaml
|
||||
# Default runs only if no non-default pre rule matches
|
||||
- name: default
|
||||
do: |
|
||||
remove resp_header X-Internal
|
||||
add resp_header X-Powered-By godoxy
|
||||
|
||||
# Matching rules suppress default
|
||||
- name: api routes
|
||||
on: path glob("/api/*")
|
||||
do: proxy http://api:8080
|
||||
|
||||
- name: api marker
|
||||
on: path glob("/api/*")
|
||||
do: set resp_header X-API true
|
||||
```
|
||||
|
||||
Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors and both behave as fallback-only.
|
||||
|
||||
## Testing Notes
|
||||
|
||||
- Unit tests for all matchers and actions
|
||||
|
||||
409
internal/route/rules/block_parser.go
Normal file
409
internal/route/rules/block_parser.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/yusing/goutils/env"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
func getStringBuffer(size int) *strings.Builder {
|
||||
var buf strings.Builder
|
||||
if size > 0 {
|
||||
buf.Grow(size)
|
||||
}
|
||||
return &buf
|
||||
}
|
||||
|
||||
// expandEnvVarsRaw expands ${NAME} in-place using env.LookupEnv (prefix-aware).
|
||||
func expandEnvVarsRaw(v string) (string, gperr.Error) {
|
||||
buf := getStringBuffer(len(v))
|
||||
envVar := getStringBuffer(0)
|
||||
|
||||
var missingEnvVars []string
|
||||
inEnvVar := false
|
||||
expectingBrace := false
|
||||
|
||||
for _, r := range v {
|
||||
if expectingBrace && r != '{' && r != '$' {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
}
|
||||
switch r {
|
||||
case '$':
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
} else {
|
||||
expectingBrace = true
|
||||
}
|
||||
case '{':
|
||||
if expectingBrace {
|
||||
inEnvVar = true
|
||||
expectingBrace = false
|
||||
envVar.Reset()
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '}':
|
||||
if inEnvVar {
|
||||
envValue, ok := env.LookupEnv(envVar.String())
|
||||
if !ok {
|
||||
missingEnvVars = append(missingEnvVars, envVar.String())
|
||||
} else {
|
||||
buf.WriteString(envValue)
|
||||
}
|
||||
inEnvVar = false
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
default:
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
}
|
||||
if inEnvVar {
|
||||
envVar.WriteRune(r)
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
}
|
||||
|
||||
var err gperr.Error
|
||||
if inEnvVar {
|
||||
// Write back the unterminated ${...} so the output matches the input.
|
||||
buf.WriteString("${")
|
||||
buf.WriteString(envVar.String())
|
||||
err = ErrUnterminatedEnvVar
|
||||
}
|
||||
if len(missingEnvVars) > 0 {
|
||||
err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...)))
|
||||
}
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
// parseBlockRules parses the block-syntax rule format.
|
||||
// Grammar:
|
||||
//
|
||||
// file := { ws | comment | rule }
|
||||
// rule := default_rule | conditional_rule
|
||||
// default_rule := 'default' ws* block
|
||||
// conditional_rule := on_expr ws* block
|
||||
// block := '{' do_body '}'
|
||||
//
|
||||
// Where:
|
||||
// - on_expr is passed verbatim to RuleOn.Parse()
|
||||
// - do_body is passed verbatim to Command.Parse()
|
||||
//
|
||||
// Comments (ignored outside quotes/backticks):
|
||||
// - line comment: // ... or # ...
|
||||
// - block comment: /* ... */
|
||||
//
|
||||
// Brace handling:
|
||||
// - Braces inside quotes/backticks are ignored
|
||||
// - Braces inside ${...} (env vars) are ignored in do_body
|
||||
// - Braces in on_expr are not ignored (env vars must be quoted in on_expr)
|
||||
//
|
||||
//nolint:dupword
|
||||
func parseBlockRules(src string) (Rules, gperr.Error) {
|
||||
var rules Rules
|
||||
var errs gperr.Builder
|
||||
|
||||
pos := 0
|
||||
length := len(src)
|
||||
t := newTokenizer(src)
|
||||
|
||||
for pos < length {
|
||||
// Skip whitespace/comments between rules.
|
||||
newPos, skipErr := t.skipComments(pos, true, true)
|
||||
if skipErr != nil {
|
||||
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
|
||||
}
|
||||
pos = newPos
|
||||
if pos >= length {
|
||||
break
|
||||
}
|
||||
|
||||
// Stray closing brace at top-level: keep parsing but mark invalid so Rules.Validate() fails.
|
||||
if src[pos] == '}' {
|
||||
return nil, ErrInvalidBlockSyntax.Withf("unmatched '}' at position %d", pos)
|
||||
}
|
||||
|
||||
// Parse rule header (default, unconditional, or on_expr)
|
||||
headerStart := pos
|
||||
header := parseRuleHeader(&t, src, &pos, length)
|
||||
headerStr := src[headerStart:pos]
|
||||
|
||||
// Skip whitespace/comments before '{' (default header may end before '{').
|
||||
newPos, skipErr = t.skipComments(pos, false, true)
|
||||
if skipErr != nil {
|
||||
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
|
||||
}
|
||||
pos = newPos
|
||||
|
||||
if pos >= length || src[pos] != '{' {
|
||||
errs.AddSubjectf(ErrInvalidBlockSyntax, "expected '{' after rule header %q", headerStr)
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
// Find matching '}' (respecting quotes and env vars in do_body)
|
||||
bodyStart := pos + 1
|
||||
bodyEnd, err := t.findMatchingBrace(bodyStart)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "rule header %q", headerStr)
|
||||
return nil, errs.Error()
|
||||
}
|
||||
pos = bodyEnd + 1
|
||||
|
||||
onExpr := header
|
||||
|
||||
doBody := ""
|
||||
if bodyStart < bodyEnd {
|
||||
doBody = src[bodyStart:bodyEnd]
|
||||
}
|
||||
// Normalize do body for the inner DSL parser:
|
||||
// - strip comments (outside quotes/backticks)
|
||||
// - trim block whitespace/indentation
|
||||
// - expand ${ENV} in-place so cmd.raw is usable/debuggable
|
||||
doBody, err = preprocessDoBody(doBody)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "rule header %q", headerStr)
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Name: "", // auto-generate if empty
|
||||
On: RuleOn{},
|
||||
Do: Command{},
|
||||
}
|
||||
|
||||
// Header semantics:
|
||||
// - "default" => default rule (matched when no other rules are matched)
|
||||
// - "" => unconditional rule (always matches)
|
||||
// - otherwise => conditional rule (on expression)
|
||||
switch onExpr {
|
||||
case "default":
|
||||
rule.On.raw = OnDefault
|
||||
case "":
|
||||
// leave rule.On as zero value => checker=nil => always matches
|
||||
default:
|
||||
if parseErr := rule.On.Parse(onExpr); parseErr != nil {
|
||||
errs.AddSubjectf(parseErr, "on")
|
||||
}
|
||||
}
|
||||
|
||||
if doBody != "" {
|
||||
if parseErr := rule.Do.Parse(doBody); parseErr != nil {
|
||||
errs.AddSubjectf(parseErr, "do")
|
||||
}
|
||||
}
|
||||
|
||||
if errs.HasError() {
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func preprocessDoBody(doBody string) (string, gperr.Error) {
|
||||
doBody = strings.TrimSpace(doBody)
|
||||
if doBody == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
normalized := doBody
|
||||
// If comments are possible, strip them first while preserving line breaks.
|
||||
if strings.ContainsAny(normalized, "#/") {
|
||||
stripped, err := stripCommentsPreserveNewlines(normalized)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
normalized = stripped
|
||||
}
|
||||
|
||||
// Drop lines that are empty after trimming, while preserving indentation of non-empty lines.
|
||||
out := getStringBuffer(len(normalized))
|
||||
|
||||
lineStart := 0
|
||||
wroteLine := false
|
||||
for i := 0; i <= len(normalized); i++ {
|
||||
if i < len(normalized) && normalized[i] != '\n' {
|
||||
continue
|
||||
}
|
||||
line := normalized[lineStart:i]
|
||||
if strings.TrimSpace(line) != "" {
|
||||
if wroteLine {
|
||||
out.WriteByte('\n')
|
||||
}
|
||||
out.WriteString(line)
|
||||
wroteLine = true
|
||||
}
|
||||
lineStart = i + 1
|
||||
}
|
||||
|
||||
if !wroteLine {
|
||||
return "", nil
|
||||
}
|
||||
normalized = out.String()
|
||||
|
||||
// Expand env vars to keep Command.raw consistent with parsed semantics.
|
||||
if !strings.Contains(normalized, "${") {
|
||||
return normalized, nil
|
||||
}
|
||||
expanded, err := expandEnvVarsRaw(normalized)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return expanded, nil
|
||||
}
|
||||
|
||||
// stripCommentsPreserveNewlines removes //, #, and /* */ comments outside quotes/backticks.
|
||||
// It preserves newlines so command line boundaries remain intact.
|
||||
func stripCommentsPreserveNewlines(src string) (string, gperr.Error) {
|
||||
if !strings.ContainsAny(src, "#/") {
|
||||
return src, nil
|
||||
}
|
||||
|
||||
out := getStringBuffer(len(src))
|
||||
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
atLineStart := true
|
||||
prevIsSpace := true
|
||||
|
||||
for i := 0; i < len(src); {
|
||||
c := src[i]
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
out.WriteByte('\n')
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '\n' {
|
||||
out.WriteByte('\n')
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if c == '*' && i+1 < len(src) && src[i+1] == '/' {
|
||||
inBlock = false
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
out.WriteByte(c)
|
||||
if c == '\\' && i+1 < len(src) {
|
||||
// Write next char and skip it (escape sequence)
|
||||
i++
|
||||
out.WriteByte(src[i])
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if rune(c) == quote {
|
||||
quote = 0
|
||||
}
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(rune(c))
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Not in quote/comment.
|
||||
switch c {
|
||||
case '\'', '"', '`':
|
||||
quote = rune(c)
|
||||
out.WriteByte(c)
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
i++
|
||||
continue
|
||||
case '#':
|
||||
if atLineStart || prevIsSpace {
|
||||
inLine = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
case '/':
|
||||
if i+1 < len(src) {
|
||||
n := src[i+1]
|
||||
if (atLineStart || prevIsSpace) && n == '/' {
|
||||
inLine = true
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if (atLineStart || prevIsSpace) && n == '*' {
|
||||
inBlock = true
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out.WriteByte(c)
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(rune(c))
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
if inBlock {
|
||||
return "", ErrInvalidBlockSyntax.Withf("unterminated block comment")
|
||||
}
|
||||
return out.String(), nil
|
||||
}
|
||||
|
||||
// parseRuleHeader parses the rule header (default or on expression).
|
||||
// Returns the header string, or "" if parsing failed.
|
||||
func parseRuleHeader(t *Tokenizer, src string, pos *int, length int) string {
|
||||
start := *pos
|
||||
|
||||
// Check for 'default' keyword
|
||||
if *pos+7 <= length && src[*pos:*pos+7] == "default" {
|
||||
next := *pos + 7
|
||||
if next >= length || unicode.IsSpace(rune(src[next])) {
|
||||
*pos = next
|
||||
return "default"
|
||||
}
|
||||
}
|
||||
|
||||
// Parse on expression until we hit '{' outside quotes.
|
||||
bracePos, err := t.scanToBrace(*pos)
|
||||
if err != nil {
|
||||
*pos = length
|
||||
return strings.TrimSpace(src[start:*pos])
|
||||
}
|
||||
*pos = bracePos
|
||||
return strings.TrimSpace(src[start:*pos])
|
||||
}
|
||||
48
internal/route/rules/block_parser_bench_test.go
Normal file
48
internal/route/rules/block_parser_bench_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rules
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkParseBlockRules(b *testing.B) {
|
||||
const rulesString = `
|
||||
default {
|
||||
remove resp_header X-Secret
|
||||
add resp_header X-Custom-Header custom-value
|
||||
}
|
||||
|
||||
header X-Test-Header {
|
||||
set header X-Remote-Type public
|
||||
@remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||
set header X-Remote-Type private
|
||||
}
|
||||
}
|
||||
|
||||
path glob(/api/admin/*) {
|
||||
@cookie session-id {
|
||||
set header X-Session-ID $cookie(session-id)
|
||||
}
|
||||
}
|
||||
|
||||
!remote 192.168.0.0/16 {
|
||||
@!header X-User-Role admin & !header X-User-Role user {
|
||||
error 403 "Access denied"
|
||||
} elif remote 127.0.0.1 {
|
||||
@header X-User-Role staff {
|
||||
set header X-User-Role staff
|
||||
}
|
||||
} else {
|
||||
error 403 "Access denied"
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
var rules Rules
|
||||
err := rules.Parse(rulesString)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
var rules Rules
|
||||
_ = rules.Parse(rulesString)
|
||||
}
|
||||
}
|
||||
339
internal/route/rules/block_parser_test.go
Normal file
339
internal/route/rules/block_parser_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
func testParseRules(t *testing.T, data string) Rules {
|
||||
t.Helper()
|
||||
|
||||
var rules Rules
|
||||
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
|
||||
require.True(t, convertible)
|
||||
require.NoError(t, err)
|
||||
return rules
|
||||
}
|
||||
|
||||
func testParseRulesError(t *testing.T, data string) error {
|
||||
t.Helper()
|
||||
|
||||
var rules Rules
|
||||
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
|
||||
require.True(t, convertible)
|
||||
return err
|
||||
}
|
||||
|
||||
func TestParseBlockRules_DefaultRule(t *testing.T) {
|
||||
rules := testParseRules(t, `default {
|
||||
upstream
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "upstream", rules[0].Do.raw)
|
||||
assert.True(t, rules[0].Do.raw == CommandUpstream)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_ConditionalRule(t *testing.T) {
|
||||
rules := testParseRules(t, `path glob(/api/*) {
|
||||
proxy http://localhost:8080
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, "path glob(/api/*)", rules[0].On.raw)
|
||||
assert.Equal(t, "proxy http://localhost:8080", rules[0].Do.raw)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[0].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_MultipleRules(t *testing.T) {
|
||||
rules := testParseRules(t, `default {
|
||||
bypass
|
||||
}
|
||||
|
||||
path /api/* {
|
||||
proxy http://localhost:8080
|
||||
}
|
||||
|
||||
header Connection Upgrade &
|
||||
header Upgrade websocket {
|
||||
route ws-api
|
||||
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
|
||||
}`)
|
||||
require.Len(t, rules, 3)
|
||||
|
||||
// Default rule
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
|
||||
// API rule
|
||||
assert.Equal(t, "path /api/*", rules[1].On.raw)
|
||||
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
|
||||
|
||||
// WebSocket rule
|
||||
assert.Equal(t, "header Connection Upgrade &\nheader Upgrade websocket", rules[2].On.raw)
|
||||
assert.Equal(t, `route ws-api
|
||||
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"`, rules[2].Do.raw)
|
||||
require.Len(t, rules[2].Do.pre, 2)
|
||||
_, ok := rules[2].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
_, ok = rules[2].Do.pre[1].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[2].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_Comments(t *testing.T) {
|
||||
rules := testParseRules(t, `// This is a comment
|
||||
default {
|
||||
bypass // inline comment
|
||||
}
|
||||
|
||||
/* Block comment
|
||||
spanning multiple lines */
|
||||
path /admin/* {
|
||||
require_auth
|
||||
}`)
|
||||
require.Len(t, rules, 2)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "path /admin/*", rules[1].On.raw)
|
||||
assert.Equal(t, "require_auth", rules[1].Do.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_HashComment(t *testing.T) {
|
||||
rules := testParseRules(t, `# YAML-style comment
|
||||
default {
|
||||
bypass
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_EnvVars(t *testing.T) {
|
||||
t.Setenv("CUSTOM_HEADER", "test-header")
|
||||
|
||||
rules := testParseRules(t, `path /api/* {
|
||||
set header X-Custom "${CUSTOM_HEADER}"
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, "path /api/*", rules[0].On.raw)
|
||||
assert.Equal(t, `set header X-Custom "test-header"`, rules[0].Do.raw)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[0].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_YAMLFallback(t *testing.T) {
|
||||
rules := testParseRules(t, `- name: default
|
||||
do: bypass
|
||||
- name: api
|
||||
on: path glob(/api/*)
|
||||
do: proxy http://localhost:8080`)
|
||||
require.Len(t, rules, 2)
|
||||
assert.Equal(t, "default", rules[0].Name)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
assert.Equal(t, "api", rules[1].Name)
|
||||
assert.Equal(t, "path glob(/api/*)", rules[1].On.raw)
|
||||
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
|
||||
require.Len(t, rules[1].Do.pre, 1)
|
||||
_, ok := rules[1].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[1].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_UnmatchedBrace(t *testing.T) {
|
||||
t.Run("unquoted", func(t *testing.T) {
|
||||
err := testParseRulesError(t, `path /api/* {
|
||||
proxy http://localhost:8080}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("quoted", func(t *testing.T) {
|
||||
_ = testParseRules(t, `path /api/* {
|
||||
error 403 "some message}"
|
||||
}`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseBlockRules_UnterminatedBlockComment(t *testing.T) {
|
||||
err := testParseRulesError(t, `/* unterminated block comment
|
||||
default {
|
||||
bypass
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks(t *testing.T) {
|
||||
rules := testParseRules(t, `
|
||||
header X-Test-Header {
|
||||
set header X-Remote-Type public
|
||||
@remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||
set header X-Remote-Type private
|
||||
}
|
||||
}`)
|
||||
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, "header X-Test-Header", rules[0].On.raw)
|
||||
|
||||
require.Len(t, rules[0].Do.pre, 2)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[0].Do.post, 0)
|
||||
|
||||
ifCmd, ok := rules[0].Do.pre[1].(IfBlockCommand)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "remote 127.0.0.1 | remote 192.168.0.0/16", ifCmd.On.raw)
|
||||
|
||||
require.Len(t, ifCmd.Do, 1)
|
||||
|
||||
upstream := func(http.ResponseWriter, *http.Request) {}
|
||||
|
||||
t.Run("condition matched executes nested content", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Test-Header", "1")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "private", req.Header.Get("X-Remote-Type"))
|
||||
})
|
||||
|
||||
t.Run("condition not matched skips nested content", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Test-Header", "1")
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "public", req.Header.Get("X-Remote-Type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElifElse(t *testing.T) {
|
||||
rules := testParseRules(t, `
|
||||
header X-Test-Header {
|
||||
set header X-Mode outer
|
||||
@method GET {
|
||||
set header X-Mode get
|
||||
} elif method POST {
|
||||
set header X-Mode post
|
||||
} else {
|
||||
set header X-Mode other
|
||||
}
|
||||
}`)
|
||||
|
||||
require.Len(t, rules, 1)
|
||||
|
||||
require.Len(t, rules[0].Do.pre, 2)
|
||||
|
||||
ifCmd, ok := rules[0].Do.pre[1].(IfElseBlockCommand)
|
||||
require.True(t, ok)
|
||||
require.Len(t, ifCmd.Ifs, 2)
|
||||
assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw)
|
||||
assert.Equal(t, "method POST", ifCmd.Ifs[1].On.raw)
|
||||
require.NotNil(t, ifCmd.Else)
|
||||
|
||||
upstream := func(http.ResponseWriter, *http.Request) {}
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
want string
|
||||
}{
|
||||
{name: "get branch", method: http.MethodGet, want: "get"},
|
||||
{name: "post branch", method: http.MethodPost, want: "post"},
|
||||
{name: "else branch", method: http.MethodPut, want: "other"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, "/", nil)
|
||||
req.Header.Set("X-Test-Header", "1")
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.want, req.Header.Get("X-Mode"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBlockRules_DefaultRule_CommentBeforeBrace(t *testing.T) {
|
||||
rules := testParseRules(t, `default /* comment between header and brace */ {
|
||||
bypass
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_StrayClosingBraceAtTopLevel(t *testing.T) {
|
||||
err := testParseRulesError(t, `}
|
||||
default {
|
||||
bypass
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElifMustBeSameLine(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
@method GET {
|
||||
set header X-Mode get
|
||||
}
|
||||
elif method POST {
|
||||
set header X-Mode post
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElseMustBeLastOnLine(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
@method GET {
|
||||
set header X-Mode get
|
||||
} else {
|
||||
set header X-Mode other
|
||||
} set header X-After else
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unexpected token after else block")
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_MultipleElse(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
@method GET {
|
||||
set header X-Mode get
|
||||
} else {
|
||||
set header X-Mode other
|
||||
} else {
|
||||
set header X-Mode other2
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
// assert.Contains(t, err.Error(), "multiple 'else' branches")
|
||||
assert.Contains(t, err.Error(), "unexpected token after else block")
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElifMissingOnExpr(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
@method GET {
|
||||
set header X-Mode get
|
||||
} elif {
|
||||
set header X-Mode post
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expected on-expr after 'elif'")
|
||||
}
|
||||
@@ -1,21 +1,25 @@
|
||||
package rules
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
type (
|
||||
CheckFunc func(w http.ResponseWriter, r *http.Request) bool
|
||||
CheckFunc func(w *httputils.ResponseModifier, r *http.Request) bool
|
||||
Checker interface {
|
||||
Check(w http.ResponseWriter, r *http.Request) bool
|
||||
Check(w *httputils.ResponseModifier, r *http.Request) bool
|
||||
}
|
||||
CheckMatchSingle []Checker
|
||||
CheckMatchAll []Checker
|
||||
)
|
||||
|
||||
func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (checker CheckFunc) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return checker(w, r)
|
||||
}
|
||||
|
||||
func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (checkers CheckMatchSingle) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
for _, check := range checkers {
|
||||
if check.Check(w, r) {
|
||||
return true
|
||||
@@ -24,7 +28,7 @@ func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) b
|
||||
return false
|
||||
}
|
||||
|
||||
func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (checkers CheckMatchAll) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
for _, check := range checkers {
|
||||
if !check.Check(w, r) {
|
||||
return false
|
||||
|
||||
@@ -1,79 +1,62 @@
|
||||
package rules
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
var errTerminateRule = errors.New("terminate rule")
|
||||
|
||||
type (
|
||||
handlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
HandlerFunc func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
|
||||
Handler struct {
|
||||
fn HandlerFunc
|
||||
phase PhaseFlag
|
||||
terminate bool
|
||||
}
|
||||
|
||||
CommandHandler interface {
|
||||
// CommandHandler can read and modify the values
|
||||
// then handle the request
|
||||
// finally proceed to next command (or return) base on situation
|
||||
Handle(w http.ResponseWriter, r *http.Request) error
|
||||
IsResponseHandler() bool
|
||||
ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
|
||||
Phase() PhaseFlag
|
||||
}
|
||||
// NonTerminatingCommand will run then proceed to next command or reverse proxy.
|
||||
NonTerminatingCommand handlerFunc
|
||||
// TerminatingCommand will run then return immediately.
|
||||
TerminatingCommand handlerFunc
|
||||
// OnResponseCommand will run then return based on the response.
|
||||
OnResponseCommand handlerFunc
|
||||
// BypassCommand will skip all the following commands
|
||||
// and directly return to reverse proxy.
|
||||
BypassCommand struct{}
|
||||
|
||||
// Commands is a slice of CommandHandler.
|
||||
Commands []CommandHandler
|
||||
)
|
||||
|
||||
func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return c(w, r)
|
||||
func (h Handler) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
return h.fn(w, r, upstream)
|
||||
}
|
||||
|
||||
func (c NonTerminatingCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
func (h Handler) Phase() PhaseFlag {
|
||||
return h.phase
|
||||
}
|
||||
|
||||
func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
if err := c(w, r); err != nil {
|
||||
return err
|
||||
}
|
||||
return errTerminated
|
||||
func (h Handler) Terminates() bool {
|
||||
return h.terminate
|
||||
}
|
||||
|
||||
func (c TerminatingCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return c(w, r)
|
||||
}
|
||||
|
||||
func (c OnResponseCommand) IsResponseHandler() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return errTerminated
|
||||
}
|
||||
|
||||
func (c BypassCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
func (c Commands) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
for _, cmd := range c {
|
||||
if err := cmd.Handle(w, r); err != nil {
|
||||
err := cmd.ServeHTTP(w, r, upstream)
|
||||
if err != nil {
|
||||
// Terminating actions stop the command chain immediately.
|
||||
// Will be handled by the caller.
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Commands) IsResponseHandler() bool {
|
||||
func (c Commands) Phase() PhaseFlag {
|
||||
req := PhaseNone
|
||||
for _, cmd := range c {
|
||||
if cmd.IsResponseHandler() {
|
||||
return true
|
||||
}
|
||||
req |= cmd.Phase()
|
||||
}
|
||||
return false
|
||||
return req
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -24,17 +23,17 @@ import (
|
||||
|
||||
type (
|
||||
Command struct {
|
||||
raw string
|
||||
exec CommandHandler
|
||||
isResponseHandler bool
|
||||
raw string
|
||||
pre Commands // runs before w.WriteHeader
|
||||
post Commands
|
||||
}
|
||||
)
|
||||
|
||||
func (cmd *Command) IsResponseHandler() bool {
|
||||
return cmd.isResponseHandler
|
||||
}
|
||||
|
||||
const (
|
||||
CommandUpstream = "upstream"
|
||||
CommandUpstreamOld = "bypass"
|
||||
CommandUpstreamOld2 = "pass"
|
||||
|
||||
CommandRequireAuth = "require_auth"
|
||||
CommandRewrite = "rewrite"
|
||||
CommandServe = "serve"
|
||||
@@ -48,8 +47,6 @@ const (
|
||||
CommandRemove = "remove"
|
||||
CommandLog = "log"
|
||||
CommandNotify = "notify"
|
||||
CommandPass = "pass"
|
||||
CommandPassAlt = "bypass"
|
||||
)
|
||||
|
||||
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
@@ -60,36 +57,60 @@ func InitAuthHandler(handler AuthHandler) {
|
||||
authHandler = handler
|
||||
}
|
||||
|
||||
func init() {
|
||||
commands[CommandUpstreamOld] = commands[CommandUpstream]
|
||||
commands[CommandUpstreamOld2] = commands[CommandUpstream]
|
||||
}
|
||||
|
||||
var commands = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
build func(args any) CommandHandler
|
||||
isResponseHandler bool
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
build func(args any) HandlerFunc
|
||||
terminate bool
|
||||
}{
|
||||
CommandUpstream: {
|
||||
help: Help{
|
||||
command: CommandUpstream,
|
||||
description: makeLines("Pass the request to the upstream"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 0 {
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
return phase, nil, nil
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
upstream(w, r)
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRequireAuth: {
|
||||
help: Help{
|
||||
command: CommandRequireAuth,
|
||||
description: makeLines("Require HTTP authentication for incoming requests"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
return phase, nil, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if authHandler == nil {
|
||||
http.Error(w, "Auth handler not initialized", http.StatusInternalServerError)
|
||||
return errTerminated
|
||||
build: func(args any) HandlerFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if authHandler == nil { // no auth handler configured, allow request to proceed
|
||||
return nil
|
||||
}
|
||||
if !authHandler(w, r) {
|
||||
return errTerminated
|
||||
if proceed := authHandler(w, r); !proceed {
|
||||
return errTerminateRule
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
CommandRewrite: {
|
||||
@@ -104,26 +125,27 @@ var commands = map[string]struct {
|
||||
"to": "the path to rewrite to, must start with /",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
path1, err1 := validateURLPath(args[:1])
|
||||
path2, err2 := validateURLPath(args[1:])
|
||||
if err1 != nil {
|
||||
err1 = gperr.PrependSubject(err1, "from")
|
||||
err1 = gperr.Errorf("from: %w", err1)
|
||||
}
|
||||
if err2 != nil {
|
||||
err2 = gperr.PrependSubject(err2, "to")
|
||||
err2 = gperr.Errorf("to: %w", err2)
|
||||
}
|
||||
if err1 != nil || err2 != nil {
|
||||
return nil, gperr.Join(err1, err2)
|
||||
return phase, nil, gperr.Join(err1, err2)
|
||||
}
|
||||
return &StrTuple{path1.(string), path2.(string)}, nil
|
||||
return phase, &StrTuple{path1.(string), path2.(string)}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
orig, repl := args.(*StrTuple).Unpack()
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
path := r.URL.Path
|
||||
if len(path) > 0 && path[0] != '/' {
|
||||
path = "/" + path
|
||||
@@ -133,10 +155,10 @@ var commands = map[string]struct {
|
||||
}
|
||||
path = repl + path[len(orig):]
|
||||
r.URL.Path = path
|
||||
r.URL.RawPath = r.URL.EscapedPath()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.URL.RawPath = ""
|
||||
r.RequestURI = ""
|
||||
return nil
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
CommandServe: {
|
||||
@@ -150,14 +172,19 @@ var commands = map[string]struct {
|
||||
"root": "the file system path to serve, must be an existing directory",
|
||||
},
|
||||
},
|
||||
validate: validateFSPath,
|
||||
build: func(args any) CommandHandler {
|
||||
root := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||
return nil
|
||||
})
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
parsedArgs, err = validateFSPath(args)
|
||||
return
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
root := args.(string)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRedirect: {
|
||||
help: Help{
|
||||
@@ -170,14 +197,19 @@ var commands = map[string]struct {
|
||||
"to": "the url to redirect to, can be relative or absolute URL",
|
||||
},
|
||||
},
|
||||
validate: validateURL,
|
||||
build: func(args any) CommandHandler {
|
||||
target := args.(*nettypes.URL).String()
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||
return nil
|
||||
})
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
parsedArgs, err = validateURL(args)
|
||||
return
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
target := args.(*nettypes.URL).String()
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRoute: {
|
||||
help: Help{
|
||||
@@ -190,15 +222,16 @@ var commands = map[string]struct {
|
||||
"route": "the route to route to",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
return args[0], nil
|
||||
return phase, args[0], nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
route := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, req *http.Request, upstream http.HandlerFunc) error {
|
||||
ep := entrypoint.FromCtx(req.Context())
|
||||
r, ok := ep.HTTPRoutes().Get(route)
|
||||
if !ok {
|
||||
@@ -212,9 +245,10 @@ var commands = map[string]struct {
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandError: {
|
||||
help: Help{
|
||||
@@ -228,34 +262,40 @@ var commands = map[string]struct {
|
||||
"text": "the error message to return",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
codeStr, text := args[0], args[1]
|
||||
code, err := strconv.Atoi(codeStr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return phase, nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if !httputils.IsStatusCodeValid(code) {
|
||||
return nil, ErrInvalidArguments.Subject(codeStr)
|
||||
return phase, nil, ErrInvalidArguments.Subject(codeStr)
|
||||
}
|
||||
textTmpl, err := validateTemplate(text, true)
|
||||
tmplReq, textTmpl, err := validateTemplate(text, true)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return phase, nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return &Tuple[int, templateString]{code, textTmpl}, nil
|
||||
phase |= tmplReq
|
||||
return phase, &Tuple[int, templateString]{code, textTmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
// error command should overwrite the response body
|
||||
httputils.GetInitResponseModifier(w).ResetBody()
|
||||
w.ResetBody()
|
||||
w.WriteHeader(code)
|
||||
err := textTmpl.ExpandVars(w, r, w)
|
||||
return err
|
||||
})
|
||||
_, err := textTmpl.ExpandVars(w, r, w.BodyBuffer())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRequireBasicAuth: {
|
||||
help: Help{
|
||||
@@ -268,20 +308,22 @@ var commands = map[string]struct {
|
||||
"realm": "the authentication realm",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) == 1 {
|
||||
return args[0], nil
|
||||
return phase, args[0], nil
|
||||
}
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
realm := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, realm))
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandProxy: {
|
||||
help: Help{
|
||||
@@ -294,14 +336,19 @@ var commands = map[string]struct {
|
||||
"to": "the url to proxy to, must be an absolute URL",
|
||||
},
|
||||
},
|
||||
validate: validateURL,
|
||||
build: func(args any) CommandHandler {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
parsedArgs, err = validateURL(args)
|
||||
return
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
target := args.(*nettypes.URL)
|
||||
if target.Scheme == "" {
|
||||
target.Scheme = "http"
|
||||
}
|
||||
if target.Host == "" {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rawPath := target.EscapedPath()
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
url := target.URL
|
||||
url.Host = routes.TryGetUpstreamHostPort(r)
|
||||
if url.Host == "" {
|
||||
@@ -309,18 +356,19 @@ var commands = map[string]struct {
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
|
||||
r.URL.Path = target.Path
|
||||
r.URL.RawPath = r.URL.EscapedPath()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.URL.RawPath = rawPath
|
||||
r.RequestURI = ""
|
||||
rp.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
rp.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandSet: {
|
||||
help: Help{
|
||||
@@ -335,11 +383,11 @@ var commands = map[string]struct {
|
||||
"value": "the value to set",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
return validateModField(ModFieldSet, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return args.(CommandHandler)
|
||||
build: func(args any) HandlerFunc {
|
||||
return args.(HandlerFunc)
|
||||
},
|
||||
},
|
||||
CommandAdd: {
|
||||
@@ -355,11 +403,11 @@ var commands = map[string]struct {
|
||||
"value": "the value to add",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
return validateModField(ModFieldAdd, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return args.(CommandHandler)
|
||||
build: func(args any) HandlerFunc {
|
||||
return args.(HandlerFunc)
|
||||
},
|
||||
},
|
||||
CommandRemove: {
|
||||
@@ -374,15 +422,14 @@ var commands = map[string]struct {
|
||||
"field": "the field to remove",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
return validateModField(ModFieldRemove, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return args.(CommandHandler)
|
||||
build: func(args any) HandlerFunc {
|
||||
return args.(HandlerFunc)
|
||||
},
|
||||
},
|
||||
CommandLog: {
|
||||
isResponseHandler: true,
|
||||
help: Help{
|
||||
command: CommandLog,
|
||||
description: makeLines(
|
||||
@@ -399,46 +446,57 @@ var commands = map[string]struct {
|
||||
"template": "the template to log",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrExpectThreeArgs
|
||||
return phase, nil, ErrExpectThreeArgs
|
||||
}
|
||||
tmpl, err := validateTemplate(args[2], true)
|
||||
phase, tmpl, err := validateTemplate(args[2], true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
level, err := validateLevel(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
// NOTE: file will stay opened forever
|
||||
// it leverages accesslog.NewFileIO so
|
||||
// it will be opened only once for the same path
|
||||
f, err := openFile(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
return &onLogArgs{level, f, tmpl}, nil
|
||||
return phase, &onLogArgs{level, f, tmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
level, f, tmpl := args.(*onLogArgs).Unpack()
|
||||
var logger io.Writer
|
||||
if f == stdout || f == stderr {
|
||||
isStdLogger := f == stdout || f == stderr
|
||||
if isStdLogger {
|
||||
logger = logging.NewLoggerWithFixedLevel(level, f)
|
||||
} else {
|
||||
logger = f
|
||||
}
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
err := tmpl.ExpandVars(w, r, logger)
|
||||
if err != nil {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if isStdLogger {
|
||||
bufPool := w.BufPool()
|
||||
buf := bufPool.GetBuffer()
|
||||
defer bufPool.PutBuffer(buf)
|
||||
|
||||
if _, err := tmpl.ExpandVars(w, r, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if buf.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := logger.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
_, err := tmpl.ExpandVars(w, r, logger)
|
||||
return err
|
||||
}
|
||||
},
|
||||
},
|
||||
CommandNotify: {
|
||||
isResponseHandler: true,
|
||||
help: Help{
|
||||
command: CommandNotify,
|
||||
description: makeLines(
|
||||
@@ -456,22 +514,24 @@ var commands = map[string]struct {
|
||||
"body": "the body of the notification",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 4 {
|
||||
return nil, ErrExpectFourArgs
|
||||
return phase, nil, ErrExpectFourArgs
|
||||
}
|
||||
titleTmpl, err := validateTemplate(args[2], false)
|
||||
req1, titleTmpl, err := validateTemplate(args[2], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
bodyTmpl, err := validateTemplate(args[3], false)
|
||||
req2, bodyTmpl, err := validateTemplate(args[3], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
level, err := validateLevel(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
|
||||
phase |= req1 | req2
|
||||
// TODO: validate provider
|
||||
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
|
||||
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
|
||||
@@ -480,34 +540,34 @@ var commands = map[string]struct {
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
||||
return phase, &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
|
||||
to := []string{provider}
|
||||
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
var respBuf strings.Builder
|
||||
|
||||
err := titleTmpl.ExpandVars(w, r, respBuf)
|
||||
_, err := titleTmpl.ExpandVars(w, r, &respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
titleLen := respBuf.Len()
|
||||
err = bodyTmpl.ExpandVars(w, r, respBuf)
|
||||
_, err = bodyTmpl.ExpandVars(w, r, &respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b := respBuf.Bytes()
|
||||
s := respBuf.String()
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: level,
|
||||
Title: string(b[:titleLen]),
|
||||
Body: notif.MessageBodyBytes(b[titleLen:]),
|
||||
Title: s[:titleLen],
|
||||
Body: notif.MessageBodyBytes(s[titleLen:]),
|
||||
To: to,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -519,121 +579,29 @@ type (
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (cmd *Command) Parse(v string) error {
|
||||
executors := make([]CommandHandler, 0)
|
||||
isResponseHandler := false
|
||||
for line := range strings.SplitSeq(v, "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
directive, args, err := parse(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if directive == CommandPass || directive == CommandPassAlt {
|
||||
if len(args) != 0 {
|
||||
return ErrExpectNoArg
|
||||
}
|
||||
executors = append(executors, BypassCommand{})
|
||||
continue
|
||||
}
|
||||
|
||||
builder, ok := commands[directive]
|
||||
if !ok {
|
||||
return ErrUnknownDirective.Subject(directive)
|
||||
}
|
||||
validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
|
||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
handler := builder.build(validArgs)
|
||||
executors = append(executors, handler)
|
||||
if builder.isResponseHandler || handler.IsResponseHandler() {
|
||||
isResponseHandler = true
|
||||
}
|
||||
executors, parseErr := parseDoWithBlocks(v)
|
||||
if parseErr != nil {
|
||||
return parseErr
|
||||
}
|
||||
|
||||
if len(executors) == 0 {
|
||||
cmd.raw = v
|
||||
cmd.exec = nil
|
||||
cmd.isResponseHandler = false
|
||||
cmd.pre = nil
|
||||
cmd.post = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
exec, err := buildCmd(executors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.raw = v
|
||||
cmd.exec = exec
|
||||
if exec.IsResponseHandler() {
|
||||
isResponseHandler = true
|
||||
for _, executor := range executors {
|
||||
if executor.Phase().IsPostRule() {
|
||||
cmd.post = append(cmd.post, executor)
|
||||
} else {
|
||||
cmd.pre = append(cmd.pre, executor)
|
||||
}
|
||||
}
|
||||
cmd.isResponseHandler = isResponseHandler
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
|
||||
// Validate the execution order.
|
||||
//
|
||||
// This allows sequences like:
|
||||
// route ws-api
|
||||
// log info /dev/stdout "..."
|
||||
// where the first command is request-phase and the last is response-phase.
|
||||
lastNonResp := -1
|
||||
seenResp := false
|
||||
for i, exec := range executors {
|
||||
if exec.IsResponseHandler() {
|
||||
seenResp = true
|
||||
continue
|
||||
}
|
||||
if seenResp {
|
||||
return nil, ErrInvalidCommandSequence.Withf("response handlers must be the last commands")
|
||||
}
|
||||
lastNonResp = i
|
||||
}
|
||||
|
||||
for i, exec := range executors {
|
||||
if i > lastNonResp {
|
||||
break // response-handler tail
|
||||
}
|
||||
switch exec.(type) {
|
||||
case TerminatingCommand, BypassCommand:
|
||||
if i != lastNonResp {
|
||||
return nil, ErrInvalidCommandSequence.
|
||||
Withf("a response handler or terminating/bypass command must be the last command")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Commands(executors), nil
|
||||
}
|
||||
|
||||
// Command is purely "bypass" or empty.
|
||||
func (cmd *Command) isBypass() bool {
|
||||
if cmd == nil {
|
||||
return true
|
||||
}
|
||||
switch cmd := cmd.exec.(type) {
|
||||
case BypassCommand:
|
||||
return true
|
||||
case Commands:
|
||||
// bypass command is always the last one
|
||||
_, ok := cmd[len(cmd)-1].(BypassCommand)
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
return cmd.exec.Handle(w, r)
|
||||
}
|
||||
|
||||
func (cmd *Command) String() string {
|
||||
return cmd.raw
|
||||
}
|
||||
|
||||
386
internal/route/rules/do_blocks.go
Normal file
386
internal/route/rules/do_blocks.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
// IfBlockCommand is an inline conditional block inside a do-body.
|
||||
//
|
||||
// Syntax (within a rule do block):
|
||||
//
|
||||
// @<on-expr> { <do...> }
|
||||
//
|
||||
// Semantics:
|
||||
// - Evaluated in the same phase the parent rule runs.
|
||||
// - If <on-expr> matches, run the nested commands in-order.
|
||||
// - Otherwise do nothing.
|
||||
//
|
||||
// NOTE: Per current design decision, we keep this permissive:
|
||||
// nested blocks may use response matchers and response commands; no extra phase validation is performed.
|
||||
type IfBlockCommand struct {
|
||||
On RuleOn
|
||||
Do []CommandHandler
|
||||
}
|
||||
|
||||
func (c IfBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if c.Do == nil {
|
||||
return nil
|
||||
}
|
||||
// If On.checker is nil, treat as unconditional (should not happen if parsed).
|
||||
if c.On.checker == nil {
|
||||
return Commands(c.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
if c.On.checker.Check(w, r) {
|
||||
return Commands(c.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c IfBlockCommand) Phase() PhaseFlag {
|
||||
phase := c.On.phase
|
||||
for _, cmd := range c.Do {
|
||||
phase |= cmd.Phase()
|
||||
}
|
||||
return phase
|
||||
}
|
||||
|
||||
// IfElseBlockCommand is a chained conditional block inside a do-body.
|
||||
//
|
||||
// Syntax (within a rule do block):
|
||||
//
|
||||
// @<on-expr> { <do...> } elif <on-expr> { <do...> } ... else { <do...> }
|
||||
//
|
||||
// NOTE: `elif`/`else` must appear on the same line as the preceding closing brace (`}`),
|
||||
// e.g. `} elif ... {` and `} else {`.
|
||||
type IfElseBlockCommand struct {
|
||||
Ifs []IfBlockCommand
|
||||
Else []CommandHandler
|
||||
}
|
||||
|
||||
func (c IfElseBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
for _, br := range c.Ifs {
|
||||
// If On.checker is nil, treat as unconditional.
|
||||
if br.On.checker == nil {
|
||||
if br.Do == nil {
|
||||
return nil
|
||||
}
|
||||
return Commands(br.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
if br.On.checker.Check(w, r) {
|
||||
if br.Do == nil {
|
||||
return nil
|
||||
}
|
||||
return Commands(br.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
}
|
||||
if len(c.Else) > 0 {
|
||||
return Commands(c.Else).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c IfElseBlockCommand) Phase() PhaseFlag {
|
||||
phase := PhaseNone
|
||||
for _, br := range c.Ifs {
|
||||
phase |= br.Phase()
|
||||
}
|
||||
if len(c.Else) > 0 {
|
||||
phase |= Commands(c.Else).Phase()
|
||||
}
|
||||
return phase
|
||||
}
|
||||
|
||||
func skipSameLineSpace(src string, pos int) int {
|
||||
for pos < len(src) {
|
||||
switch src[pos] {
|
||||
case '\n':
|
||||
return pos
|
||||
case '\r':
|
||||
pos++
|
||||
continue
|
||||
case ' ', '\t':
|
||||
pos++
|
||||
continue
|
||||
default:
|
||||
return pos
|
||||
}
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
func parseAtBlockChain(src string, atPos int) (CommandHandler, int, error) {
|
||||
length := len(src)
|
||||
headerStart := atPos + 1
|
||||
|
||||
parseBranch := func(onExpr string, bodyStart int, bodyEnd int) (RuleOn, []CommandHandler, error) {
|
||||
var on RuleOn
|
||||
if err := on.Parse(onExpr); err != nil {
|
||||
return RuleOn{}, nil, err
|
||||
}
|
||||
innerSrc := ""
|
||||
if bodyStart < bodyEnd {
|
||||
innerSrc = src[bodyStart:bodyEnd]
|
||||
}
|
||||
inner, err := parseDoWithBlocks(innerSrc)
|
||||
if err != nil {
|
||||
return RuleOn{}, nil, err
|
||||
}
|
||||
if len(inner) == 0 {
|
||||
return on, nil, nil
|
||||
}
|
||||
return on, inner, nil
|
||||
}
|
||||
|
||||
onExpr, bracePos, herr := parseHeaderToBrace(src, headerStart)
|
||||
if herr != nil {
|
||||
return nil, 0, herr
|
||||
}
|
||||
if bracePos >= length || src[bracePos] != '{' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after nested block header")
|
||||
}
|
||||
|
||||
// Parse first @<on-expr> { ... }
|
||||
p := bracePos
|
||||
bodyStart := p + 1
|
||||
bodyEnd, ferr := findMatchingBrace(src, &p, bodyStart)
|
||||
if ferr != nil {
|
||||
return nil, 0, ferr
|
||||
}
|
||||
firstOn, firstDo, berr := parseBranch(onExpr, bodyStart, bodyEnd)
|
||||
if berr != nil {
|
||||
return nil, 0, berr
|
||||
}
|
||||
|
||||
ifs := []IfBlockCommand{{On: firstOn, Do: firstDo}}
|
||||
var elseDo []CommandHandler
|
||||
hasChain := false
|
||||
hasElse := false
|
||||
|
||||
for {
|
||||
q := skipSameLineSpace(src, p)
|
||||
if q >= length || src[q] == '\n' {
|
||||
break
|
||||
}
|
||||
|
||||
// elif <on-expr> { ... }
|
||||
if strings.HasPrefix(src[q:], "elif") {
|
||||
next := q + len("elif")
|
||||
if next >= length {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
if src[next] == '\n' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
if !unicode.IsSpace(rune(src[next])) {
|
||||
if src[next] == '{' || src[next] == '}' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected whitespace after 'elif'")
|
||||
}
|
||||
next++
|
||||
for next < length {
|
||||
c := src[next]
|
||||
if c == '\n' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
|
||||
}
|
||||
if c == '\r' {
|
||||
next++
|
||||
continue
|
||||
}
|
||||
if !unicode.IsSpace(rune(c)) {
|
||||
break
|
||||
}
|
||||
next++
|
||||
}
|
||||
|
||||
p2 := next
|
||||
elifOnExpr, bracePos, herr := parseHeaderToBrace(src, p2)
|
||||
if herr != nil {
|
||||
return nil, 0, herr
|
||||
}
|
||||
if elifOnExpr == "" {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
if bracePos >= length || src[bracePos] != '{' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
|
||||
}
|
||||
p2 = bracePos
|
||||
elifBodyStart := p2 + 1
|
||||
elifBodyEnd, ferr := findMatchingBrace(src, &p2, elifBodyStart)
|
||||
if ferr != nil {
|
||||
return nil, 0, ferr
|
||||
}
|
||||
elifOn, elifDo, berr := parseBranch(elifOnExpr, elifBodyStart, elifBodyEnd)
|
||||
if berr != nil {
|
||||
return nil, 0, berr
|
||||
}
|
||||
ifs = append(ifs, IfBlockCommand{On: elifOn, Do: elifDo})
|
||||
hasChain = true
|
||||
p = p2
|
||||
continue
|
||||
}
|
||||
|
||||
// else { ... }
|
||||
if strings.HasPrefix(src[q:], "else") {
|
||||
if hasElse {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("multiple 'else' branches")
|
||||
}
|
||||
next := q + len("else")
|
||||
for next < length {
|
||||
c := src[next]
|
||||
if c == '\n' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
|
||||
}
|
||||
if c == '\r' {
|
||||
next++
|
||||
continue
|
||||
}
|
||||
if !unicode.IsSpace(rune(c)) {
|
||||
break
|
||||
}
|
||||
next++
|
||||
}
|
||||
if next >= length || src[next] != '{' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
|
||||
}
|
||||
|
||||
elseBodyStart := next + 1
|
||||
p2 := next
|
||||
elseBodyEnd, ferr := findMatchingBrace(src, &p2, elseBodyStart)
|
||||
if ferr != nil {
|
||||
return nil, 0, ferr
|
||||
}
|
||||
innerSrc := ""
|
||||
if elseBodyStart < elseBodyEnd {
|
||||
innerSrc = src[elseBodyStart:elseBodyEnd]
|
||||
}
|
||||
inner, ierr := parseDoWithBlocks(innerSrc)
|
||||
if ierr != nil {
|
||||
return nil, 0, ierr
|
||||
}
|
||||
if len(inner) == 0 {
|
||||
elseDo = nil
|
||||
} else {
|
||||
elseDo = inner
|
||||
}
|
||||
hasChain = true
|
||||
hasElse = true
|
||||
p = p2
|
||||
|
||||
// else must be the last branch on that line.
|
||||
for q2 := skipSameLineSpace(src, p); q2 < length && src[q2] != '\n'; q2 = skipSameLineSpace(src, q2) {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after else block")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after nested block; expected 'elif'/'else' or newline")
|
||||
}
|
||||
|
||||
if hasChain {
|
||||
return IfElseBlockCommand{Ifs: ifs, Else: elseDo}, p, nil
|
||||
}
|
||||
return IfBlockCommand{On: ifs[0].On, Do: ifs[0].Do}, p, nil
|
||||
}
|
||||
|
||||
// parseDoWithBlocks parses a do-body containing plain command lines and nested @-blocks.
|
||||
// It returns the outer command handlers and the require phase.
|
||||
//
|
||||
// A nested block is only recognized when '@' is the first non-space character on a line.
|
||||
func parseDoWithBlocks(src string) (handlers []CommandHandler, err error) {
|
||||
pos := 0
|
||||
length := len(src)
|
||||
lineStart := true
|
||||
handlers = make([]CommandHandler, 0, strings.Count(src, "\n")+1)
|
||||
|
||||
appendLineCommand := func(line string) error {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
directive, args, err := parse(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
builder, ok := commands[directive]
|
||||
if !ok {
|
||||
return ErrUnknownDirective.Subject(directive)
|
||||
}
|
||||
|
||||
phase, validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
h := builder.build(validArgs)
|
||||
handlers = append(handlers, Handler{fn: h, phase: phase, terminate: builder.terminate})
|
||||
return nil
|
||||
}
|
||||
|
||||
for pos < length {
|
||||
// Handle newlines
|
||||
switch src[pos] {
|
||||
case '\n':
|
||||
pos++
|
||||
lineStart = true
|
||||
continue
|
||||
case '\r':
|
||||
// tolerate CRLF
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if lineStart {
|
||||
// Find first non-space on the line.
|
||||
linePos := pos
|
||||
for linePos < length {
|
||||
c := rune(src[linePos])
|
||||
if c == '\n' {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(c) {
|
||||
break
|
||||
}
|
||||
linePos++
|
||||
}
|
||||
|
||||
if linePos < length && src[linePos] == '@' {
|
||||
h, next, err := parseAtBlockChain(src, linePos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handlers = append(handlers, h)
|
||||
pos = next
|
||||
lineStart = false
|
||||
continue
|
||||
}
|
||||
|
||||
// Not a nested block; parse the rest of this line as a command.
|
||||
lineEnd := pos
|
||||
for lineEnd < length && src[lineEnd] != '\n' {
|
||||
lineEnd++
|
||||
}
|
||||
if lerr := appendLineCommand(src[pos:lineEnd]); lerr != nil {
|
||||
return nil, lerr
|
||||
}
|
||||
pos = lineEnd
|
||||
lineStart = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Not at line start; advance to the next line boundary.
|
||||
for pos < length && src[pos] != '\n' {
|
||||
pos++
|
||||
}
|
||||
lineStart = true
|
||||
}
|
||||
|
||||
return handlers, nil
|
||||
}
|
||||
73
internal/route/rules/do_blocks_test.go
Normal file
73
internal/route/rules/do_blocks_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoNotFallsThrough(t *testing.T) {
|
||||
elseCalled := false
|
||||
cmd := IfElseBlockCommand{
|
||||
Ifs: []IfBlockCommand{
|
||||
{
|
||||
On: RuleOn{},
|
||||
Do: nil,
|
||||
},
|
||||
},
|
||||
Else: []CommandHandler{
|
||||
Handler{
|
||||
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
|
||||
elseCalled = true
|
||||
return nil
|
||||
},
|
||||
phase: PhaseNone,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := cmd.ServeHTTP(rm, req, nil)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, elseCalled)
|
||||
}
|
||||
|
||||
func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoNotFallsThrough(t *testing.T) {
|
||||
elseCalled := false
|
||||
cmd := IfElseBlockCommand{
|
||||
Ifs: []IfBlockCommand{
|
||||
{
|
||||
On: RuleOn{
|
||||
checker: CheckFunc(func(_ *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
return true
|
||||
}),
|
||||
},
|
||||
Do: nil,
|
||||
},
|
||||
},
|
||||
Else: []CommandHandler{
|
||||
Handler{
|
||||
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
|
||||
elseCalled = true
|
||||
return nil
|
||||
},
|
||||
phase: PhaseNone,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := cmd.ServeHTTP(rm, req, nil)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, elseCalled)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -37,7 +39,7 @@ func parseRules(data string, target *Rules) error {
|
||||
}
|
||||
|
||||
func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(200, "success response", http.Header{
|
||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success response", http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
})
|
||||
|
||||
@@ -45,10 +47,9 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-request-response
|
||||
do: |
|
||||
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
||||
`, logFile), &rules)
|
||||
default {
|
||||
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
||||
}`, logFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -59,7 +60,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "success response", w.Body.String())
|
||||
|
||||
// Read and verify log content
|
||||
@@ -70,16 +71,25 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
originalStdout := stdout
|
||||
originalStderr := stderr
|
||||
var stdoutBuf bytes.Buffer
|
||||
var stderrBuf bytes.Buffer
|
||||
stdout = noopWriteCloser{&stdoutBuf}
|
||||
stderr = noopWriteCloser{&stderrBuf}
|
||||
defer func() {
|
||||
stdout = originalStdout
|
||||
stderr = originalStderr
|
||||
}()
|
||||
|
||||
upstream := mockUpstream(http.StatusOK, "success")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: log-stdout
|
||||
do: |
|
||||
log info /dev/stdout "stdout: $req_method $status_code"
|
||||
- name: log-stderr
|
||||
do: |
|
||||
log error /dev/stderr "stderr: $req_path $status_code"
|
||||
default {
|
||||
log info /dev/stdout "stdout: $req_method $status_code"
|
||||
log error /dev/stderr "stderr: $req_path $status_code"
|
||||
}
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -90,9 +100,13 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
// Note: We can't easily capture stdout/stderr in unit tests,
|
||||
// but we can verify no errors occurred and the handler completed
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
require.Eventually(t, func() bool {
|
||||
return strings.Contains(stdoutBuf.String(), "stdout: GET 200") &&
|
||||
strings.Contains(stderrBuf.String(), "stderr: /test 200")
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
assert.Equal(t, 1, strings.Count(stdoutBuf.String(), "stdout: GET 200"))
|
||||
assert.Equal(t, 1, strings.Count(stderrBuf.String(), "stderr: /test 200"))
|
||||
}
|
||||
|
||||
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
@@ -104,26 +118,22 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-info
|
||||
do: |
|
||||
log info %s "INFO: $req_method $status_code"
|
||||
- name: log-warn
|
||||
do: |
|
||||
log warn %s "WARN: $req_path $status_code"
|
||||
- name: log-error
|
||||
do: |
|
||||
log error %s "ERROR: $req_method $req_path $status_code"
|
||||
default {
|
||||
log info %s "INFO: $req_method $status_code"
|
||||
log warn %s "WARN: $req_path $status_code"
|
||||
log error %s "ERROR: $req_method $req_path $status_code"
|
||||
}
|
||||
`, infoFile, warnFile, errorFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/resource/123", nil)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/resource/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
|
||||
// Verify each log file
|
||||
infoContent := TestFileContent(infoFile)
|
||||
@@ -148,22 +158,22 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-with-templates
|
||||
do: |
|
||||
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
||||
default {
|
||||
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
||||
}
|
||||
`, tempFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/api/resource", nil)
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/resource", nil)
|
||||
req.Header.Set("User-Agent", "test-client/1.0")
|
||||
req.Host = "example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 201, w.Code)
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
// Verify log content
|
||||
content := TestFileContent(tempFile)
|
||||
@@ -192,14 +202,12 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-success
|
||||
on: status 2xx
|
||||
do: |
|
||||
log info %q "SUCCESS: $req_method $req_path $status_code"
|
||||
- name: log-error
|
||||
on: status 4xx | status 5xx
|
||||
do: |
|
||||
log error %q "ERROR: $req_method $req_path $status_code"
|
||||
status 2xx {
|
||||
log info %q "SUCCESS: $req_method $req_path $status_code"
|
||||
}
|
||||
status 4xx | status 5xx {
|
||||
log error %q "ERROR: $req_method $req_path $status_code"
|
||||
}
|
||||
`, successFile, errorFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -244,9 +252,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-multiple
|
||||
do: |
|
||||
log info %q "$req_method $req_path $status_code"`, tempFile), &rules)
|
||||
default {
|
||||
log info %q "$req_method $req_path $status_code"
|
||||
}
|
||||
`, tempFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -256,10 +265,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"GET", "/users"},
|
||||
{"POST", "/users"},
|
||||
{"PUT", "/users/1"},
|
||||
{"DELETE", "/users/1"},
|
||||
{http.MethodGet, "/users"},
|
||||
{http.MethodPost, "/users"},
|
||||
{http.MethodPost, "/users/1"},
|
||||
{http.MethodDelete, "/users/1"},
|
||||
}
|
||||
|
||||
for _, reqInfo := range requests {
|
||||
@@ -287,8 +296,9 @@ func TestLogCommand_InvalidTemplate(t *testing.T) {
|
||||
|
||||
// Test with invalid template syntax
|
||||
err := parseRules(`
|
||||
- name: log-invalid
|
||||
do: |
|
||||
log info /dev/stdout "$invalid_var"`, &rules)
|
||||
assert.ErrorIs(t, err, ErrUnexpectedVar)
|
||||
default {
|
||||
log info /dev/stdout "$invalid_var"
|
||||
}
|
||||
`, &rules)
|
||||
require.ErrorIs(t, err, ErrUnexpectedVar)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type (
|
||||
FieldHandler struct {
|
||||
set, add, remove CommandHandler
|
||||
set, add, remove HandlerFunc
|
||||
}
|
||||
FieldModifier string
|
||||
)
|
||||
@@ -49,30 +49,30 @@ var modFields = map[string]struct {
|
||||
"value": "the header template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePreRequestKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Header[k] = []string{v}
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Header[k] = append(r.Header[k], v)
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
delete(r.Header, k)
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -84,30 +84,30 @@ var modFields = map[string]struct {
|
||||
"value": "the response header template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePostResponseKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header()[k] = []string{v}
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header()[k] = append(w.Header()[k], v)
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
delete(w.Header(), k)
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -119,36 +119,36 @@ var modFields = map[string]struct {
|
||||
"value": "the query template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePreRequestKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Set(k, v)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Add(k, v)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Del(k)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -160,16 +160,16 @@ var modFields = map[string]struct {
|
||||
"value": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePreRequestKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
cookies[i].Value = v
|
||||
@@ -179,19 +179,19 @@ var modFields = map[string]struct {
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
index := -1
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
@@ -208,7 +208,7 @@ var modFields = map[string]struct {
|
||||
return cookies
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -227,24 +227,27 @@ var modFields = map[string]struct {
|
||||
"template": "the body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return 0, nil, ErrExpectOneArg
|
||||
}
|
||||
return validateTemplate(args[0], true)
|
||||
phase = PhasePre
|
||||
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
|
||||
phase |= tmplReq
|
||||
return
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateString)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if r.Body != nil {
|
||||
r.Body.Close()
|
||||
r.Body = nil
|
||||
}
|
||||
|
||||
bufPool := httputils.GetInitResponseModifier(w).BufPool()
|
||||
bufPool := w.BufPool()
|
||||
b := bufPool.GetBuffer()
|
||||
err := tmpl.ExpandVars(w, r, b)
|
||||
_, err := tmpl.ExpandVars(w, r, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -252,7 +255,7 @@ var modFields = map[string]struct {
|
||||
bufPool.PutBuffer(b)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -272,20 +275,26 @@ var modFields = map[string]struct {
|
||||
"template": "the response body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return 0, nil, ErrExpectOneArg
|
||||
}
|
||||
return validateTemplate(args[0], true)
|
||||
phase = PhasePost
|
||||
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
|
||||
phase |= tmplReq
|
||||
return
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateString)
|
||||
return &FieldHandler{
|
||||
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rm := httputils.GetInitResponseModifier(w)
|
||||
rm.ResetBody()
|
||||
return tmpl.ExpandVars(w, r, rm)
|
||||
}),
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.ResetBody()
|
||||
_, err := tmpl.ExpandVars(w, r, w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -300,26 +309,27 @@ var modFields = map[string]struct {
|
||||
"code": "the status code",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
phase = PhasePost
|
||||
status, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return phase, nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if status < 100 || status > 599 {
|
||||
return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
||||
return phase, nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
||||
}
|
||||
return status, nil
|
||||
return phase, status, nil
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
status := args.(int)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
httputils.GetInitResponseModifier(w).WriteHeader(status)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.WriteHeader(status)
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -72,12 +71,12 @@ func TestFieldHandler_Header(t *testing.T) {
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -87,7 +86,7 @@ func TestFieldHandler_Header(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -150,12 +149,12 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
tt.setup(w)
|
||||
}
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -165,7 +164,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -237,12 +236,12 @@ func TestFieldHandler_Query(t *testing.T) {
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -252,7 +251,7 @@ func TestFieldHandler_Query(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -335,12 +334,12 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -350,7 +349,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -371,7 +370,7 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
name: "set body with template",
|
||||
template: "Hello $req_method $req_path",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "POST"
|
||||
r.Method = http.MethodPost
|
||||
r.URL.Path = "/test"
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
@@ -399,15 +398,15 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.template, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.template, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to parse template: %v", tErr)
|
||||
}
|
||||
|
||||
handler := modFields[FieldBody].builder(tmpl)
|
||||
err := handler.set.Handle(w, req)
|
||||
err := handler.set(w, req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -428,7 +427,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
name: "set response body with template",
|
||||
template: "Response: $req_method $req_path",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "GET"
|
||||
r.Method = http.MethodGet
|
||||
r.URL.Path = "/api/test"
|
||||
},
|
||||
verify: func(rm *httputils.ResponseModifier) {
|
||||
@@ -443,23 +442,20 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
// Create ResponseModifier wrapper
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.template, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.template, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to parse template: %v", tErr)
|
||||
}
|
||||
|
||||
handler := modFields[FieldResponseBody].builder(tmpl)
|
||||
err := handler.set.Handle(rm, req)
|
||||
err := handler.set(w, req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(rm)
|
||||
tt.verify(w)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -472,23 +468,23 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "set status code 200",
|
||||
status: 200,
|
||||
status: http.StatusOK,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 200, w.Code, "Expected status code 200")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set status code 404",
|
||||
status: 404,
|
||||
status: http.StatusNotFound,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 404, w.Code, "Expected status code 404")
|
||||
assert.Equal(t, http.StatusNotFound, w.Code, "Expected status code 404")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set status code 500",
|
||||
status: 500,
|
||||
status: http.StatusInternalServerError,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 500, w.Code, "Expected status code 500")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code, "Expected status code 500")
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -503,12 +499,11 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
err = cmd.ServeHTTP(rm, req)
|
||||
err = cmd.post.ServeHTTP(rm, req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
rm.FlushRelease()
|
||||
|
||||
tt.verify(w)
|
||||
})
|
||||
}
|
||||
@@ -600,7 +595,7 @@ func TestFieldValidation(t *testing.T) {
|
||||
field, exists := modFields[tt.field]
|
||||
assert.True(t, exists, "Field %s does not exist", tt.field)
|
||||
|
||||
_, err := field.validate(tt.args)
|
||||
_, _, err := field.validate(tt.args)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err, "Expected error but got none")
|
||||
} else {
|
||||
@@ -610,25 +605,6 @@ func TestFieldValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllFields(t *testing.T) {
|
||||
expectedFields := []string{
|
||||
FieldHeader,
|
||||
FieldResponseHeader,
|
||||
FieldQuery,
|
||||
FieldCookie,
|
||||
FieldBody,
|
||||
FieldResponseBody,
|
||||
FieldStatusCode,
|
||||
}
|
||||
|
||||
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
|
||||
|
||||
for _, expected := range expectedFields {
|
||||
found := slices.Contains(AllFields, expected)
|
||||
assert.True(t, found, "Expected field %s not found in AllFields", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModFields(t *testing.T) {
|
||||
for fieldName, field := range modFields {
|
||||
// Test that each field has required components
|
||||
|
||||
@@ -14,8 +14,9 @@ var (
|
||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
||||
ErrMultipleDefaultRules = gperr.New("multiple default rules")
|
||||
|
||||
ErrMultipleDefaultRules = gperr.New("multiple default rules")
|
||||
ErrDeadRule = gperr.New("dead rule")
|
||||
|
||||
// vars errors
|
||||
ErrNoArgProvided = gperr.New("no argument provided")
|
||||
@@ -31,5 +32,5 @@ var (
|
||||
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
|
||||
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
|
||||
|
||||
errTerminated = gperr.New("terminated")
|
||||
ErrInvalidBlockSyntax = gperr.New("invalid block syntax") // TODO: struct this error
|
||||
)
|
||||
|
||||
@@ -131,12 +131,13 @@ Error generates help string as error, e.g.
|
||||
from: the path to rewrite, must start with /
|
||||
to: the path to rewrite to, must start with /
|
||||
*/
|
||||
func (h *Help) Error() error {
|
||||
var lines gperr.MultilineError
|
||||
func (h *Help) Error() gperr.Error {
|
||||
help := gperr.New(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
for _, line := range h.description {
|
||||
help = help.Withf("%s", line)
|
||||
}
|
||||
|
||||
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
lines.AddStrings(h.description...)
|
||||
lines.Adds(" args:")
|
||||
args := gperr.New("args")
|
||||
|
||||
argKeys := make([]string, 0, len(h.args))
|
||||
longestArg := 0
|
||||
@@ -151,7 +152,9 @@ func (h *Help) Error() error {
|
||||
slices.Sort(argKeys)
|
||||
for _, arg := range argKeys {
|
||||
desc := h.args[arg]
|
||||
lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc)
|
||||
paddedArg := fmt.Sprintf("%-"+strconv.Itoa(longestArg)+"s", arg)
|
||||
args = args.Withf("%s%s", ansi.WithANSI(paddedArg, ansi.HighlightCyan)+": ", desc)
|
||||
}
|
||||
return &lines
|
||||
|
||||
return help.With(args)
|
||||
}
|
||||
|
||||
1357
internal/route/rules/http_flow_block_test.go
Normal file
1357
internal/route/rules/http_flow_block_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,9 @@ import (
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
func mockUpstream(body string) http.HandlerFunc {
|
||||
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
@@ -47,7 +48,7 @@ func parseRules(data string, target *Rules) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
func TestHTTPFlow_BasicPreRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -74,8 +75,8 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
upstream := mockUpstream("upstream response")
|
||||
func TestHTTPFlow_BypassRuleYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -99,8 +100,8 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
upstream := mockUpstream("should not be called")
|
||||
func TestHTTPFlow_TerminatingCommandYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -120,13 +121,13 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, "Forbidden\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
upstream := mockUpstream("should not be called")
|
||||
func TestHTTPFlow_RedirectFlowYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -143,11 +144,11 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, 307, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
func TestHTTPFlow_RewriteFlowYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("path: " + r.URL.Path))
|
||||
@@ -172,7 +173,7 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
assert.Equal(t, "path: /v1/users", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
func TestHTTPFlow_MultiplePreRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
||||
@@ -201,7 +202,7 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
func TestHTTPFlow_PostResponseRuleYAML(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
|
||||
"X-Upstream": []string{"upstream-value"},
|
||||
})
|
||||
@@ -229,11 +230,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
|
||||
// Check log file
|
||||
content := TestFileContent(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "GET 200\n", string(content))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseRuleWithStatusConditionYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/success" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -246,14 +246,17 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
|
||||
// Create a temporary file for logging
|
||||
tempFile := TestRandomFileName()
|
||||
errorLog := TestRandomFileName()
|
||||
infoLog := TestRandomFileName()
|
||||
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-errors
|
||||
on: status 4xx
|
||||
do: log error %s "$req_url returned $status_code"
|
||||
`, tempFile), &rules)
|
||||
status 4xx {
|
||||
log error %s "$req_url returned $status_code"
|
||||
}
|
||||
status 200 {
|
||||
log info %s "$req_url returned $status_code"
|
||||
}
|
||||
`, errorLog, infoLog), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -273,14 +276,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
|
||||
// Check log file
|
||||
content := TestFileContent(tempFile)
|
||||
require.NoError(t, err)
|
||||
content := TestFileContent(errorLog)
|
||||
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
|
||||
require.Len(t, lines, 1, "only 4xx requests should be logged")
|
||||
assert.Equal(t, "/notfound returned 404", lines[0])
|
||||
|
||||
infoContent := TestFileContent(infoLog)
|
||||
lines = strings.Split(strings.TrimSpace(string(infoContent)), "\n")
|
||||
require.Len(t, lines, 1, "only 200 requests should be logged")
|
||||
assert.Equal(t, "/success returned 200", lines[0])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
func TestHTTPFlow_ConditionalRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
||||
@@ -320,22 +327,21 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
func TestHTTPFlow_ComplexFlowWithPreAndPostRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate different responses based on path
|
||||
if r.URL.Path == "/protected" {
|
||||
if r.Header.Get("X-Auth") != "valid" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("unauthorized"))
|
||||
fmt.Fprint(w, "unauthorized")
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("X-Response-Time", "100ms")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
fmt.Fprint(w, "success")
|
||||
})
|
||||
|
||||
// Create temporary files for logging
|
||||
logFile := TestRandomFileName()
|
||||
errorLogFile := TestRandomFileName()
|
||||
|
||||
@@ -402,8 +408,8 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
upstream := mockUpstream("upstream response")
|
||||
func TestHTTPFlow_DefaultRuleYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -426,21 +432,57 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule + default rule
|
||||
// Test special rule (default should not run)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
func TestHTTPFlow_DefaultRuleWithOnDefaultYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: default-on-rule
|
||||
on: default
|
||||
do: set resp_header X-Default-Applied true
|
||||
- name: special-rule
|
||||
on: path /special
|
||||
do: set resp_header X-Special-Handled true
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test default rule on regular request
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/regular", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule on matching request (default should not run)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Empty(t, w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_HeaderManipulationYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Echo back a header
|
||||
headerValue := r.Header.Get("X-Test-Header")
|
||||
w.Header().Set("X-Echoed-Header", headerValue)
|
||||
w.Header().Set("X-Secret", "sensitive-data")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("header echoed"))
|
||||
})
|
||||
@@ -460,7 +502,6 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Secret", "secret-value")
|
||||
req.Header.Set("X-Test-Header", "original-value")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -469,11 +510,10 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
||||
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
||||
// Ensure the secret header was removed and not passed to upstream
|
||||
// (we can't directly test this, but the upstream shouldn't see it)
|
||||
assert.Empty(t, w.Header().Get("X-Secret"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
func TestHTTPFlow_QueryParameterHandlingYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -500,13 +540,15 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
assert.Equal(t, "query: added-value", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
func TestHTTPFlow_ServeCommandYAML(t *testing.T) {
|
||||
// Create a temporary directory with test files
|
||||
tempDir := t.TempDir()
|
||||
tempDir, err := os.MkdirTemp("", "test-serve-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create test files directly in the temp directory
|
||||
testFile := filepath.Join(tempDir, "index.html")
|
||||
err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var rules Rules
|
||||
@@ -517,7 +559,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
`, tempDir), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
|
||||
|
||||
// Test serving a file - serve command serves files relative to the root directory
|
||||
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
||||
@@ -546,7 +588,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
func TestHTTPFlow_ProxyCommandYAML(t *testing.T) {
|
||||
// Create a mock upstream server
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Header", "upstream-value")
|
||||
@@ -563,7 +605,7 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
`, upstreamServer.URL), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -576,11 +618,28 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_NotifyCommand(t *testing.T) {
|
||||
// TODO:
|
||||
func TestHTTPFlow_NotifyCommandYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "ok")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: notify-rule
|
||||
on: path /notify
|
||||
do: notify info test-provider "title $req_method" "body $req_url $status_code"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/notify", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "ok", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
func TestHTTPFlow_FormConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("form processed"))
|
||||
@@ -620,7 +679,7 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
func TestHTTPFlow_RemoteConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("remote processed"))
|
||||
@@ -654,11 +713,11 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code)
|
||||
assert.Equal(t, 403, w2.Code)
|
||||
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
func TestHTTPFlow_BasicAuthConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("auth processed"))
|
||||
@@ -702,7 +761,7 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
func TestHTTPFlow_RouteConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("route processed"))
|
||||
@@ -742,10 +801,10 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseStatusConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
w.Write([]byte("method not allowed"))
|
||||
fmt.Fprint(w, "method not allowed")
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
@@ -767,11 +826,11 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseHeaderConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Response-Header", "response header")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("processed"))
|
||||
fmt.Fprint(w, "processed")
|
||||
})
|
||||
|
||||
t.Run("any_value", func(t *testing.T) {
|
||||
@@ -831,7 +890,65 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMatchersYAML(t *testing.T) {
|
||||
upstreamCalled := false
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: path /
|
||||
do: error 403 blocked
|
||||
- on: path /
|
||||
do: set resp_header X-Late should-not-run
|
||||
- on: status 4xx
|
||||
do: set resp_header X-Post true
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.False(t, upstreamCalled)
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, "blocked\n", w.Body.String())
|
||||
assert.Equal(t, "should-not-run", w.Header().Get("X-Late"))
|
||||
assert.Equal(t, "true", w.Header().Get("X-Post"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostRuleTermination_StopsRemainingCommandsInRuleYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: status 200
|
||||
do: |
|
||||
error 500 failed
|
||||
set resp_header X-After should-not-run
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, "failed\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-After"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinationsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("complex processed"))
|
||||
@@ -887,12 +1004,12 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
assert.Equal(t, 200, w3.Code)
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
|
||||
assert.Empty(t, w3.Header()["X-API-Version"])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseModifierYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("original response"))
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type noopWriteCloser struct {
|
||||
@@ -30,7 +31,7 @@ var (
|
||||
testFilesLock sync.Mutex
|
||||
)
|
||||
|
||||
func openFile(path string) (io.WriteCloser, error) {
|
||||
func openFile(path string) (io.WriteCloser, gperr.Error) {
|
||||
switch path {
|
||||
case "/dev/stdout":
|
||||
return stdout, nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,8 @@ type (
|
||||
MatcherType string
|
||||
)
|
||||
|
||||
var matcherCache = xsync.NewMap[string, Matcher]() // map[string]Matcher
|
||||
|
||||
const (
|
||||
MatcherTypeString MatcherType = "string"
|
||||
MatcherTypeGlob MatcherType = "glob"
|
||||
@@ -59,7 +62,12 @@ func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Erro
|
||||
}
|
||||
|
||||
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
if cached, ok := matcherCache.Load(expr); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
negate := false
|
||||
origExpr := expr
|
||||
if strings.HasPrefix(expr, "!") {
|
||||
negate = true
|
||||
expr = expr[1:]
|
||||
@@ -72,11 +80,23 @@ func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
|
||||
switch t {
|
||||
case MatcherTypeString:
|
||||
return StringMatcher(expr, negate)
|
||||
m, err := StringMatcher(expr, negate)
|
||||
if err == nil {
|
||||
matcherCache.Store(origExpr, m)
|
||||
}
|
||||
return m, err
|
||||
case MatcherTypeGlob:
|
||||
return GlobMatcher(expr, negate)
|
||||
m, err := GlobMatcher(expr, negate)
|
||||
if err == nil {
|
||||
matcherCache.Store(origExpr, m)
|
||||
}
|
||||
return m, err
|
||||
case MatcherTypeRegex:
|
||||
return RegexMatcher(expr, negate)
|
||||
m, err := RegexMatcher(expr, negate)
|
||||
if err == nil {
|
||||
matcherCache.Store(origExpr, m)
|
||||
}
|
||||
return m, err
|
||||
}
|
||||
// won't reach here
|
||||
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
||||
|
||||
@@ -12,19 +12,19 @@ import (
|
||||
)
|
||||
|
||||
type RuleOn struct {
|
||||
raw string
|
||||
checker Checker
|
||||
isResponseChecker bool
|
||||
}
|
||||
|
||||
func (on *RuleOn) IsResponseChecker() bool {
|
||||
return on.isResponseChecker
|
||||
raw string
|
||||
checker Checker
|
||||
phase PhaseFlag
|
||||
}
|
||||
|
||||
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
return on.checker.Check(w, r)
|
||||
if on.checker == nil {
|
||||
return true
|
||||
}
|
||||
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
|
||||
}
|
||||
|
||||
// on request
|
||||
const (
|
||||
OnDefault = "default"
|
||||
OnHeader = "header"
|
||||
@@ -39,35 +39,36 @@ const (
|
||||
OnRemote = "remote"
|
||||
OnBasicAuth = "basic_auth"
|
||||
OnRoute = "route"
|
||||
)
|
||||
|
||||
// on response
|
||||
|
||||
// on response
|
||||
const (
|
||||
OnResponseHeader = "resp_header"
|
||||
OnStatus = "status"
|
||||
)
|
||||
|
||||
var checkers = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
isResponseChecker bool
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
}{
|
||||
OnDefault: {
|
||||
help: Help{
|
||||
command: OnDefault,
|
||||
description: makeLines(
|
||||
"The default rule is matched when no other rules are matched.",
|
||||
"Select the default (fallback) rule.",
|
||||
),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
return phase, nil, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool { return true }
|
||||
},
|
||||
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
|
||||
},
|
||||
OnHeader: {
|
||||
help: Help{
|
||||
@@ -83,21 +84,23 @@ var checkers = map[string]struct {
|
||||
"[value]": "the header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(r.Header[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(r.Header[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnResponseHeader: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnResponseHeader,
|
||||
description: makeLines(
|
||||
@@ -111,16 +114,20 @@ var checkers = map[string]struct {
|
||||
"[value]": "the response header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePost
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(w.Header()[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(w.Header()[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -138,16 +145,19 @@ var checkers = map[string]struct {
|
||||
"[value]": "the query value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(w.SharedData().GetQueries(r)[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -165,12 +175,15 @@ var checkers = map[string]struct {
|
||||
"[value]": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
cookies := w.SharedData().GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
return true
|
||||
@@ -179,8 +192,8 @@ var checkers = map[string]struct {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
cookies := w.SharedData().GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
if matcher(cookie.Value) {
|
||||
@@ -192,6 +205,7 @@ var checkers = map[string]struct {
|
||||
}
|
||||
},
|
||||
},
|
||||
//nolint:dupl
|
||||
OnForm: {
|
||||
help: Help{
|
||||
command: OnForm,
|
||||
@@ -206,15 +220,18 @@ var checkers = map[string]struct {
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.FormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.FormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -233,15 +250,18 @@ var checkers = map[string]struct {
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.PostFormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.PostFormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -250,32 +270,46 @@ var checkers = map[string]struct {
|
||||
help: Help{
|
||||
command: OnProto,
|
||||
args: map[string]string{
|
||||
"proto": "the http protocol (http, https, h3)",
|
||||
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
proto := args[0]
|
||||
if proto != "http" && proto != "https" && proto != "h3" {
|
||||
return nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||
switch proto {
|
||||
case "http", "https", "h1", "h2", "h2c", "h3":
|
||||
return phase, proto, nil
|
||||
default:
|
||||
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||
}
|
||||
return proto, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
proto := args.(string)
|
||||
switch proto {
|
||||
case "http":
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil
|
||||
}
|
||||
case "https":
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil
|
||||
}
|
||||
case "h1":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil && r.ProtoMajor == 1
|
||||
}
|
||||
case "h2":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil && r.ProtoMajor == 2
|
||||
}
|
||||
case "h2c":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil && r.ProtoMajor == 2
|
||||
}
|
||||
default: // h3
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil && r.ProtoMajor == 3
|
||||
}
|
||||
}
|
||||
@@ -288,10 +322,13 @@ var checkers = map[string]struct {
|
||||
"method": "the http method",
|
||||
},
|
||||
},
|
||||
validate: validateMethod,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateMethod(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
method := args.(string)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.Method == method
|
||||
}
|
||||
},
|
||||
@@ -310,10 +347,13 @@ var checkers = map[string]struct {
|
||||
"host": "the host name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateSingleMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.Host)
|
||||
}
|
||||
},
|
||||
@@ -332,10 +372,13 @@ var checkers = map[string]struct {
|
||||
"path": "the request path",
|
||||
},
|
||||
},
|
||||
validate: validateURLPathMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateURLPathMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
reqPath := r.URL.Path
|
||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||
reqPath = "/" + reqPath
|
||||
@@ -351,22 +394,25 @@ var checkers = map[string]struct {
|
||||
"ip|cidr": "the remote ip or cidr",
|
||||
},
|
||||
},
|
||||
validate: validateCIDR,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateCIDR(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
ipnet := args.(*net.IPNet)
|
||||
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||
wantIP := ipnet.IP
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
ip := w.SharedData().GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(wantIP)
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
ip := w.SharedData().GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
@@ -382,11 +428,14 @@ var checkers = map[string]struct {
|
||||
"password": "the password encrypted with bcrypt",
|
||||
},
|
||||
},
|
||||
validate: validateUserBCryptPassword,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateUserBCryptPassword(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
cred := args.(*HashedCrendentials)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return cred.Match(w.SharedData().GetBasicAuth(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -403,16 +452,18 @@ var checkers = map[string]struct {
|
||||
"route": "the route name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateSingleMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(_ http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(routes.TryGetUpstreamName(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
OnStatus: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnStatus,
|
||||
description: makeLines(
|
||||
@@ -429,16 +480,20 @@ var checkers = map[string]struct {
|
||||
"status": "the status code range",
|
||||
},
|
||||
},
|
||||
validate: validateStatusRange,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePost
|
||||
parsedArgs, err = validateStatusRange(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
beg, end := args.(*IntTuple).Unpack()
|
||||
if beg == end {
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
return httputils.GetInitResponseModifier(w).StatusCode() == beg
|
||||
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
return w.StatusCode() == beg
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
|
||||
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
statusCode := w.StatusCode()
|
||||
return statusCode >= beg && statusCode <= end
|
||||
}
|
||||
},
|
||||
@@ -515,85 +570,119 @@ func splitPipe(s string) []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
result := make([]string, 0, 2)
|
||||
quote := byte(0)
|
||||
brackets := 0
|
||||
start := 0
|
||||
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '\\':
|
||||
escaped = true
|
||||
current.WriteRune(r)
|
||||
// Skip escaped character.
|
||||
if i+1 < len(s) {
|
||||
i++
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
if quote == 0 && brackets == 0 {
|
||||
quote = r
|
||||
} else if r == quote {
|
||||
quote = s[i]
|
||||
} else if s[i] == quote {
|
||||
quote = 0
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '(':
|
||||
brackets++
|
||||
current.WriteRune(r)
|
||||
case ')':
|
||||
if brackets > 0 {
|
||||
brackets--
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '|':
|
||||
if quote == 0 && brackets == 0 {
|
||||
// Found a pipe outside quotes/brackets, split here
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
result = append(result, strings.TrimSpace(s[start:i]))
|
||||
start = i + 1
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last part
|
||||
if current.Len() > 0 {
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
// drop trailing empty part.
|
||||
if start < len(s) {
|
||||
result = append(result, strings.TrimSpace(s[start:]))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func forEachAndPart(s string, fn func(part string)) {
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i < len(s) && andSeps[s[i]] == 0 {
|
||||
continue
|
||||
}
|
||||
part := strings.TrimSpace(s[start:i])
|
||||
if part != "" {
|
||||
fn(part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
func forEachPipePart(s string, fn func(part string)) {
|
||||
quote := byte(0)
|
||||
brackets := 0
|
||||
start := 0
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '\\':
|
||||
if i+1 < len(s) {
|
||||
i++
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
if quote == 0 && brackets == 0 {
|
||||
quote = s[i]
|
||||
} else if s[i] == quote {
|
||||
quote = 0
|
||||
}
|
||||
case '(':
|
||||
brackets++
|
||||
case ')':
|
||||
if brackets > 0 {
|
||||
brackets--
|
||||
}
|
||||
case '|':
|
||||
if quote == 0 && brackets == 0 {
|
||||
fn(strings.TrimSpace(s[start:i]))
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
if start < len(s) {
|
||||
fn(strings.TrimSpace(s[start:]))
|
||||
}
|
||||
}
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (on *RuleOn) Parse(v string) error {
|
||||
on.raw = v
|
||||
|
||||
rules := splitAnd(v)
|
||||
checkAnd := make(CheckMatchAll, 0, len(rules))
|
||||
ruleCount := 0
|
||||
forEachAndPart(v, func(_ string) {
|
||||
ruleCount++
|
||||
})
|
||||
checkAnd := make(CheckMatchAll, 0, ruleCount)
|
||||
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
isResponseChecker := false
|
||||
for i, rule := range rules {
|
||||
if rule == "" {
|
||||
continue
|
||||
}
|
||||
parsed, isResp, err := parseOn(rule)
|
||||
i := 0
|
||||
forEachAndPart(v, func(rule string) {
|
||||
i++
|
||||
parsed, phase, err := parseOn(rule)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "line %d", i+1)
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
errs.AddSubjectf(err, "line %d", i)
|
||||
return
|
||||
}
|
||||
on.phase |= phase
|
||||
checkAnd = append(checkAnd, parsed)
|
||||
}
|
||||
})
|
||||
|
||||
on.checker = checkAnd
|
||||
on.isResponseChecker = isResponseChecker
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
@@ -605,33 +694,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||
return []byte(on.String()), nil
|
||||
}
|
||||
|
||||
func parseOn(line string) (Checker, bool, error) {
|
||||
ors := splitPipe(line)
|
||||
|
||||
if len(ors) > 1 {
|
||||
func parseOn(line string) (Checker, PhaseFlag, error) {
|
||||
orCount := 0
|
||||
forEachPipePart(line, func(_ string) {
|
||||
orCount++
|
||||
})
|
||||
if orCount > 1 {
|
||||
var phase PhaseFlag
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
checkOr := make(CheckMatchSingle, len(ors))
|
||||
isResponseChecker := false
|
||||
for i, or := range ors {
|
||||
curCheckers, isResp, err := parseOn(or)
|
||||
checkOr := make(CheckMatchSingle, orCount)
|
||||
i := 0
|
||||
forEachPipePart(line, func(or string) {
|
||||
i++
|
||||
checkFunc, req, err := parseOnAtom(or)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
errs.AddSubjectf(err, "or[%d]", i)
|
||||
return
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
}
|
||||
checkOr[i] = curCheckers.(CheckFunc)
|
||||
}
|
||||
checkOr[i-1] = checkFunc
|
||||
phase |= req
|
||||
})
|
||||
if err := errs.Error(); err != nil {
|
||||
return nil, false, err
|
||||
return nil, phase, err
|
||||
}
|
||||
return checkOr, isResponseChecker, nil
|
||||
return checkOr, phase, nil
|
||||
}
|
||||
|
||||
return parseOnAtom(line)
|
||||
}
|
||||
|
||||
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
|
||||
var phase PhaseFlag
|
||||
subject, args, err := parse(line)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, phase, err
|
||||
}
|
||||
|
||||
negate := false
|
||||
@@ -642,20 +738,21 @@ func parseOn(line string) (Checker, bool, error) {
|
||||
|
||||
checker, ok := checkers[subject]
|
||||
if !ok {
|
||||
return nil, false, ErrInvalidOnTarget.Subject(subject)
|
||||
return nil, phase, ErrInvalidOnTarget.Subject(subject)
|
||||
}
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
req, validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, false, gperr.Wrap(err).With(checker.help.Error())
|
||||
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
|
||||
}
|
||||
phase |= req
|
||||
|
||||
checkFunc := checker.builder(validArgs)
|
||||
if negate {
|
||||
origCheckFunc := checkFunc
|
||||
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
|
||||
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return !origCheckFunc(w, r)
|
||||
}
|
||||
}
|
||||
return checkFunc, checker.isResponseChecker, nil
|
||||
return checkFunc, phase, nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
. "github.com/yusing/godoxy/internal/route/rules"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
@@ -386,7 +387,7 @@ func TestOnCorrectness(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
var on RuleOn
|
||||
err := on.Parse(tt.checker)
|
||||
expect.NoError(t, err)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/yusing/goutils/env"
|
||||
@@ -25,6 +24,76 @@ var quoteChars = [256]bool{
|
||||
'`': true,
|
||||
}
|
||||
|
||||
func parseSimple(v string) (subject string, args []string, err error, ok bool) {
|
||||
brackets := 0
|
||||
for i := range len(v) {
|
||||
switch v[i] {
|
||||
case '\\', '$', '"', '\'', '`', '\t', '\r', '\n':
|
||||
return "", nil, nil, false
|
||||
case '(':
|
||||
brackets++
|
||||
case ')':
|
||||
if brackets == 0 {
|
||||
return "", nil, ErrUnterminatedBrackets, true
|
||||
}
|
||||
brackets--
|
||||
}
|
||||
}
|
||||
if brackets != 0 {
|
||||
return "", nil, ErrUnterminatedBrackets, true
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(v) && v[i] == ' ' {
|
||||
i++
|
||||
}
|
||||
if i >= len(v) {
|
||||
return "", nil, nil, true
|
||||
}
|
||||
|
||||
start := i
|
||||
for i < len(v) && v[i] != ' ' {
|
||||
i++
|
||||
}
|
||||
subject = v[start:i]
|
||||
|
||||
if i >= len(v) {
|
||||
return subject, nil, nil, true
|
||||
}
|
||||
|
||||
argCount := 0
|
||||
for j := i; j < len(v); {
|
||||
for j < len(v) && v[j] == ' ' {
|
||||
j++
|
||||
}
|
||||
if j >= len(v) {
|
||||
break
|
||||
}
|
||||
argCount++
|
||||
for j < len(v) && v[j] != ' ' {
|
||||
j++
|
||||
}
|
||||
}
|
||||
if argCount == 0 {
|
||||
return subject, nil, nil, true
|
||||
}
|
||||
args = make([]string, 0, argCount)
|
||||
for i < len(v) {
|
||||
for i < len(v) && v[i] == ' ' {
|
||||
i++
|
||||
}
|
||||
if i >= len(v) {
|
||||
break
|
||||
}
|
||||
start = i
|
||||
for i < len(v) && v[i] != ' ' {
|
||||
i++
|
||||
}
|
||||
args = append(args, v[start:i])
|
||||
}
|
||||
return subject, args, nil, true
|
||||
}
|
||||
|
||||
// parse expression to subject and args
|
||||
// with support for quotes, escaped chars, and env substitution, e.g.
|
||||
//
|
||||
@@ -32,14 +101,21 @@ var quoteChars = [256]bool{
|
||||
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
||||
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
||||
func parse(v string) (subject string, args []string, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
|
||||
if subject, args, err, ok := parseSimple(v); ok {
|
||||
return subject, args, err
|
||||
}
|
||||
|
||||
buf := getStringBuffer(len(v))
|
||||
args = make([]string, 0, 4)
|
||||
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
brackets := 0
|
||||
|
||||
var envVar bytes.Buffer
|
||||
var missingEnvVars []string
|
||||
var (
|
||||
envVar strings.Builder
|
||||
missingEnvVars []string
|
||||
)
|
||||
inEnvVar := false
|
||||
expectingBrace := false
|
||||
|
||||
@@ -71,7 +147,8 @@ func parse(v string) (subject string, args []string, err error) {
|
||||
if ch, ok := escapedChars[r]; ok {
|
||||
buf.WriteRune(ch)
|
||||
} else {
|
||||
fmt.Fprintf(buf, `\%c`, r)
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
escaped = false
|
||||
continue
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,7 @@ func TestParser(t *testing.T) {
|
||||
input string
|
||||
subject string
|
||||
args []string
|
||||
wantErr gperr.Error
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
@@ -90,6 +92,10 @@ func TestParser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subject, args, err := parse(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
expect.ErrorIs(t, tt.wantErr, err)
|
||||
return
|
||||
}
|
||||
// t.Log(subject, args, err)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, tt.subject)
|
||||
|
||||
29
internal/route/rules/phase.go
Normal file
29
internal/route/rules/phase.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package rules
|
||||
|
||||
import "strings"
|
||||
|
||||
type PhaseFlag uint8
|
||||
|
||||
const (
|
||||
PhaseNone PhaseFlag = 0
|
||||
PhasePre PhaseFlag = 1 << (iota - 1)
|
||||
PhasePost
|
||||
)
|
||||
|
||||
func (phase PhaseFlag) IsPostRule() bool {
|
||||
return phase&PhasePost != 0
|
||||
}
|
||||
|
||||
func (phase PhaseFlag) String() string {
|
||||
if phase == PhaseNone {
|
||||
return "none"
|
||||
}
|
||||
var flags []string
|
||||
if phase&PhasePre != 0 {
|
||||
flags = append(flags, "PhasePre")
|
||||
}
|
||||
if phase&PhasePost != 0 {
|
||||
flags = append(flags, "PhasePost")
|
||||
}
|
||||
return strings.Join(flags, ",")
|
||||
}
|
||||
@@ -4,9 +4,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
@@ -15,37 +22,36 @@ import (
|
||||
|
||||
type (
|
||||
/*
|
||||
Example:
|
||||
Rules is a list of rules.
|
||||
|
||||
proxy.app1.rules: |
|
||||
- name: default
|
||||
do: |
|
||||
rewrite / /index.html
|
||||
serve /var/www/goaccess
|
||||
- name: ws
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
Example:
|
||||
|
||||
proxy.app2.rules: |
|
||||
- name: default
|
||||
do: bypass
|
||||
- name: block POST and PUT
|
||||
on: method POST | method PUT
|
||||
do: error 403 Forbidden
|
||||
proxy.app1.rules: |
|
||||
- name: default
|
||||
do: |
|
||||
rewrite / /index.html
|
||||
serve /var/www/goaccess
|
||||
- name: ws
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
|
||||
proxy.app2.rules: |
|
||||
- name: default
|
||||
do: bypass
|
||||
- name: block POST and PUT
|
||||
on: method POST | method PUT
|
||||
do: error 403 Forbidden
|
||||
*/
|
||||
//nolint:recvcheck
|
||||
Rules []Rule
|
||||
/*
|
||||
Rule is a rule for a reverse proxy.
|
||||
It do `Do` when `On` matches.
|
||||
|
||||
A rule can have multiple lines of on.
|
||||
|
||||
All lines of on must match,
|
||||
but each line can have multiple checks that
|
||||
one match means this line is matched.
|
||||
*/
|
||||
// Rule represents a reverse proxy rule.
|
||||
// The `Do` field is executed when `On` matches.
|
||||
//
|
||||
// - A rule may have multiple lines in the `On` section.
|
||||
// - All `On` lines must match for the rule to trigger.
|
||||
// - Each line can have several checks—one match per line is enough for that line.
|
||||
Rule struct {
|
||||
Name string `json:"name"`
|
||||
On RuleOn `json:"on" swaggertype:"string"`
|
||||
@@ -53,210 +59,351 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (rule *Rule) IsResponseRule() bool {
|
||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
||||
func isDefaultRule(rule Rule) bool {
|
||||
return rule.Name == "default" || rule.On.raw == OnDefault
|
||||
}
|
||||
|
||||
func (rules Rules) Validate() error {
|
||||
func (rules Rules) Validate() gperr.Error {
|
||||
var defaultRulesFound []int
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
if isDefaultRule(rule) {
|
||||
defaultRulesFound = append(defaultRulesFound, i)
|
||||
}
|
||||
if rules[i].Name == "" {
|
||||
// set name to index if name is empty
|
||||
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
}
|
||||
if len(defaultRulesFound) > 1 {
|
||||
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
||||
}
|
||||
for i := range rules {
|
||||
r1 := rules[i]
|
||||
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
||||
continue
|
||||
}
|
||||
sig1, ok := matcherSignature(r1.On.raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(rules); j++ {
|
||||
r2 := rules[j]
|
||||
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
||||
continue
|
||||
}
|
||||
sig2, ok := matcherSignature(r2.On.raw)
|
||||
if !ok || sig1 != sig2 {
|
||||
continue
|
||||
}
|
||||
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rule Rule) doesTerminateInPre() bool {
|
||||
for _, cmd := range rule.Do.pre {
|
||||
handler, ok := cmd.(Handler)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if handler.Terminates() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matcherSignature(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
andParts := splitAnd(raw)
|
||||
if len(andParts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
canonAnd := make([]string, 0, len(andParts))
|
||||
for _, andPart := range andParts {
|
||||
orParts := splitPipe(andPart)
|
||||
if len(orParts) == 0 {
|
||||
continue
|
||||
}
|
||||
canonOr := make([]string, 0, len(orParts))
|
||||
for _, atom := range orParts {
|
||||
subject, args, err := parse(strings.TrimSpace(atom))
|
||||
if err != nil || subject == "" {
|
||||
return "", false
|
||||
}
|
||||
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
||||
}
|
||||
slices.Sort(canonOr)
|
||||
canonOr = slices.Compact(canonOr)
|
||||
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
||||
}
|
||||
|
||||
slices.Sort(canonAnd)
|
||||
canonAnd = slices.Compact(canonAnd)
|
||||
if len(canonAnd) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(canonAnd, "&"), true
|
||||
}
|
||||
|
||||
// Parse parses a rule configuration string.
|
||||
// It first tries the block syntax (if the string contains a top-level '{'),
|
||||
// then falls back to YAML syntax.
|
||||
func (rules *Rules) Parse(config string) error {
|
||||
config = strings.TrimSpace(config)
|
||||
if config == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prefer block syntax if it looks like block syntax.
|
||||
if hasTopLevelLBrace(config) {
|
||||
blockRules, err := parseBlockRules(config)
|
||||
if err == nil {
|
||||
*rules = blockRules
|
||||
return nil
|
||||
}
|
||||
// Fall through to YAML (backward compatibility).
|
||||
}
|
||||
|
||||
// YAML fallback
|
||||
var anySlice []any
|
||||
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
||||
if yamlErr == nil {
|
||||
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
||||
}
|
||||
|
||||
// If YAML fails and we didn't try block syntax yet, try it now.
|
||||
blockRules, err := parseBlockRules(config)
|
||||
if err == nil {
|
||||
*rules = blockRules
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
||||
// Used to decide whether to prioritize the block syntax.
|
||||
func hasTopLevelLBrace(s string) bool {
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
||||
inBlock = false
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
if quote != '`' && c == '\\' && i+1 < len(s) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if rune(c) == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch c {
|
||||
case '\'', '"', '`':
|
||||
quote = rune(c)
|
||||
continue
|
||||
case '{':
|
||||
return true
|
||||
case '#':
|
||||
inLine = true
|
||||
continue
|
||||
case '/':
|
||||
if i+1 < len(s) && s[i+1] == '/' {
|
||||
inLine = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if i+1 < len(s) && s[i+1] == '*' {
|
||||
inBlock = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
default:
|
||||
if unicode.IsSpace(rune(c)) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
||||
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
||||
if len(rules) == 0 {
|
||||
return up
|
||||
}
|
||||
|
||||
defaultRule := Rule{
|
||||
Name: "default",
|
||||
Do: Command{
|
||||
raw: "pass",
|
||||
exec: BypassCommand{},
|
||||
},
|
||||
}
|
||||
var defaultRule *Rule
|
||||
|
||||
var nonDefaultRules Rules
|
||||
hasDefaultRule := false
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
defaultRule = rule
|
||||
hasDefaultRule = true
|
||||
for _, rule := range rules {
|
||||
if isDefaultRule(rule) {
|
||||
r := rule
|
||||
defaultRule = &r
|
||||
} else {
|
||||
// set name to index if name is empty
|
||||
if rule.Name == "" {
|
||||
rule.Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
nonDefaultRules = append(nonDefaultRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonDefaultRules) == 0 {
|
||||
if defaultRule.Do.isBypass() {
|
||||
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
||||
return up
|
||||
}
|
||||
if defaultRule.IsResponseRule() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
}
|
||||
}()
|
||||
w = rm
|
||||
up(w, r)
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
}
|
||||
}()
|
||||
w = rm
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err == nil {
|
||||
up(w, r)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preRules := make(Rules, 0, len(nonDefaultRules))
|
||||
postRules := make(Rules, 0, len(nonDefaultRules))
|
||||
for _, rule := range nonDefaultRules {
|
||||
if rule.IsResponseRule() {
|
||||
postRules = append(postRules, rule)
|
||||
} else {
|
||||
preRules = append(preRules, rule)
|
||||
}
|
||||
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||
return cmd.pre.ServeHTTP(w, r, up)
|
||||
}
|
||||
|
||||
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
|
||||
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
||||
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||
return cmd.post.ServeHTTP(w, r, up)
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
logFlushError(err, r)
|
||||
}
|
||||
}()
|
||||
|
||||
w = rm
|
||||
var hasError bool
|
||||
|
||||
shouldCallUpstream := true
|
||||
preMatched := false
|
||||
executedPre := make([]bool, len(nonDefaultRules))
|
||||
terminatedInPre := make([]bool, len(nonDefaultRules))
|
||||
matchedNonDefaultPre := false
|
||||
preTerminated := false
|
||||
for i, rule := range nonDefaultRules {
|
||||
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||
continue
|
||||
}
|
||||
matchedNonDefaultPre = true
|
||||
if preTerminated {
|
||||
// Preserve post-only commands (e.g. logging) even after
|
||||
// pre-phase termination.
|
||||
if len(rule.Do.pre) == 0 {
|
||||
executedPre[i] = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
executedPre[i] = true
|
||||
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
terminatedInPre[i] = true
|
||||
preTerminated = true
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
hasError = true
|
||||
}
|
||||
}
|
||||
|
||||
// Default rule is a fallback: run only when no non-default pre rule matched.
|
||||
defaultExecutedPre := false
|
||||
defaultTerminatedInPre := false
|
||||
if defaultRule != nil && !matchedNonDefaultPre && !defaultRule.On.phase.IsPostRule() && defaultRule.On.Check(rm, r) {
|
||||
defaultExecutedPre = true
|
||||
if err := execPreCommand(defaultRule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
defaultTerminatedInPre = true
|
||||
} else {
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing pre rule (%s): %w", defaultRule.Do.raw, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
hasError = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
for _, rule := range preRules {
|
||||
if rule.Check(w, r) {
|
||||
preMatched = true
|
||||
if rule.Do.isBypass() {
|
||||
break // post rules should still execute
|
||||
}
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &rule, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
break
|
||||
}
|
||||
if !rm.HasStatus() {
|
||||
if hasError {
|
||||
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
||||
up(rm, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Run post commands for rules that actually executed in pre phase,
|
||||
// unless that same rule terminated in pre phase.
|
||||
for i, rule := range nonDefaultRules {
|
||||
if !executedPre[i] || terminatedInPre[i] {
|
||||
continue
|
||||
}
|
||||
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if defaultExecutedPre && !defaultTerminatedInPre {
|
||||
if err := execPostCommand(defaultRule.Do, rm, r); err != nil {
|
||||
if !errors.Is(err, errTerminateRule) && isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing post rule (%s): %w", defaultRule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
return
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
// Run true post-matcher rules after response is available.
|
||||
for _, rule := range nonDefaultRules {
|
||||
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||
continue
|
||||
}
|
||||
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
||||
// as "pre" for requirement purposes still need to run in this phase.
|
||||
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
up(w, r)
|
||||
}
|
||||
|
||||
// if no post rules, we are done here
|
||||
if len(postRules) == 0 && !isDefaultRulePost {
|
||||
return
|
||||
}
|
||||
|
||||
for _, rule := range postRules {
|
||||
if rule.Check(w, r) {
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &rule, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isDefaultRulePost {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) {
|
||||
// rm.AppendError("rule: %s, error: %w", rule.Name, err)
|
||||
}
|
||||
|
||||
func isTerminatingHandler(handler CommandHandler) bool {
|
||||
switch h := handler.(type) {
|
||||
case TerminatingCommand:
|
||||
return true
|
||||
case Commands:
|
||||
if len(h) == 0 {
|
||||
return false
|
||||
}
|
||||
return isTerminatingHandler(h[len(h)-1])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,41 +411,41 @@ func (rule *Rule) String() string {
|
||||
return rule.Name
|
||||
}
|
||||
|
||||
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
if rule.On.checker == nil {
|
||||
return true
|
||||
}
|
||||
v := rule.On.checker.Check(w, r)
|
||||
return v
|
||||
}
|
||||
|
||||
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return rule.Do.exec.Handle(w, r)
|
||||
return rule.On.Check(w, r)
|
||||
}
|
||||
|
||||
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
||||
var errStreamClosed error
|
||||
|
||||
func logError(err error, r *http.Request) {
|
||||
if errors.Is(err, errStreamClosed) {
|
||||
return
|
||||
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
||||
var errClientDisconnected error
|
||||
|
||||
func isUnexpectedError(err error) bool {
|
||||
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
||||
return false
|
||||
}
|
||||
var h2Err http2.StreamError
|
||||
if errors.As(err, &h2Err) {
|
||||
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
||||
// ignore these errors
|
||||
if h2Err.Code == http2.ErrCodeStreamClosed {
|
||||
return
|
||||
return false
|
||||
}
|
||||
}
|
||||
var h3Err *http3.Error
|
||||
if errors.As(err, &h3Err) {
|
||||
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
||||
// ignore these errors
|
||||
switch h3Err.ErrorCode {
|
||||
case
|
||||
http3.ErrCodeNoError,
|
||||
http3.ErrCodeRequestCanceled:
|
||||
return
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func logFlushError(err error, r *http.Request) {
|
||||
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
||||
}
|
||||
|
||||
@@ -19,28 +19,71 @@ func TestRulesValidate(t *testing.T) {
|
||||
{
|
||||
name: "no default rule",
|
||||
rules: `
|
||||
- name: rule1
|
||||
on: header Host example.com
|
||||
do: pass
|
||||
`,
|
||||
header Host example.com {
|
||||
pass
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "multiple default rules",
|
||||
rules: `
|
||||
- name: default
|
||||
do: pass
|
||||
- name: rule1
|
||||
on: default
|
||||
do: pass
|
||||
`,
|
||||
default {
|
||||
pass
|
||||
}
|
||||
|
||||
default {
|
||||
pass
|
||||
}`,
|
||||
want: ErrMultipleDefaultRules,
|
||||
},
|
||||
{
|
||||
name: "multiple responses on same condition",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
error 404 "not found"
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
{
|
||||
name: "same condition different formatting error then proxy",
|
||||
rules: `
|
||||
header Host example.com & method GET {
|
||||
error 404 "not found"
|
||||
}
|
||||
|
||||
method GET
|
||||
header Host example.com {
|
||||
proxy http://127.0.0.1:8080
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
{
|
||||
name: "same condition with non terminating first rule",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
set resp_header X-Test first
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var rules Rules
|
||||
convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules))
|
||||
require.True(t, convertible)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rules.Validate()
|
||||
|
||||
if tt.want == nil {
|
||||
assert.NoError(t, err)
|
||||
@@ -50,3 +93,38 @@ func TestRulesValidate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTopLevelLBrace(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "escaped quote inside double quoted string",
|
||||
in: `"test\"more{"`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "escaped quote inside single quoted string",
|
||||
in: "'test\\'more{'",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "top-level brace outside quoted string",
|
||||
in: `"test\"more" {`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "backtick keeps existing behavior",
|
||||
in: "`test\\`more{`",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, hasTopLevelLBrace(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
273
internal/route/rules/scanner.go
Normal file
273
internal/route/rules/scanner.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
// Tokenizer provides utilities for parsing rule syntax with proper handling
|
||||
// of quotes, comments, and env vars.
|
||||
//
|
||||
// This is intentionally reusable by both the top-level rule block parser and
|
||||
// the nested do-block parser.
|
||||
type Tokenizer struct {
|
||||
src string
|
||||
length int
|
||||
}
|
||||
|
||||
// newTokenizer creates a tokenizer for the given source.
|
||||
func newTokenizer(src string) Tokenizer {
|
||||
return Tokenizer{src: src, length: len(src)}
|
||||
}
|
||||
|
||||
// skipComments skips whitespace, line comments, and block comments.
|
||||
// It returns the new position and an error if a block comment is unterminated.
|
||||
func (t *Tokenizer) skipComments(pos int, atLineStart bool, prevIsSpace bool) (int, gperr.Error) {
|
||||
for pos < t.length {
|
||||
c := t.src[pos]
|
||||
|
||||
// Skip whitespace
|
||||
if unicode.IsSpace(rune(c)) {
|
||||
pos++
|
||||
atLineStart = false
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for line comment: // or #
|
||||
if c == '/' {
|
||||
if pos+1 < t.length && t.src[pos+1] == '/' {
|
||||
// Skip to end of line
|
||||
for pos < t.length && t.src[pos] != '\n' {
|
||||
pos++
|
||||
}
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c == '#' && (atLineStart || prevIsSpace) {
|
||||
// Skip to end of line
|
||||
for pos < t.length && t.src[pos] != '\n' {
|
||||
pos++
|
||||
}
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for block comment: /*
|
||||
if c == '/' && pos+1 < t.length && t.src[pos+1] == '*' {
|
||||
pos += 2
|
||||
closed := false
|
||||
for pos+1 < t.length {
|
||||
if t.src[pos] == '*' && t.src[pos+1] == '/' {
|
||||
pos += 2
|
||||
closed = true
|
||||
break
|
||||
}
|
||||
pos++
|
||||
}
|
||||
if !closed {
|
||||
return 0, ErrInvalidBlockSyntax.Withf("unterminated block comment")
|
||||
}
|
||||
atLineStart = false
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
// scanToBrace scans from pos until it finds '{' outside quotes, or returns an error.
|
||||
func (t *Tokenizer) scanToBrace(pos int) (int, gperr.Error) {
|
||||
quote := rune(0)
|
||||
for pos < t.length {
|
||||
c := rune(t.src[pos])
|
||||
if quote != 0 {
|
||||
if c == quote {
|
||||
quote = 0
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if c == '"' || c == '\'' || c == '`' {
|
||||
quote = c
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if c == '{' {
|
||||
return pos, nil
|
||||
}
|
||||
if c == '}' {
|
||||
return 0, ErrInvalidBlockSyntax.Withf("unmatched '}' in block header")
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return 0, ErrInvalidBlockSyntax.Withf("expected '{' after block header")
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the matching '}' for a '{' starting at startPos.
|
||||
// It respects quotes/backticks and ${...} env vars.
|
||||
func (t *Tokenizer) findMatchingBrace(startPos int) (int, gperr.Error) {
|
||||
pos := startPos
|
||||
braceDepth := 1
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
atLineStart := true
|
||||
prevIsSpace := true
|
||||
|
||||
for pos < t.length {
|
||||
c := rune(t.src[pos])
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '*' && pos+1 < t.length && t.src[pos+1] == '/' {
|
||||
pos += 2
|
||||
inBlock = false
|
||||
continue
|
||||
}
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
if c == quote {
|
||||
quote = 0
|
||||
}
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(c)
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' || c == '\'' || c == '`' {
|
||||
quote = c
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Comments (only outside quotes) at token boundary
|
||||
if c == '#' && (atLineStart || prevIsSpace) {
|
||||
inLine = true
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if c == '/' && pos+1 < t.length {
|
||||
n := rune(t.src[pos+1])
|
||||
if (atLineStart || prevIsSpace) && n == '/' {
|
||||
inLine = true
|
||||
pos += 2
|
||||
continue
|
||||
}
|
||||
if (atLineStart || prevIsSpace) && n == '*' {
|
||||
inBlock = true
|
||||
pos += 2
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if c == '$' && pos+1 < t.length && t.src[pos+1] == '{' {
|
||||
// Skip env var ${...}
|
||||
pos += 2
|
||||
envBraceDepth := 1
|
||||
envQuote := rune(0)
|
||||
for pos < t.length {
|
||||
ec := rune(t.src[pos])
|
||||
if envQuote != 0 {
|
||||
if ec == envQuote {
|
||||
envQuote = 0
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if ec == '"' || ec == '\'' || ec == '`' {
|
||||
envQuote = ec
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if ec == '{' {
|
||||
envBraceDepth++
|
||||
} else if ec == '}' {
|
||||
envBraceDepth--
|
||||
if envBraceDepth == 0 {
|
||||
pos++ // Move past the closing '}'
|
||||
break
|
||||
}
|
||||
}
|
||||
pos++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch c {
|
||||
case '{':
|
||||
braceDepth++
|
||||
case '}':
|
||||
braceDepth--
|
||||
if braceDepth == 0 {
|
||||
return pos, nil
|
||||
}
|
||||
}
|
||||
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(c)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
|
||||
return 0, ErrInvalidBlockSyntax.Withf("unmatched '{' at position %d", startPos)
|
||||
}
|
||||
|
||||
// parseHeaderToBrace parses an expression/header starting at start and returns:
|
||||
// - header: trimmed src[start:bracePos]
|
||||
// - bracePos: position of '{' (outside quotes/backticks)
|
||||
func parseHeaderToBrace(src string, start int) (header string, bracePos int, err gperr.Error) {
|
||||
t := newTokenizer(src)
|
||||
bracePos, err = t.scanToBrace(start)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return strings.TrimSpace(src[start:bracePos]), bracePos, nil
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the matching '}' for a '{' at position startPos.
|
||||
// It respects quotes/backticks and ${...} env vars in do_body.
|
||||
func findMatchingBrace(src string, pos *int, startPos int) (int, gperr.Error) {
|
||||
t := newTokenizer(src)
|
||||
endPos, err := t.findMatchingBrace(startPos)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
*pos = endPos + 1
|
||||
return endPos, nil
|
||||
}
|
||||
39
internal/route/rules/scanner_test.go
Normal file
39
internal/route/rules/scanner_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTokenizer_skipComments_UnterminatedBlockComment(t *testing.T) {
|
||||
tok := newTokenizer("/* unterminated")
|
||||
pos, err := tok.skipComments(0, true, true)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, pos)
|
||||
}
|
||||
|
||||
func TestTokenizer_skipComments_SkipsLineAndBlockComments(t *testing.T) {
|
||||
src := " // line\n /* block */\n # hash\n default"
|
||||
tok := newTokenizer(src)
|
||||
pos, err := tok.skipComments(0, true, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.Index(src, "default"), pos)
|
||||
}
|
||||
|
||||
func TestTokenizer_scanToBrace_IgnoresQuotedBraces(t *testing.T) {
|
||||
src := "cond \"{\" {" // the brace inside quotes must be ignored
|
||||
tok := newTokenizer(src)
|
||||
bracePos, err := tok.scanToBrace(0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.LastIndex(src, "{"), bracePos)
|
||||
}
|
||||
|
||||
func TestTokenizer_findMatchingBrace_IgnoresQuotedClosingBrace(t *testing.T) {
|
||||
src := `{ "}" }`
|
||||
tok := newTokenizer(src)
|
||||
endPos, err := tok.findMatchingBrace(1) // body starts after the first '{'
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.LastIndex(src, "}"), endPos)
|
||||
}
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
type templateString struct {
|
||||
string
|
||||
|
||||
isTemplate bool
|
||||
}
|
||||
|
||||
@@ -23,32 +23,28 @@ func (tmpl *keyValueTemplate) Unpack() (string, templateString) {
|
||||
return tmpl.key, tmpl.tmpl
|
||||
}
|
||||
|
||||
func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error {
|
||||
func (tmpl *templateString) ExpandVars(w *httputils.ResponseModifier, req *http.Request, dst io.Writer) (phase PhaseFlag, err error) {
|
||||
if !tmpl.isTemplate {
|
||||
_, err := dstW.Write(strtobNoCopy(tmpl.string))
|
||||
return err
|
||||
_, err := asBytesBufferLike(dst).WriteString(tmpl.string)
|
||||
return PhaseNone, err
|
||||
}
|
||||
|
||||
return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW)
|
||||
return ExpandVars(w, req, tmpl.string, dst)
|
||||
}
|
||||
|
||||
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
|
||||
func (tmpl *templateString) ExpandVarsToString(w *httputils.ResponseModifier, r *http.Request) (string, PhaseFlag, error) {
|
||||
if !tmpl.isTemplate {
|
||||
return tmpl.string, nil
|
||||
return tmpl.string, PhaseNone, nil
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf)
|
||||
phase, err := tmpl.ExpandVars(w, r, &buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", PhaseNone, err
|
||||
}
|
||||
return buf.String(), nil
|
||||
return buf.String(), phase, nil
|
||||
}
|
||||
|
||||
func (tmpl *templateString) Len() int {
|
||||
return len(tmpl.string)
|
||||
}
|
||||
|
||||
func strtobNoCopy(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
ValidateFunc func(args []string) (any, error)
|
||||
ValidateFunc func(args []string) (phase PhaseFlag, parsedArgs any, err error)
|
||||
Tuple[T1, T2 any] struct {
|
||||
First T1
|
||||
Second T2
|
||||
@@ -37,6 +38,8 @@ type (
|
||||
MapValueMatcher = Tuple[string, Matcher]
|
||||
)
|
||||
|
||||
var cidrCache = xsync.NewMap[string, *net.IPNet]()
|
||||
|
||||
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
||||
return t.First, t.Second
|
||||
}
|
||||
@@ -62,7 +65,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string {
|
||||
}
|
||||
|
||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||
func validateSingleMatcher(args []string) (any, error) {
|
||||
func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -70,7 +73,7 @@ func validateSingleMatcher(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||
func toKVOptionalVMatcher(args []string) (any, error) {
|
||||
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
switch len(args) {
|
||||
case 1:
|
||||
return &MapValueMatcher{args[0], nil}, nil
|
||||
@@ -85,20 +88,8 @@ func toKVOptionalVMatcher(args []string) (any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func toKeyValueTemplate(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
isTemplate, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &keyValueTemplate{args[0], isTemplate}, nil
|
||||
}
|
||||
|
||||
// validateURL returns types.URL with the URL validated.
|
||||
func validateURL(args []string) (any, error) {
|
||||
func validateURL(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -116,22 +107,27 @@ func validateURL(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateCIDR returns types.CIDR with the CIDR validated.
|
||||
func validateCIDR(args []string) (any, error) {
|
||||
func validateCIDR(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
if !strings.Contains(args[0], "/") {
|
||||
args[0] += "/32"
|
||||
cidr := args[0]
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/32"
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(args[0])
|
||||
if cached, ok := cidrCache.Load(cidr); ok {
|
||||
return cached, nil
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
cidrCache.Store(cidr, ipnet)
|
||||
return ipnet, nil
|
||||
}
|
||||
|
||||
// validateURLPath returns string with the path validated.
|
||||
func validateURLPath(args []string) (any, error) {
|
||||
func validateURLPath(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -148,7 +144,7 @@ func validateURLPath(args []string) (any, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func validateURLPathMatcher(args []string) (any, error) {
|
||||
func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
path, err := validateURLPath(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -157,7 +153,7 @@ func validateURLPathMatcher(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateFSPath returns string with the path validated.
|
||||
func validateFSPath(args []string) (any, error) {
|
||||
func validateFSPath(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -169,7 +165,7 @@ func validateFSPath(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateMethod returns string with the method validated.
|
||||
func validateMethod(args []string) (any, error) {
|
||||
func validateMethod(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -200,7 +196,7 @@ func validateStatusCode(status string) (int, error) {
|
||||
// - 3xx
|
||||
// - 4xx
|
||||
// - 5xx
|
||||
func validateStatusRange(args []string) (any, error) {
|
||||
func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -232,7 +228,7 @@ func validateStatusRange(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||
func validateUserBCryptPassword(args []string) (any, error) {
|
||||
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -240,64 +236,93 @@ func validateUserBCryptPassword(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateModField returns CommandHandler with the field validated.
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, error) {
|
||||
func validateModField(mod FieldModifier, args []string) (phase PhaseFlag, handler HandlerFunc, err error) {
|
||||
if len(args) == 0 {
|
||||
return nil, ErrExpectTwoOrThreeArgs
|
||||
return phase, nil, ErrExpectTwoOrThreeArgs
|
||||
}
|
||||
setField, ok := modFields[args[0]]
|
||||
if !ok {
|
||||
return nil, ErrUnknownModField.Subject(args[0])
|
||||
return phase, nil, ErrUnknownModField.Subject(args[0])
|
||||
}
|
||||
if mod == ModFieldRemove {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
// setField expect validateStrTuple
|
||||
args = append(args, "")
|
||||
}
|
||||
validArgs, err := setField.validate(args[1:])
|
||||
phase, validArgs, err := setField.validate(args[1:])
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err).With(setField.help.Error())
|
||||
return phase, nil, gperr.Wrap(err).With(setField.help.Error())
|
||||
}
|
||||
|
||||
modder := setField.builder(validArgs)
|
||||
switch mod {
|
||||
case ModFieldAdd:
|
||||
add := modder.add
|
||||
if add == nil {
|
||||
return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod)
|
||||
return phase, nil, ErrInvalidArguments.Withf("add is not supported for field %s", args[0])
|
||||
}
|
||||
return add, nil
|
||||
return phase, add, nil
|
||||
case ModFieldRemove:
|
||||
remove := modder.remove
|
||||
if remove == nil {
|
||||
return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod)
|
||||
return phase, nil, ErrInvalidArguments.Withf("remove is not supported for field %s", args[0])
|
||||
}
|
||||
return remove, nil
|
||||
return phase, remove, nil
|
||||
}
|
||||
set := modder.set
|
||||
if set == nil {
|
||||
return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod)
|
||||
return phase, nil, ErrInvalidArguments.Withf("set is not supported for field %s", args[0])
|
||||
}
|
||||
return set, nil
|
||||
return phase, set, nil
|
||||
}
|
||||
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, error) {
|
||||
func validateTemplate(tmplStr string, newline bool) (phase PhaseFlag, tmpl templateString, err error) {
|
||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||
tmplStr += "\n"
|
||||
}
|
||||
|
||||
if !NeedExpandVars(tmplStr) {
|
||||
return templateString{tmplStr, false}, nil
|
||||
return phase, templateString{tmplStr, false}, nil
|
||||
}
|
||||
|
||||
err := ValidateVars(tmplStr)
|
||||
phase, err = ValidateVars(tmplStr)
|
||||
if err != nil {
|
||||
return templateString{}, err
|
||||
return phase, templateString{}, gperr.Wrap(err)
|
||||
}
|
||||
return templateString{tmplStr, true}, nil
|
||||
return phase, templateString{tmplStr, true}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, error) {
|
||||
func validatePreRequestKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 2 {
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
phase = PhasePre
|
||||
tmplReq, tmpl, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return phase, nil, err
|
||||
}
|
||||
phase |= tmplReq
|
||||
return phase, &keyValueTemplate{args[0], tmpl}, nil
|
||||
}
|
||||
|
||||
func validatePostResponseKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 2 {
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
phase = PhasePost
|
||||
tmplReq, tmpl, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return phase, nil, err
|
||||
}
|
||||
phase |= tmplReq
|
||||
return phase, &keyValueTemplate{args[0], tmpl}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||
l, err := zerolog.ParseLevel(level)
|
||||
if err != nil {
|
||||
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
||||
|
||||
@@ -23,7 +23,7 @@ func BenchmarkExpandVars(b *testing.B) {
|
||||
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
|
||||
|
||||
for b.Loop() {
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
|
||||
// TODO: remove middleware/vars.go and use this instead
|
||||
@@ -45,41 +46,84 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type bytesBufferLike interface {
|
||||
io.Writer
|
||||
WriteByte(c byte) error
|
||||
WriteString(s string) (int, error)
|
||||
}
|
||||
|
||||
type bytesBufferAdapter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (b bytesBufferAdapter) WriteByte(c byte) error {
|
||||
buf := [1]byte{c}
|
||||
_, err := b.Write(buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (b bytesBufferAdapter) WriteString(s string) (int, error) {
|
||||
return b.Write(unsafe.Slice(unsafe.StringData(s), len(s))) // avoid copy
|
||||
}
|
||||
|
||||
func asBytesBufferLike(w io.Writer) bytesBufferLike {
|
||||
switch w := w.(type) {
|
||||
case *bytes.Buffer:
|
||||
return w
|
||||
case bytesBufferLike:
|
||||
return w
|
||||
default:
|
||||
return bytesBufferAdapter{w}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateVars validates the variables in the given string.
|
||||
// It returns ErrUnexpectedVar if any invalid variable is found.
|
||||
func ValidateVars(s string) error {
|
||||
// It returns the phase that the variables require and an error if any error occurs.
|
||||
//
|
||||
// Possible errors:
|
||||
// - ErrUnexpectedVar: if any invalid variable is found
|
||||
// - ErrUnterminatedEnvVar: missing closing }
|
||||
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
||||
// - ErrUnterminatedParenthesis: missing closing )
|
||||
func ValidateVars(s string) (phase PhaseFlag, err error) {
|
||||
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
||||
}
|
||||
|
||||
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
|
||||
dst := ioutils.NewBufferedWriter(dstW, 1024)
|
||||
defer dst.Close()
|
||||
|
||||
// ExpandVars expands the variables in the given string and writes the result to the given writer.
|
||||
// It returns the phase that the variables require and an error if any error occurs.
|
||||
//
|
||||
// Possible errors:
|
||||
// - ErrUnexpectedVar: if any invalid variable is found
|
||||
// - ErrUnterminatedEnvVar: missing closing }
|
||||
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
||||
// - ErrUnterminatedParenthesis: missing closing )
|
||||
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) (phase PhaseFlag, err error) {
|
||||
dst := asBytesBufferLike(dstW)
|
||||
for i := 0; i < len(src); i++ {
|
||||
ch := src[i]
|
||||
if ch != '$' {
|
||||
if err := dst.WriteByte(ch); err != nil {
|
||||
return err
|
||||
if err = dst.WriteByte(ch); err != nil {
|
||||
return phase, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Look ahead
|
||||
if i+1 >= len(src) {
|
||||
return ErrUnterminatedEnvVar
|
||||
return phase, ErrUnterminatedEnvVar
|
||||
}
|
||||
j := i + 1
|
||||
|
||||
switch src[j] {
|
||||
case '$': // $$ -> literal '$'
|
||||
if err := dst.WriteByte('$'); err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
i = j
|
||||
continue
|
||||
case '{': // ${...} pass through as-is
|
||||
if _, err := dst.WriteString("${"); err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
i = j // we've consumed the '{' too
|
||||
continue
|
||||
@@ -102,24 +146,26 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
|
||||
if getter, ok := dynamicVarSubsMap[name]; ok {
|
||||
// Function-like variables
|
||||
isStatic = false
|
||||
phase |= getter.phase
|
||||
args, nextIdx, err := extractArgs(src, j, name)
|
||||
if err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
i = nextIdx
|
||||
actual, err = getter(args, w, req)
|
||||
actual, err = getter.get(args, w, req)
|
||||
if err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
} else if getter, ok := staticReqVarSubsMap[name]; ok {
|
||||
} else if getter, ok := staticReqVarSubsMap[name]; ok { // always available
|
||||
actual = getter(req)
|
||||
} else if getter, ok := staticRespVarSubsMap[name]; ok {
|
||||
} else if getter, ok := staticRespVarSubsMap[name]; ok { // post response
|
||||
actual = getter(w)
|
||||
phase |= PhasePost
|
||||
} else {
|
||||
return ErrUnexpectedVar.Subject(name)
|
||||
return phase, ErrUnexpectedVar.Subject(name)
|
||||
}
|
||||
if _, err := dst.WriteString(actual); err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
if isStatic {
|
||||
i = k - 1
|
||||
@@ -128,10 +174,10 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
|
||||
}
|
||||
|
||||
// No valid construct after '$'
|
||||
return ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
||||
return phase, ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
||||
}
|
||||
|
||||
return nil
|
||||
return phase, nil
|
||||
}
|
||||
|
||||
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
|
||||
|
||||
@@ -11,58 +11,88 @@ import (
|
||||
var (
|
||||
VarHeader = "header"
|
||||
VarResponseHeader = "resp_header"
|
||||
VarCookie = "cookie"
|
||||
VarQuery = "arg"
|
||||
VarForm = "form"
|
||||
VarPostForm = "postform"
|
||||
)
|
||||
|
||||
type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
||||
type dynamicVarGetter struct {
|
||||
phase PhaseFlag
|
||||
get func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
||||
}
|
||||
|
||||
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
||||
VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Header, key, index)
|
||||
},
|
||||
VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||
},
|
||||
VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
||||
},
|
||||
VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
VarHeader: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Form, key, index)
|
||||
return getValueByKeyAtIndex(req.Header, key, index)
|
||||
},
|
||||
},
|
||||
VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
VarResponseHeader: {
|
||||
phase: PhasePost,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.PostForm, key, index)
|
||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||
},
|
||||
},
|
||||
VarCookie: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sharedData := httputils.GetSharedData(w)
|
||||
return getValueByKeyAtIndex(sharedData.GetCookiesMap(req), key, index)
|
||||
},
|
||||
},
|
||||
VarQuery: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
||||
},
|
||||
},
|
||||
VarForm: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Form, key, index)
|
||||
},
|
||||
},
|
||||
VarPostForm: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.PostForm, key, index)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ func TestExpandVars(t *testing.T) {
|
||||
{
|
||||
name: "req_method",
|
||||
input: "$req_method",
|
||||
want: "POST",
|
||||
want: http.MethodPost,
|
||||
},
|
||||
{
|
||||
name: "req_path",
|
||||
@@ -484,7 +484,7 @@ func TestExpandVars(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
@@ -506,7 +506,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
||||
&out)
|
||||
|
||||
@@ -520,7 +520,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||
"Query: $arg(q), Page: $arg(page)",
|
||||
&out)
|
||||
|
||||
@@ -537,7 +537,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
||||
&out)
|
||||
|
||||
@@ -560,7 +560,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
{
|
||||
name: "https scheme",
|
||||
request: &http.Request{
|
||||
Method: "GET",
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
|
||||
TLS: &tls.ConnectionState{}, // Simulate TLS connection
|
||||
},
|
||||
@@ -572,7 +572,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||
_, err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, out.String())
|
||||
})
|
||||
@@ -598,7 +598,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
for _, varExpr := range upstreamVars {
|
||||
t.Run(varExpr, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
||||
// Should not error, may return empty string
|
||||
require.NoError(t, err)
|
||||
})
|
||||
@@ -614,16 +614,16 @@ func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
|
||||
t.Run("req_host without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "example.com", out.String())
|
||||
})
|
||||
|
||||
t.Run("req_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, out.String())
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -636,16 +636,16 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
|
||||
t.Run("remote_host without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, out.String())
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
|
||||
t.Run("remote_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, out.String())
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -654,7 +654,7 @@ func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "GET /test", out.String())
|
||||
}
|
||||
@@ -699,7 +699,7 @@ func TestValidateVars(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateVars(tt.input)
|
||||
_, err := ValidateVars(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user