mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-11 03:06:51 +02:00
Compare commits
15 Commits
main
...
feat/rules
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95ffd35585 | ||
|
|
7b0d846576 | ||
|
|
458c7779d3 | ||
|
|
dc6c649f2c | ||
|
|
3c5c3ecac2 | ||
|
|
a94442b001 | ||
|
|
2a51c2ef52 | ||
|
|
6477c35b15 | ||
|
|
5b20bbeb6f | ||
|
|
5ba475c489 | ||
|
|
54be056530 | ||
|
|
08de9086c3 | ||
|
|
1a17f3943a | ||
|
|
9bb5c54e7c | ||
|
|
faecbab2cb |
2
goutils
2
goutils
Submodule goutils updated: 482b5bca9f...3be815cb6e
@@ -5093,11 +5093,6 @@
|
|||||||
"x-nullable": false,
|
"x-nullable": false,
|
||||||
"x-omitempty": false
|
"x-omitempty": false
|
||||||
},
|
},
|
||||||
"isResponseRule": {
|
|
||||||
"type": "boolean",
|
|
||||||
"x-nullable": false,
|
|
||||||
"x-omitempty": false
|
|
||||||
},
|
|
||||||
"name": {
|
"name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"x-nullable": false,
|
"x-nullable": false,
|
||||||
|
|||||||
@@ -891,8 +891,6 @@ definitions:
|
|||||||
properties:
|
properties:
|
||||||
do:
|
do:
|
||||||
type: string
|
type: string
|
||||||
isResponseRule:
|
|
||||||
type: boolean
|
|
||||||
name:
|
name:
|
||||||
type: string
|
type: string
|
||||||
"on":
|
"on":
|
||||||
@@ -1837,12 +1835,12 @@ definitions:
|
|||||||
type: string
|
type: string
|
||||||
kernel_version:
|
kernel_version:
|
||||||
type: string
|
type: string
|
||||||
load_avg_5m:
|
|
||||||
type: string
|
|
||||||
load_avg_15m:
|
load_avg_15m:
|
||||||
type: string
|
type: string
|
||||||
load_avg_1m:
|
load_avg_1m:
|
||||||
type: string
|
type: string
|
||||||
|
load_avg_5m:
|
||||||
|
type: string
|
||||||
mem_pct:
|
mem_pct:
|
||||||
type: string
|
type: string
|
||||||
mem_total:
|
mem_total:
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ type ParsedRule struct {
|
|||||||
On string `json:"on"`
|
On string `json:"on"`
|
||||||
Do string `json:"do"`
|
Do string `json:"do"`
|
||||||
ValidationError error `json:"validationError,omitempty"` // we need the structured error, not the plain string
|
ValidationError error `json:"validationError,omitempty"` // we need the structured error, not the plain string
|
||||||
IsResponseRule bool `json:"isResponseRule"`
|
|
||||||
} // @name ParsedRule
|
} // @name ParsedRule
|
||||||
|
|
||||||
type FinalRequest struct {
|
type FinalRequest struct {
|
||||||
@@ -298,7 +297,6 @@ func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
|
|||||||
On: onStr,
|
On: onStr,
|
||||||
Do: doStr,
|
Do: doStr,
|
||||||
ValidationError: validationErr,
|
ValidationError: validationErr,
|
||||||
IsResponseRule: rule.IsResponseRule(),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Only add valid rules to execution list
|
// Only add valid rules to execution list
|
||||||
|
|||||||
@@ -38,13 +38,13 @@ type Rule struct {
|
|||||||
type RuleOn struct {
|
type RuleOn struct {
|
||||||
raw string
|
raw string
|
||||||
checker Checker
|
checker Checker
|
||||||
isResponseChecker bool
|
phase PhaseFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
type Command struct {
|
type Command struct {
|
||||||
raw string
|
raw string
|
||||||
exec CommandHandler
|
pre Commands
|
||||||
isResponseHandler bool
|
post Commands
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -59,6 +59,9 @@ func ParseRules(config string) (Rules, error)
|
|||||||
|
|
||||||
// ValidateRules validates rule syntax
|
// ValidateRules validates rule syntax
|
||||||
func ValidateRules(config string) error
|
func ValidateRules(config string) error
|
||||||
|
|
||||||
|
// Validate validates rule semantics (e.g., prevents multiple default rules)
|
||||||
|
func (rules Rules) Validate() gperr.Error
|
||||||
```
|
```
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
@@ -122,16 +125,52 @@ sequenceDiagram
|
|||||||
Pre->>Pre: Execute handler
|
Pre->>Pre: Execute handler
|
||||||
alt Terminating action
|
alt Terminating action
|
||||||
Pre-->>Req: Response
|
Pre-->>Req: Response
|
||||||
Return-->>Req: Return immediately
|
Note right of Pre: Stop remaining pre commands
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
opt No pre termination
|
||||||
Req->>Proxy: Forward request
|
Req->>Proxy: Forward request
|
||||||
Proxy-->>Req: Response
|
Proxy-->>Req: Response
|
||||||
Req->>Post: Check post-rules
|
end
|
||||||
Post->>Post: Execute handlers
|
Req->>Post: Run scheduled post commands
|
||||||
Post-->>Req: Modified response
|
Req->>Post: Evaluate response matchers
|
||||||
|
Post->>Post: Execute matched post handlers
|
||||||
|
Post-->>Req: Final response
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Execution Model (Authoritative)
|
||||||
|
|
||||||
|
Rules run in two phases:
|
||||||
|
|
||||||
|
1. **Pre phase**
|
||||||
|
- Evaluate only request-based matchers (`path`, `method`, `header`, `remote`, etc.) in declaration order.
|
||||||
|
- Execute matched rule `do` pre-commands in order.
|
||||||
|
- If a default rule exists (`name: default` or `on: default`), it is a fallback and runs only when no non-default pre rule matches.
|
||||||
|
- If a terminating action runs, stop:
|
||||||
|
- remaining commands in that rule
|
||||||
|
- all later pre-phase commands.
|
||||||
|
- Exception: rules that only contain post commands (no pre commands) are still scheduled for post phase.
|
||||||
|
|
||||||
|
2. **Upstream phase**
|
||||||
|
- Upstream is called only if pre phase did not terminate.
|
||||||
|
|
||||||
|
3. **Post phase**
|
||||||
|
- Run post-commands for rules whose pre phase executed, except rules that terminated in pre.
|
||||||
|
- Then evaluate response-based matchers (`status`, `resp_header`) and execute their `do` commands.
|
||||||
|
- Response-based rules run even when the response was produced in pre phase.
|
||||||
|
|
||||||
|
**Important:** termination is explicit by command semantics, not inferred from status-code mutation.
|
||||||
|
|
||||||
|
### Phase Flags
|
||||||
|
|
||||||
|
Rule and command parsing tracks phase requirements via `PhaseFlag`:
|
||||||
|
|
||||||
|
- `PhasePre`
|
||||||
|
- `PhasePost`
|
||||||
|
- `PhasePre | PhasePost` (combined)
|
||||||
|
|
||||||
|
Combined flags are expected for nested/compound commands and variable templates that may need both request and response context.
|
||||||
|
|
||||||
### Condition Matchers
|
### Condition Matchers
|
||||||
|
|
||||||
| Matcher | Type | Description |
|
| Matcher | Type | Description |
|
||||||
@@ -167,21 +206,21 @@ path regex("/api/v[0-9]+/.*") // regex pattern
|
|||||||
**Terminating Actions** (stop processing):
|
**Terminating Actions** (stop processing):
|
||||||
|
|
||||||
| Command | Description |
|
| Command | Description |
|
||||||
| ------------------------ | ---------------------- |
|
| ------------------------------ | ------------------------------------- |
|
||||||
|
| `upstream` / `bypass` / `pass` | Call upstream and terminate pre-phase |
|
||||||
| `error <code> <message>` | Return HTTP error |
|
| `error <code> <message>` | Return HTTP error |
|
||||||
| `redirect <url>` | Redirect to URL |
|
| `redirect <url>` | Redirect to URL |
|
||||||
| `serve <path>` | Serve local files |
|
| `serve <path>` | Serve local files |
|
||||||
| `route <name>` | Route to another route |
|
| `route <name>` | Route to another route |
|
||||||
| `proxy <url>` | Proxy to upstream |
|
| `proxy <url>` | Proxy to upstream |
|
||||||
|
| `require_basic_auth <realm>` | Return 401 challenge |
|
||||||
|
|
||||||
**Non-Terminating Actions** (modify and continue):
|
**Non-Terminating Actions** (modify and continue):
|
||||||
|
|
||||||
| Command | Description |
|
| Command | Description |
|
||||||
| ------------------------------ | ---------------------- |
|
| ------------------------------ | ---------------------- |
|
||||||
| `pass` / `bypass` | Pass through unchanged |
|
|
||||||
| `rewrite <from> <to>` | Rewrite request path |
|
| `rewrite <from> <to>` | Rewrite request path |
|
||||||
| `require_auth` | Require authentication |
|
| `require_auth` | Require authentication |
|
||||||
| `require_basic_auth <realm>` | Basic auth challenge |
|
|
||||||
| `set <target> <field> <value>` | Set header/variable |
|
| `set <target> <field> <value>` | Set header/variable |
|
||||||
| `add <target> <field> <value>` | Add header/variable |
|
| `add <target> <field> <value>` | Add header/variable |
|
||||||
| `remove <target> <field>` | Remove header/variable |
|
| `remove <target> <field>` | Remove header/variable |
|
||||||
@@ -195,54 +234,215 @@ path regex("/api/v[0-9]+/.*") // regex pattern
|
|||||||
|
|
||||||
## Configuration Surface
|
## Configuration Surface
|
||||||
|
|
||||||
### Rule Configuration (YAML)
|
### Rule Configuration (Block Syntax)
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
rules:
|
default {
|
||||||
- name: rule name
|
|
||||||
on: |
|
|
||||||
condition1
|
|
||||||
& condition2
|
|
||||||
do: |
|
|
||||||
action1
|
action1
|
||||||
action2
|
action2
|
||||||
|
}
|
||||||
|
|
||||||
|
condition1 &
|
||||||
|
condition2 {
|
||||||
|
action1
|
||||||
|
action2
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This is the primary syntax for rules and avoids YAML wrappers.
|
||||||
|
It keeps the **inner** `on` and `do` DSLs exactly the same (same matchers, same commands, same optional quotes), but wraps each rule in a `{ ... }` block.
|
||||||
|
|
||||||
|
#### Key ideas
|
||||||
|
|
||||||
|
- A rule is:
|
||||||
|
- `default { <do...> }`,
|
||||||
|
- `{ <do...> }`, or
|
||||||
|
- `<on-expr> { <do...> }`
|
||||||
|
- Comments are supported:
|
||||||
|
- line comment: `// ...` (to end of line)
|
||||||
|
- line comment: `# ...` (to end of line, for YAML familiarity)
|
||||||
|
- block comment: `/* ... */` (may span multiple lines)
|
||||||
|
- Comments are ignored **only when outside quotes** (`"`, `'` or backticks).
|
||||||
|
- Environment variable syntax: `${NAME}` is supported by the inner DSL parser in [`parse()`](internal/route/rules/parser.go:34).
|
||||||
|
Block-syntax rule:
|
||||||
|
- In `on` (rule header): `${...}` must be inside quotes/backticks.
|
||||||
|
- In `do` (rule body): `${...}` may be unquoted; the outer parser must treat `${...}` as an opaque token so braces inside it are not structural.
|
||||||
|
|
||||||
|
#### Grammar sketch (EBNF-ish)
|
||||||
|
|
||||||
|
```text
|
||||||
|
file := { ws | comment | rule }
|
||||||
|
rule := default_rule | unconditional_rule | conditional_rule
|
||||||
|
|
||||||
|
default_rule := 'default' ws* block
|
||||||
|
unconditional_rule := ws* block
|
||||||
|
conditional_rule := on_expr ws* block
|
||||||
|
|
||||||
|
block := '{' do_body '}'
|
||||||
|
|
||||||
|
// on_expr and do_body are raw text regions.
|
||||||
|
// The outer parser only needs to:
|
||||||
|
// - find the top-level '{' to start a rule block
|
||||||
|
// - find the matching top-level '}' to end it
|
||||||
|
// while respecting quotes and comments.
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Elif/Else Chain Grammar
|
||||||
|
|
||||||
|
```text
|
||||||
|
// Elif/Else chains can appear in do_body
|
||||||
|
do_stmt := command_line | nested_block | elif_else_chain
|
||||||
|
elif_else_chain := nested_block { elif_clause } [else_clause]
|
||||||
|
elif_clause := 'elif' ws* on_expr ws* '{' do_body '}'
|
||||||
|
else_clause := 'else' ws* '{' do_body '}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Nested blocks (inline conditionals inside `do`)
|
||||||
|
|
||||||
|
Inside a rule body (`do_body`), you can write **nested blocks**:
|
||||||
|
|
||||||
|
```text
|
||||||
|
do_stmt := command_line | nested_block | elif_else_chain
|
||||||
|
|
||||||
|
nested_block := on_expr ws* '{' do_body '}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- A nested block is recognized when a line ends with an unquoted `{` (ignoring trailing whitespace).
|
||||||
|
- `on_expr` uses the same syntax as rule `on` (supports `|`, `&`, quoting/backticks, matcher functions, etc.).
|
||||||
|
- The nested block executes **in sequence**, at the point where it appears in the parent `do` list.
|
||||||
|
- Nested blocks are evaluated in the same phase the parent rule runs (no special phase promotion).
|
||||||
|
- Nested blocks can be chained with `elif`/`else` for conditional execution (see Elif/Else Chains section).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
default {
|
||||||
|
remove resp_header X-Secret
|
||||||
|
add resp_header X-Custom-Header custom-value
|
||||||
|
}
|
||||||
|
|
||||||
|
header X-Test-Header {
|
||||||
|
set header X-Remote-Type public
|
||||||
|
remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||||
|
set header X-Remote-Type private
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Elif/Else Chains
|
||||||
|
|
||||||
|
You can chain multiple conditions using `elif` and provide a fallback with `else`.
|
||||||
|
The `elif`/`else` keywords must appear on the same line as the preceding closing brace (`}`).
|
||||||
|
|
||||||
|
```go
|
||||||
|
header X-Test-Header {
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
} elif method POST {
|
||||||
|
set header X-Mode post
|
||||||
|
} else {
|
||||||
|
set header X-Mode other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- `elif` and `else` must be on the same line as the preceding `}`.
|
||||||
|
- Multiple `elif` branches are allowed; only one `else` is allowed.
|
||||||
|
- The entire chain is evaluated in sequence; the first matching branch executes.
|
||||||
|
- Elif/else chains can only be used within nested blocks.
|
||||||
|
- Each `elif` clause must have its own condition expression and block.
|
||||||
|
- The `else` clause is optional and provides a default action when no conditions match.
|
||||||
|
|
||||||
|
#### Examples
|
||||||
|
|
||||||
|
Basic default rule:
|
||||||
|
|
||||||
|
```go
|
||||||
|
default {
|
||||||
|
bypass
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
WebSocket upgrade routing:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# WebSocket requests
|
||||||
|
header Connection Upgrade &
|
||||||
|
header Upgrade websocket {
|
||||||
|
route ws-api
|
||||||
|
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Block comments:
|
||||||
|
|
||||||
|
```go
|
||||||
|
/* protect admin area */
|
||||||
|
path glob("/admin/*") {
|
||||||
|
require_auth
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Always log the request
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
log info /dev/stdout "Request $req_method $req_path"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Notes and constraints
|
||||||
|
|
||||||
|
- The block syntax uses `{` and `}` as structure delimiters at **top-level** (outside quotes/comments).
|
||||||
|
- Braces inside quoted strings (including backticks) are not structural.
|
||||||
|
- `${...}` handling:
|
||||||
|
- `on`: must be quoted/backticked
|
||||||
|
- `do`: may be unquoted
|
||||||
|
Preferred style: always write env vars as `${NAME}` rather than a bare `$NAME`.
|
||||||
|
- If you need literal `{` or `}` outside quotes/backticks (for example unquoted templates like `{{ ... }}`), wrap that argument in quotes/backticks so the outer parser does not treat it as structure.
|
||||||
|
- Rule naming remains minimal: if no explicit name is provided by the syntax, it will behave like the current YAML behavior (empty name becomes `rule[index]` in [`Rules.BuildHandler()`](internal/route/rules/rules.go:75)).
|
||||||
|
- YAML remains supported as a fallback for backward compatibility.
|
||||||
|
|
||||||
### Condition Syntax
|
### Condition Syntax
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
# Simple condition
|
# Simple condition
|
||||||
on: path /api/users
|
path /api/users
|
||||||
|
|
||||||
# Multiple conditions (AND)
|
# Multiple conditions (AND)
|
||||||
on: |
|
header Authorization Bearer & path glob("/api/admin/*")
|
||||||
header Authorization Bearer
|
|
||||||
& path /api/admin/*
|
|
||||||
|
|
||||||
# Negation
|
# Negation
|
||||||
on: !path /public/*
|
!path glob("/public/*")
|
||||||
|
|
||||||
|
# Negation on matcher
|
||||||
|
path !glob("/public/*")
|
||||||
|
|
||||||
# OR within a line
|
# OR within a line
|
||||||
on: method GET | method POST
|
method GET | method POST
|
||||||
```
|
```
|
||||||
|
|
||||||
### Variable Substitution
|
### Variable Substitution
|
||||||
|
|
||||||
```go
|
```bash
|
||||||
// Static variables
|
# Static variables
|
||||||
$req_method // Request method
|
$req_method # Request method
|
||||||
$req_host // Request host
|
$req_host # Request host
|
||||||
$req_path // Request path
|
$req_path # Request path
|
||||||
$status_code // Response status
|
$status_code # Response status
|
||||||
$remote_host // Client IP
|
$remote_host # Client IP
|
||||||
|
|
||||||
// Dynamic variables
|
# Dynamic variables
|
||||||
$header(Name) // Request header
|
$header(Name) # Request header
|
||||||
$header(Name, index) // Header at index
|
$header(Name, index) # Header at index
|
||||||
$arg(Name) // Query argument
|
$arg(Name) # Query argument
|
||||||
$form(Name) // Form field
|
$form(Name) # Form field
|
||||||
|
|
||||||
// Environment variables
|
# Environment variables
|
||||||
${ENV_VAR}
|
${ENV_VAR}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -278,8 +478,9 @@ Log context includes: `rule`, `alias`, `match_result`
|
|||||||
## Failure Modes and Recovery
|
## Failure Modes and Recovery
|
||||||
|
|
||||||
| Failure | Behavior | Recovery |
|
| Failure | Behavior | Recovery |
|
||||||
| ------------------- | ------------------------- | ---------------------------------- |
|
| ---------------------- | ------------------------- | ---------------------------------- |
|
||||||
| Invalid rule syntax | Route validation fails | Fix YAML syntax |
|
| Invalid rule syntax | Route validation fails | Fix block rule syntax |
|
||||||
|
| Multiple default rules | Route validation fails | Remove duplicate default rules |
|
||||||
| Missing variables | Variable renders as empty | Check variable sources |
|
| Missing variables | Variable renders as empty | Check variable sources |
|
||||||
| Rule timeout | Request times out | Increase timeout or simplify rules |
|
| Rule timeout | Request times out | Increase timeout or simplify rules |
|
||||||
| Auth failure | Returns 401/403 | Fix credentials |
|
| Auth failure | Returns 401/403 | Fix credentials |
|
||||||
@@ -288,73 +489,92 @@ Log context includes: `rule`, `alias`, `match_result`
|
|||||||
|
|
||||||
### Basic Pass-Through
|
### Basic Pass-Through
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
- name: default
|
default {
|
||||||
do: pass
|
pass
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Path-Based Routing
|
### Path-Based Routing
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
- name: api proxy
|
path glob("/api/*") {
|
||||||
on: path /api/*
|
proxy http://api-backend:8080
|
||||||
do: proxy http://api-backend:8080
|
}
|
||||||
|
|
||||||
- name: static files
|
path glob("/static/*") {
|
||||||
on: path /static/*
|
serve /var/www/static
|
||||||
do: serve /var/www/static
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Authentication
|
### Authentication
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
- name: admin protection
|
path glob("/admin/*") {
|
||||||
on: path /admin/*
|
require_auth
|
||||||
do: require_auth
|
}
|
||||||
|
|
||||||
- name: basic auth for API
|
path glob("/api/*") {
|
||||||
on: path /api/*
|
require_basic_auth "API Access"
|
||||||
do: require_basic_auth "API Access"
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Path Rewriting
|
### Path Rewriting
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
- name: rewrite API v1
|
path glob("/v1/*") {
|
||||||
on: path /v1/*
|
|
||||||
do: |
|
|
||||||
rewrite /v1 /api/v1
|
rewrite /v1 /api/v1
|
||||||
proxy http://backend:8080
|
proxy http://backend:8080
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### IP-Based Access Control
|
### IP-Based Access Control
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
- name: allow internal
|
remote 10.0.0.0/8 {
|
||||||
on: remote 10.0.0.0/8
|
pass
|
||||||
do: pass
|
}
|
||||||
|
|
||||||
- name: block external
|
!remote 10.0.0.0/8 &
|
||||||
on: |
|
!remote 192.168.0.0/16 {
|
||||||
!remote 10.0.0.0/8
|
error 403 "Access Denied"
|
||||||
!remote 192.168.0.0/16
|
}
|
||||||
do: error 403 "Access Denied"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### WebSocket Support
|
### WebSocket Support
|
||||||
|
|
||||||
```yaml
|
```bash
|
||||||
- name: websocket upgrade
|
header Connection Upgrade &
|
||||||
on: |
|
header Upgrade websocket {
|
||||||
header Connection Upgrade
|
bypass
|
||||||
header Upgrade websocket
|
}
|
||||||
do: bypass
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Default Rule (Fallback)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default runs only if no non-default pre rule matches
|
||||||
|
default {
|
||||||
|
remove resp_header X-Internal
|
||||||
|
add resp_header X-Powered-By godoxy
|
||||||
|
}
|
||||||
|
|
||||||
|
# Matching rules suppress default
|
||||||
|
path glob("/api/*") {
|
||||||
|
proxy http://api:8080
|
||||||
|
}
|
||||||
|
|
||||||
|
path glob("/api/*") {
|
||||||
|
set resp_header X-API true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors and both behave as fallback-only.
|
||||||
|
|
||||||
## Testing Notes
|
## Testing Notes
|
||||||
|
|
||||||
- Unit tests for all matchers and actions
|
- Unit tests for all matchers and actions
|
||||||
- Integration tests with real HTTP requests
|
- Integration tests with real HTTP requests
|
||||||
- Parser tests for YAML syntax
|
- Parser tests for block syntax
|
||||||
- Variable substitution tests
|
- Variable substitution tests
|
||||||
- Performance benchmarks for hot paths
|
- Performance benchmarks for hot paths
|
||||||
|
|||||||
409
internal/route/rules/block_parser.go
Normal file
409
internal/route/rules/block_parser.go
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/yusing/goutils/env"
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getStringBuffer(size int) *strings.Builder {
|
||||||
|
var buf strings.Builder
|
||||||
|
if size > 0 {
|
||||||
|
buf.Grow(size)
|
||||||
|
}
|
||||||
|
return &buf
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandEnvVarsRaw expands ${NAME} in-place using env.LookupEnv (prefix-aware).
|
||||||
|
func expandEnvVarsRaw(v string) (string, gperr.Error) {
|
||||||
|
buf := getStringBuffer(len(v))
|
||||||
|
envVar := getStringBuffer(0)
|
||||||
|
|
||||||
|
var missingEnvVars []string
|
||||||
|
inEnvVar := false
|
||||||
|
expectingBrace := false
|
||||||
|
|
||||||
|
for _, r := range v {
|
||||||
|
if expectingBrace && r != '{' && r != '$' {
|
||||||
|
buf.WriteRune('$')
|
||||||
|
expectingBrace = false
|
||||||
|
}
|
||||||
|
switch r {
|
||||||
|
case '$':
|
||||||
|
if expectingBrace {
|
||||||
|
buf.WriteRune('$')
|
||||||
|
expectingBrace = false
|
||||||
|
} else {
|
||||||
|
expectingBrace = true
|
||||||
|
}
|
||||||
|
case '{':
|
||||||
|
if expectingBrace {
|
||||||
|
inEnvVar = true
|
||||||
|
expectingBrace = false
|
||||||
|
envVar.Reset()
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(r)
|
||||||
|
}
|
||||||
|
case '}':
|
||||||
|
if inEnvVar {
|
||||||
|
envValue, ok := env.LookupEnv(envVar.String())
|
||||||
|
if !ok {
|
||||||
|
missingEnvVars = append(missingEnvVars, envVar.String())
|
||||||
|
} else {
|
||||||
|
buf.WriteString(envValue)
|
||||||
|
}
|
||||||
|
inEnvVar = false
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(r)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if expectingBrace {
|
||||||
|
buf.WriteRune('$')
|
||||||
|
expectingBrace = false
|
||||||
|
}
|
||||||
|
if inEnvVar {
|
||||||
|
envVar.WriteRune(r)
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectingBrace {
|
||||||
|
buf.WriteRune('$')
|
||||||
|
}
|
||||||
|
|
||||||
|
var err gperr.Error
|
||||||
|
if inEnvVar {
|
||||||
|
// Write back the unterminated ${...} so the output matches the input.
|
||||||
|
buf.WriteString("${")
|
||||||
|
buf.WriteString(envVar.String())
|
||||||
|
err = ErrUnterminatedEnvVar
|
||||||
|
}
|
||||||
|
if len(missingEnvVars) > 0 {
|
||||||
|
err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...)))
|
||||||
|
}
|
||||||
|
return buf.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseBlockRules parses the block-syntax rule format.
|
||||||
|
// Grammar:
|
||||||
|
//
|
||||||
|
// file := { ws | comment | rule }
|
||||||
|
// rule := default_rule | conditional_rule
|
||||||
|
// default_rule := 'default' ws* block
|
||||||
|
// conditional_rule := on_expr ws* block
|
||||||
|
// block := '{' do_body '}'
|
||||||
|
//
|
||||||
|
// Where:
|
||||||
|
// - on_expr is passed verbatim to RuleOn.Parse()
|
||||||
|
// - do_body is passed verbatim to Command.Parse()
|
||||||
|
//
|
||||||
|
// Comments (ignored outside quotes/backticks):
|
||||||
|
// - line comment: // ... or # ...
|
||||||
|
// - block comment: /* ... */
|
||||||
|
//
|
||||||
|
// Brace handling:
|
||||||
|
// - Braces inside quotes/backticks are ignored
|
||||||
|
// - Braces inside ${...} (env vars) are ignored in do_body
|
||||||
|
// - Braces in on_expr are not ignored (env vars must be quoted in on_expr)
|
||||||
|
//
|
||||||
|
//nolint:dupword
|
||||||
|
func parseBlockRules(src string) (Rules, gperr.Error) {
|
||||||
|
var rules Rules
|
||||||
|
var errs gperr.Builder
|
||||||
|
|
||||||
|
pos := 0
|
||||||
|
length := len(src)
|
||||||
|
t := newTokenizer(src)
|
||||||
|
|
||||||
|
for pos < length {
|
||||||
|
// Skip whitespace/comments between rules.
|
||||||
|
newPos, skipErr := t.skipComments(pos, true, true)
|
||||||
|
if skipErr != nil {
|
||||||
|
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
|
||||||
|
}
|
||||||
|
pos = newPos
|
||||||
|
if pos >= length {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stray closing brace at top-level: keep parsing but mark invalid so Rules.Validate() fails.
|
||||||
|
if src[pos] == '}' {
|
||||||
|
return nil, ErrInvalidBlockSyntax.Withf("unmatched '}' at position %d", pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse rule header (default, unconditional, or on_expr)
|
||||||
|
headerStart := pos
|
||||||
|
header := parseRuleHeader(&t, src, &pos, length)
|
||||||
|
headerStr := src[headerStart:pos]
|
||||||
|
|
||||||
|
// Skip whitespace/comments before '{' (default header may end before '{').
|
||||||
|
newPos, skipErr = t.skipComments(pos, false, true)
|
||||||
|
if skipErr != nil {
|
||||||
|
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
|
||||||
|
}
|
||||||
|
pos = newPos
|
||||||
|
|
||||||
|
if pos >= length || src[pos] != '{' {
|
||||||
|
errs.AddSubjectf(ErrInvalidBlockSyntax, "expected '{' after rule header %q", headerStr)
|
||||||
|
return nil, errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find matching '}' (respecting quotes and env vars in do_body)
|
||||||
|
bodyStart := pos + 1
|
||||||
|
bodyEnd, err := t.findMatchingBrace(bodyStart)
|
||||||
|
if err != nil {
|
||||||
|
errs.AddSubjectf(err, "rule header %q", headerStr)
|
||||||
|
return nil, errs.Error()
|
||||||
|
}
|
||||||
|
pos = bodyEnd + 1
|
||||||
|
|
||||||
|
onExpr := header
|
||||||
|
|
||||||
|
doBody := ""
|
||||||
|
if bodyStart < bodyEnd {
|
||||||
|
doBody = src[bodyStart:bodyEnd]
|
||||||
|
}
|
||||||
|
// Normalize do body for the inner DSL parser:
|
||||||
|
// - strip comments (outside quotes/backticks)
|
||||||
|
// - trim block whitespace/indentation
|
||||||
|
// - expand ${ENV} in-place so cmd.raw is usable/debuggable
|
||||||
|
doBody, err = preprocessDoBody(doBody)
|
||||||
|
if err != nil {
|
||||||
|
errs.AddSubjectf(err, "rule header %q", headerStr)
|
||||||
|
return nil, errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := Rule{
|
||||||
|
Name: "", // auto-generate if empty
|
||||||
|
On: RuleOn{},
|
||||||
|
Do: Command{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header semantics:
|
||||||
|
// - "default" => default rule (matched when no other rules are matched)
|
||||||
|
// - "" => unconditional rule (always matches)
|
||||||
|
// - otherwise => conditional rule (on expression)
|
||||||
|
switch onExpr {
|
||||||
|
case "default":
|
||||||
|
rule.On.raw = OnDefault
|
||||||
|
case "":
|
||||||
|
// leave rule.On as zero value => checker=nil => always matches
|
||||||
|
default:
|
||||||
|
if parseErr := rule.On.Parse(onExpr); parseErr != nil {
|
||||||
|
errs.AddSubjectf(parseErr, "on")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if doBody != "" {
|
||||||
|
if parseErr := rule.Do.Parse(doBody); parseErr != nil {
|
||||||
|
errs.AddSubjectf(parseErr, "do")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errs.HasError() {
|
||||||
|
return nil, errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func preprocessDoBody(doBody string) (string, gperr.Error) {
|
||||||
|
doBody = strings.TrimSpace(doBody)
|
||||||
|
if doBody == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := doBody
|
||||||
|
// If comments are possible, strip them first while preserving line breaks.
|
||||||
|
if strings.ContainsAny(normalized, "#/") {
|
||||||
|
stripped, err := stripCommentsPreserveNewlines(normalized)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
normalized = stripped
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop lines that are empty after trimming, while preserving indentation of non-empty lines.
|
||||||
|
out := getStringBuffer(len(normalized))
|
||||||
|
|
||||||
|
lineStart := 0
|
||||||
|
wroteLine := false
|
||||||
|
for i := 0; i <= len(normalized); i++ {
|
||||||
|
if i < len(normalized) && normalized[i] != '\n' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
line := normalized[lineStart:i]
|
||||||
|
if strings.TrimSpace(line) != "" {
|
||||||
|
if wroteLine {
|
||||||
|
out.WriteByte('\n')
|
||||||
|
}
|
||||||
|
out.WriteString(line)
|
||||||
|
wroteLine = true
|
||||||
|
}
|
||||||
|
lineStart = i + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !wroteLine {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
normalized = out.String()
|
||||||
|
|
||||||
|
// Expand env vars to keep Command.raw consistent with parsed semantics.
|
||||||
|
if !strings.Contains(normalized, "${") {
|
||||||
|
return normalized, nil
|
||||||
|
}
|
||||||
|
expanded, err := expandEnvVarsRaw(normalized)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return expanded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripCommentsPreserveNewlines removes //, #, and /* */ comments outside quotes/backticks.
|
||||||
|
// It preserves newlines so command line boundaries remain intact.
|
||||||
|
func stripCommentsPreserveNewlines(src string) (string, gperr.Error) {
|
||||||
|
if !strings.ContainsAny(src, "#/") {
|
||||||
|
return src, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := getStringBuffer(len(src))
|
||||||
|
|
||||||
|
quote := rune(0)
|
||||||
|
inLine := false
|
||||||
|
inBlock := false
|
||||||
|
atLineStart := true
|
||||||
|
prevIsSpace := true
|
||||||
|
|
||||||
|
for i := 0; i < len(src); {
|
||||||
|
c := src[i]
|
||||||
|
|
||||||
|
if inLine {
|
||||||
|
if c == '\n' {
|
||||||
|
inLine = false
|
||||||
|
out.WriteByte('\n')
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inBlock {
|
||||||
|
if c == '\n' {
|
||||||
|
out.WriteByte('\n')
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '*' && i+1 < len(src) && src[i+1] == '/' {
|
||||||
|
inBlock = false
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if quote != 0 {
|
||||||
|
out.WriteByte(c)
|
||||||
|
if c == '\\' && i+1 < len(src) {
|
||||||
|
// Write next char and skip it (escape sequence)
|
||||||
|
i++
|
||||||
|
out.WriteByte(src[i])
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = false
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rune(c) == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
if c == '\n' {
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
} else {
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = unicode.IsSpace(rune(c))
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not in quote/comment.
|
||||||
|
switch c {
|
||||||
|
case '\'', '"', '`':
|
||||||
|
quote = rune(c)
|
||||||
|
out.WriteByte(c)
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = false
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
case '#':
|
||||||
|
if atLineStart || prevIsSpace {
|
||||||
|
inLine = true
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case '/':
|
||||||
|
if i+1 < len(src) {
|
||||||
|
n := src[i+1]
|
||||||
|
if (atLineStart || prevIsSpace) && n == '/' {
|
||||||
|
inLine = true
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (atLineStart || prevIsSpace) && n == '*' {
|
||||||
|
inBlock = true
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out.WriteByte(c)
|
||||||
|
if c == '\n' {
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
} else {
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = unicode.IsSpace(rune(c))
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
if inBlock {
|
||||||
|
return "", ErrInvalidBlockSyntax.Withf("unterminated block comment")
|
||||||
|
}
|
||||||
|
return out.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRuleHeader parses the rule header (default or on expression).
|
||||||
|
// Returns the header string, or "" if parsing failed.
|
||||||
|
func parseRuleHeader(t *Tokenizer, src string, pos *int, length int) string {
|
||||||
|
start := *pos
|
||||||
|
|
||||||
|
// Check for 'default' keyword
|
||||||
|
if *pos+7 <= length && src[*pos:*pos+7] == "default" {
|
||||||
|
next := *pos + 7
|
||||||
|
if next >= length || unicode.IsSpace(rune(src[next])) {
|
||||||
|
*pos = next
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse on expression until we hit '{' outside quotes.
|
||||||
|
bracePos, err := t.scanToBrace(*pos)
|
||||||
|
if err != nil {
|
||||||
|
*pos = length
|
||||||
|
return strings.TrimSpace(src[start:*pos])
|
||||||
|
}
|
||||||
|
*pos = bracePos
|
||||||
|
return strings.TrimSpace(src[start:*pos])
|
||||||
|
}
|
||||||
48
internal/route/rules/block_parser_bench_test.go
Normal file
48
internal/route/rules/block_parser_bench_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func BenchmarkParseBlockRules(b *testing.B) {
|
||||||
|
const rulesString = `
|
||||||
|
default {
|
||||||
|
remove resp_header X-Secret
|
||||||
|
add resp_header X-Custom-Header custom-value
|
||||||
|
}
|
||||||
|
|
||||||
|
header X-Test-Header {
|
||||||
|
set header X-Remote-Type public
|
||||||
|
remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||||
|
set header X-Remote-Type private
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
path glob(/api/admin/*) {
|
||||||
|
cookie session-id {
|
||||||
|
set header X-Session-ID $cookie(session-id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
!remote 192.168.0.0/16 {
|
||||||
|
!header X-User-Role admin & !header X-User-Role user {
|
||||||
|
error 403 "Access denied"
|
||||||
|
} elif remote 127.0.0.1 {
|
||||||
|
header X-User-Role staff {
|
||||||
|
set header X-User-Role staff
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error 403 "Access denied"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
err := rules.Parse(rulesString)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
var rules Rules
|
||||||
|
_ = rules.Parse(rulesString)
|
||||||
|
}
|
||||||
|
}
|
||||||
390
internal/route/rules/block_parser_test.go
Normal file
390
internal/route/rules/block_parser_test.go
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/godoxy/internal/serialization"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testParseRules(t *testing.T, data string) Rules {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
|
||||||
|
require.True(t, convertible)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func testParseRulesError(t *testing.T, data string) error {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
|
||||||
|
require.True(t, convertible)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_DefaultRule(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `default {
|
||||||
|
upstream
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||||
|
assert.Equal(t, "upstream", rules[0].Do.raw)
|
||||||
|
assert.True(t, rules[0].Do.raw == CommandUpstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_ConditionalRule(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `path glob(/api/*) {
|
||||||
|
proxy http://localhost:8080
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
assert.Equal(t, "path glob(/api/*)", rules[0].On.raw)
|
||||||
|
assert.Equal(t, "proxy http://localhost:8080", rules[0].Do.raw)
|
||||||
|
require.Len(t, rules[0].Do.pre, 1)
|
||||||
|
_, ok := rules[0].Do.pre[0].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, rules[0].Do.post, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_MultipleRules(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `default {
|
||||||
|
bypass
|
||||||
|
}
|
||||||
|
|
||||||
|
path /api/* {
|
||||||
|
proxy http://localhost:8080
|
||||||
|
}
|
||||||
|
|
||||||
|
header Connection Upgrade &
|
||||||
|
header Upgrade websocket {
|
||||||
|
route ws-api
|
||||||
|
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 3)
|
||||||
|
|
||||||
|
// Default rule
|
||||||
|
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||||
|
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||||
|
|
||||||
|
// API rule
|
||||||
|
assert.Equal(t, "path /api/*", rules[1].On.raw)
|
||||||
|
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
|
||||||
|
|
||||||
|
// WebSocket rule
|
||||||
|
assert.Equal(t, "header Connection Upgrade &\nheader Upgrade websocket", rules[2].On.raw)
|
||||||
|
assert.Equal(t, `route ws-api
|
||||||
|
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"`, rules[2].Do.raw)
|
||||||
|
require.Len(t, rules[2].Do.pre, 2)
|
||||||
|
_, ok := rules[2].Do.pre[0].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
_, ok = rules[2].Do.pre[1].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, rules[2].Do.post, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_Comments(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `// This is a comment
|
||||||
|
default {
|
||||||
|
bypass // inline comment
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Block comment
|
||||||
|
spanning multiple lines */
|
||||||
|
path /admin/* {
|
||||||
|
require_auth
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 2)
|
||||||
|
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||||
|
assert.Equal(t, "path /admin/*", rules[1].On.raw)
|
||||||
|
assert.Equal(t, "require_auth", rules[1].Do.raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_HashComment(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `# YAML-style comment
|
||||||
|
default {
|
||||||
|
bypass
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||||
|
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_EnvVars(t *testing.T) {
|
||||||
|
t.Setenv("CUSTOM_HEADER", "test-header")
|
||||||
|
|
||||||
|
rules := testParseRules(t, `path /api/* {
|
||||||
|
set header X-Custom "${CUSTOM_HEADER}"
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
assert.Equal(t, "path /api/*", rules[0].On.raw)
|
||||||
|
assert.Equal(t, `set header X-Custom "test-header"`, rules[0].Do.raw)
|
||||||
|
require.Len(t, rules[0].Do.pre, 1)
|
||||||
|
_, ok := rules[0].Do.pre[0].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, rules[0].Do.post, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_YAMLFallback(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `- name: default
|
||||||
|
do: bypass
|
||||||
|
- name: api
|
||||||
|
on: path glob(/api/*)
|
||||||
|
do: proxy http://localhost:8080`)
|
||||||
|
require.Len(t, rules, 2)
|
||||||
|
assert.Equal(t, "default", rules[0].Name)
|
||||||
|
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||||
|
assert.Equal(t, "api", rules[1].Name)
|
||||||
|
assert.Equal(t, "path glob(/api/*)", rules[1].On.raw)
|
||||||
|
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
|
||||||
|
require.Len(t, rules[1].Do.pre, 1)
|
||||||
|
_, ok := rules[1].Do.pre[0].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, rules[1].Do.post, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_UnmatchedBrace(t *testing.T) {
|
||||||
|
t.Run("unquoted", func(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `path /api/* {
|
||||||
|
proxy http://localhost:8080}
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("quoted", func(t *testing.T) {
|
||||||
|
_ = testParseRules(t, `path /api/* {
|
||||||
|
error 403 "some message}"
|
||||||
|
}`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_UnterminatedBlockComment(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `/* unterminated block comment
|
||||||
|
default {
|
||||||
|
bypass
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `
|
||||||
|
header X-Test-Header {
|
||||||
|
set header X-Remote-Type public
|
||||||
|
remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||||
|
set header X-Remote-Type private
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
assert.Equal(t, "header X-Test-Header", rules[0].On.raw)
|
||||||
|
|
||||||
|
require.Len(t, rules[0].Do.pre, 2)
|
||||||
|
_, ok := rules[0].Do.pre[0].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, rules[0].Do.post, 0)
|
||||||
|
|
||||||
|
ifCmd, ok := rules[0].Do.pre[1].(IfBlockCommand)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "remote 127.0.0.1 | remote 192.168.0.0/16", ifCmd.On.raw)
|
||||||
|
|
||||||
|
require.Len(t, ifCmd.Do, 1)
|
||||||
|
|
||||||
|
upstream := func(http.ResponseWriter, *http.Request) {}
|
||||||
|
|
||||||
|
t.Run("condition matched executes nested content", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Test-Header", "1")
|
||||||
|
req.RemoteAddr = "127.0.0.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
|
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "private", req.Header.Get("X-Remote-Type"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("condition not matched skips nested content", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Test-Header", "1")
|
||||||
|
req.RemoteAddr = "10.0.0.1:12345"
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
|
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "public", req.Header.Get("X-Remote-Type"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_ElifElse(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `
|
||||||
|
header X-Test-Header {
|
||||||
|
set header X-Mode outer
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
} elif method POST {
|
||||||
|
set header X-Mode post
|
||||||
|
} else {
|
||||||
|
set header X-Mode other
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
|
||||||
|
require.Len(t, rules[0].Do.pre, 2)
|
||||||
|
|
||||||
|
ifCmd, ok := rules[0].Do.pre[1].(IfElseBlockCommand)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, ifCmd.Ifs, 2)
|
||||||
|
assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw)
|
||||||
|
assert.Equal(t, "method POST", ifCmd.Ifs[1].On.raw)
|
||||||
|
require.NotNil(t, ifCmd.Else)
|
||||||
|
|
||||||
|
upstream := func(http.ResponseWriter, *http.Request) {}
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "get branch", method: http.MethodGet, want: "get"},
|
||||||
|
{name: "post branch", method: http.MethodPost, want: "post"},
|
||||||
|
{name: "else branch", method: http.MethodPut, want: "other"},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(tc.method, "/", nil)
|
||||||
|
req.Header.Set("X-Test-Header", "1")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
|
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.want, req.Header.Get("X-Mode"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_DefaultRule_CommentBeforeBrace(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `default /* comment between header and brace */ {
|
||||||
|
bypass
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||||
|
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_StrayClosingBraceAtTopLevel(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `}
|
||||||
|
default {
|
||||||
|
bypass
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_ElifMustBeSameLine(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `header X-Test-Header {
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
}
|
||||||
|
elif method POST {
|
||||||
|
set header X-Mode post
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_ElseMustBeLastOnLine(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `header X-Test-Header {
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
} else {
|
||||||
|
set header X-Mode other
|
||||||
|
} set header X-After else
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unexpected token after else block")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_MultipleElse(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `header X-Test-Header {
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
} else {
|
||||||
|
set header X-Mode other
|
||||||
|
} else {
|
||||||
|
set header X-Mode other2
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
// assert.Contains(t, err.Error(), "multiple 'else' branches")
|
||||||
|
assert.Contains(t, err.Error(), "unexpected token after else block")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_ElifMissingOnExpr(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `header X-Test-Header {
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
} elif {
|
||||||
|
set header X-Mode post
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "expected on-expr after 'elif'")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_LineEndingBraceHeuristic(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `{
|
||||||
|
set header X-Literal "{"
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
require.Len(t, rules[0].Do.pre, 1)
|
||||||
|
_, ok := rules[0].Do.pre[0].(Handler)
|
||||||
|
require.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_LineEndingBraceWithTrailingSpaces(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `header X-Test-Header {
|
||||||
|
method GET {
|
||||||
|
set header X-Mode get
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
require.Len(t, rules[0].Do.pre, 1)
|
||||||
|
ifCmd, ok := rules[0].Do.pre[0].(IfBlockCommand)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "method GET", ifCmd.On.raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_LineEndingBraceWithTrailingComment(t *testing.T) {
|
||||||
|
rules := testParseRules(t, `header X-Test-Header {
|
||||||
|
method GET { // GET branch
|
||||||
|
set header X-Mode get
|
||||||
|
} else { # fallback branch
|
||||||
|
set header X-Mode other
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Len(t, rules, 1)
|
||||||
|
require.Len(t, rules[0].Do.pre, 1)
|
||||||
|
|
||||||
|
ifCmd, ok := rules[0].Do.pre[0].(IfElseBlockCommand)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, ifCmd.Ifs, 1)
|
||||||
|
assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw)
|
||||||
|
require.NotNil(t, ifCmd.Else)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseBlockRules_NestedBlocks_LineEndingBraceInterpretsAsBlock(t *testing.T) {
|
||||||
|
err := testParseRulesError(t, `{
|
||||||
|
set header X-Bad {
|
||||||
|
set header X-Test fail
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid `rule.on` target")
|
||||||
|
}
|
||||||
@@ -1,21 +1,25 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
CheckFunc func(w http.ResponseWriter, r *http.Request) bool
|
CheckFunc func(w *httputils.ResponseModifier, r *http.Request) bool
|
||||||
Checker interface {
|
Checker interface {
|
||||||
Check(w http.ResponseWriter, r *http.Request) bool
|
Check(w *httputils.ResponseModifier, r *http.Request) bool
|
||||||
}
|
}
|
||||||
CheckMatchSingle []Checker
|
CheckMatchSingle []Checker
|
||||||
CheckMatchAll []Checker
|
CheckMatchAll []Checker
|
||||||
)
|
)
|
||||||
|
|
||||||
func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool {
|
func (checker CheckFunc) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return checker(w, r)
|
return checker(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool {
|
func (checkers CheckMatchSingle) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
for _, check := range checkers {
|
for _, check := range checkers {
|
||||||
if check.Check(w, r) {
|
if check.Check(w, r) {
|
||||||
return true
|
return true
|
||||||
@@ -24,7 +28,7 @@ func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) b
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool {
|
func (checkers CheckMatchAll) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
for _, check := range checkers {
|
for _, check := range checkers {
|
||||||
if !check.Check(w, r) {
|
if !check.Check(w, r) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -1,79 +1,62 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errTerminateRule = errors.New("terminate rule")
|
||||||
|
|
||||||
type (
|
type (
|
||||||
handlerFunc func(w http.ResponseWriter, r *http.Request) error
|
HandlerFunc func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
|
||||||
|
Handler struct {
|
||||||
|
fn HandlerFunc
|
||||||
|
phase PhaseFlag
|
||||||
|
terminate bool
|
||||||
|
}
|
||||||
|
|
||||||
CommandHandler interface {
|
CommandHandler interface {
|
||||||
// CommandHandler can read and modify the values
|
// CommandHandler can read and modify the values
|
||||||
// then handle the request
|
// then handle the request
|
||||||
// finally proceed to next command (or return) base on situation
|
// finally proceed to next command (or return) base on situation
|
||||||
Handle(w http.ResponseWriter, r *http.Request) error
|
ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
|
||||||
IsResponseHandler() bool
|
Phase() PhaseFlag
|
||||||
}
|
}
|
||||||
// NonTerminatingCommand will run then proceed to next command or reverse proxy.
|
|
||||||
NonTerminatingCommand handlerFunc
|
|
||||||
// TerminatingCommand will run then return immediately.
|
|
||||||
TerminatingCommand handlerFunc
|
|
||||||
// OnResponseCommand will run then return based on the response.
|
|
||||||
OnResponseCommand handlerFunc
|
|
||||||
// BypassCommand will skip all the following commands
|
|
||||||
// and directly return to reverse proxy.
|
|
||||||
BypassCommand struct{}
|
|
||||||
// Commands is a slice of CommandHandler.
|
// Commands is a slice of CommandHandler.
|
||||||
Commands []CommandHandler
|
Commands []CommandHandler
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
func (h Handler) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
return c(w, r)
|
return h.fn(w, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c NonTerminatingCommand) IsResponseHandler() bool {
|
func (h Handler) Phase() PhaseFlag {
|
||||||
return false
|
return h.phase
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
func (h Handler) Terminates() bool {
|
||||||
if err := c(w, r); err != nil {
|
return h.terminate
|
||||||
return err
|
|
||||||
}
|
|
||||||
return errTerminated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c TerminatingCommand) IsResponseHandler() bool {
|
func (c Commands) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
return c(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c OnResponseCommand) IsResponseHandler() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
return errTerminated
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c BypassCommand) IsResponseHandler() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
for _, cmd := range c {
|
for _, cmd := range c {
|
||||||
if err := cmd.Handle(w, r); err != nil {
|
err := cmd.ServeHTTP(w, r, upstream)
|
||||||
|
if err != nil {
|
||||||
|
// Terminating actions stop the command chain immediately.
|
||||||
|
// Will be handled by the caller.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Commands) IsResponseHandler() bool {
|
func (c Commands) Phase() PhaseFlag {
|
||||||
|
req := PhaseNone
|
||||||
for _, cmd := range c {
|
for _, cmd := range c {
|
||||||
if cmd.IsResponseHandler() {
|
req |= cmd.Phase()
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
return req
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -25,16 +24,16 @@ import (
|
|||||||
type (
|
type (
|
||||||
Command struct {
|
Command struct {
|
||||||
raw string
|
raw string
|
||||||
exec CommandHandler
|
pre Commands // runs before w.WriteHeader
|
||||||
isResponseHandler bool
|
post Commands
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (cmd *Command) IsResponseHandler() bool {
|
|
||||||
return cmd.isResponseHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
CommandUpstream = "upstream"
|
||||||
|
CommandUpstreamOld = "bypass"
|
||||||
|
CommandUpstreamOld2 = "pass"
|
||||||
|
|
||||||
CommandRequireAuth = "require_auth"
|
CommandRequireAuth = "require_auth"
|
||||||
CommandRewrite = "rewrite"
|
CommandRewrite = "rewrite"
|
||||||
CommandServe = "serve"
|
CommandServe = "serve"
|
||||||
@@ -48,8 +47,6 @@ const (
|
|||||||
CommandRemove = "remove"
|
CommandRemove = "remove"
|
||||||
CommandLog = "log"
|
CommandLog = "log"
|
||||||
CommandNotify = "notify"
|
CommandNotify = "notify"
|
||||||
CommandPass = "pass"
|
|
||||||
CommandPassAlt = "bypass"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
|
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
@@ -60,36 +57,60 @@ func InitAuthHandler(handler AuthHandler) {
|
|||||||
authHandler = handler
|
authHandler = handler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
commands[CommandUpstreamOld] = commands[CommandUpstream]
|
||||||
|
commands[CommandUpstreamOld2] = commands[CommandUpstream]
|
||||||
|
}
|
||||||
|
|
||||||
var commands = map[string]struct {
|
var commands = map[string]struct {
|
||||||
help Help
|
help Help
|
||||||
validate ValidateFunc
|
validate ValidateFunc
|
||||||
build func(args any) CommandHandler
|
build func(args any) HandlerFunc
|
||||||
isResponseHandler bool
|
terminate bool
|
||||||
}{
|
}{
|
||||||
|
CommandUpstream: {
|
||||||
|
help: Help{
|
||||||
|
command: CommandUpstream,
|
||||||
|
description: makeLines("Pass the request to the upstream"),
|
||||||
|
args: map[string]string{},
|
||||||
|
},
|
||||||
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
if len(args) != 0 {
|
||||||
|
return phase, nil, ErrExpectNoArg
|
||||||
|
}
|
||||||
|
return phase, nil, nil
|
||||||
|
},
|
||||||
|
build: func(args any) HandlerFunc {
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
|
upstream(w, r)
|
||||||
|
return errTerminateRule
|
||||||
|
}
|
||||||
|
},
|
||||||
|
terminate: true,
|
||||||
|
},
|
||||||
CommandRequireAuth: {
|
CommandRequireAuth: {
|
||||||
help: Help{
|
help: Help{
|
||||||
command: CommandRequireAuth,
|
command: CommandRequireAuth,
|
||||||
description: makeLines("Require HTTP authentication for incoming requests"),
|
description: makeLines("Require HTTP authentication for incoming requests"),
|
||||||
args: map[string]string{},
|
args: map[string]string{},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePre
|
||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, ErrExpectNoArg
|
return phase, nil, ErrExpectNoArg
|
||||||
}
|
}
|
||||||
//nolint:nilnil
|
return phase, nil, nil
|
||||||
return nil, nil
|
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
if authHandler == nil {
|
if authHandler == nil { // no auth handler configured, allow request to proceed
|
||||||
http.Error(w, "Auth handler not initialized", http.StatusInternalServerError)
|
return nil
|
||||||
return errTerminated
|
|
||||||
}
|
}
|
||||||
if !authHandler(w, r) {
|
if proceed := authHandler(w, r); !proceed {
|
||||||
return errTerminated
|
return errTerminateRule
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandRewrite: {
|
CommandRewrite: {
|
||||||
@@ -104,26 +125,27 @@ var commands = map[string]struct {
|
|||||||
"to": "the path to rewrite to, must start with /",
|
"to": "the path to rewrite to, must start with /",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePre
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return nil, ErrExpectTwoArgs
|
return phase, nil, ErrExpectTwoArgs
|
||||||
}
|
}
|
||||||
path1, err1 := validateURLPath(args[:1])
|
path1, err1 := validateURLPath(args[:1])
|
||||||
path2, err2 := validateURLPath(args[1:])
|
path2, err2 := validateURLPath(args[1:])
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
err1 = gperr.PrependSubject(err1, "from")
|
err1 = gperr.Errorf("from: %w", err1)
|
||||||
}
|
}
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
err2 = gperr.PrependSubject(err2, "to")
|
err2 = gperr.Errorf("to: %w", err2)
|
||||||
}
|
}
|
||||||
if err1 != nil || err2 != nil {
|
if err1 != nil || err2 != nil {
|
||||||
return nil, gperr.Join(err1, err2)
|
return phase, nil, gperr.Join(err1, err2)
|
||||||
}
|
}
|
||||||
return &StrTuple{path1.(string), path2.(string)}, nil
|
return phase, &StrTuple{path1.(string), path2.(string)}, nil
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
orig, repl := args.(*StrTuple).Unpack()
|
orig, repl := args.(*StrTuple).Unpack()
|
||||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
path := r.URL.Path
|
path := r.URL.Path
|
||||||
if len(path) > 0 && path[0] != '/' {
|
if len(path) > 0 && path[0] != '/' {
|
||||||
path = "/" + path
|
path = "/" + path
|
||||||
@@ -133,10 +155,10 @@ var commands = map[string]struct {
|
|||||||
}
|
}
|
||||||
path = repl + path[len(orig):]
|
path = repl + path[len(orig):]
|
||||||
r.URL.Path = path
|
r.URL.Path = path
|
||||||
r.URL.RawPath = r.URL.EscapedPath()
|
r.URL.RawPath = ""
|
||||||
r.RequestURI = r.URL.RequestURI()
|
r.RequestURI = ""
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandServe: {
|
CommandServe: {
|
||||||
@@ -150,14 +172,19 @@ var commands = map[string]struct {
|
|||||||
"root": "the file system path to serve, must be an existing directory",
|
"root": "the file system path to serve, must be an existing directory",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateFSPath,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
build: func(args any) CommandHandler {
|
phase = PhasePre
|
||||||
root := args.(string)
|
parsedArgs, err = validateFSPath(args)
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return
|
||||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
},
|
},
|
||||||
|
build: func(args any) HandlerFunc {
|
||||||
|
root := args.(string)
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
|
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||||
|
return errTerminateRule
|
||||||
|
}
|
||||||
|
},
|
||||||
|
terminate: true,
|
||||||
},
|
},
|
||||||
CommandRedirect: {
|
CommandRedirect: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -170,14 +197,19 @@ var commands = map[string]struct {
|
|||||||
"to": "the url to redirect to, can be relative or absolute URL",
|
"to": "the url to redirect to, can be relative or absolute URL",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateURL,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
build: func(args any) CommandHandler {
|
phase = PhasePre
|
||||||
target := args.(*nettypes.URL).String()
|
parsedArgs, err = validateURL(args)
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return
|
||||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
},
|
},
|
||||||
|
build: func(args any) HandlerFunc {
|
||||||
|
target := args.(*nettypes.URL).String()
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
|
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||||
|
return errTerminateRule
|
||||||
|
}
|
||||||
|
},
|
||||||
|
terminate: true,
|
||||||
},
|
},
|
||||||
CommandRoute: {
|
CommandRoute: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -190,15 +222,16 @@ var commands = map[string]struct {
|
|||||||
"route": "the route to route to",
|
"route": "the route to route to",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePre
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return phase, nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
return args[0], nil
|
return phase, args[0], nil
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
route := args.(string)
|
route := args.(string)
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
|
return func(w *httputils.ResponseModifier, req *http.Request, upstream http.HandlerFunc) error {
|
||||||
ep := entrypoint.FromCtx(req.Context())
|
ep := entrypoint.FromCtx(req.Context())
|
||||||
r, ok := ep.HTTPRoutes().Get(route)
|
r, ok := ep.HTTPRoutes().Get(route)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -212,9 +245,10 @@ var commands = map[string]struct {
|
|||||||
} else {
|
} else {
|
||||||
http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound)
|
http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound)
|
||||||
}
|
}
|
||||||
return nil
|
return errTerminateRule
|
||||||
})
|
}
|
||||||
},
|
},
|
||||||
|
terminate: true,
|
||||||
},
|
},
|
||||||
CommandError: {
|
CommandError: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -228,34 +262,40 @@ var commands = map[string]struct {
|
|||||||
"text": "the error message to return",
|
"text": "the error message to return",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePre
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return nil, ErrExpectTwoArgs
|
return phase, nil, ErrExpectTwoArgs
|
||||||
}
|
}
|
||||||
codeStr, text := args[0], args[1]
|
codeStr, text := args[0], args[1]
|
||||||
code, err := strconv.Atoi(codeStr)
|
code, err := strconv.Atoi(codeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidArguments.With(err)
|
return phase, nil, ErrInvalidArguments.With(err)
|
||||||
}
|
}
|
||||||
if !httputils.IsStatusCodeValid(code) {
|
if !httputils.IsStatusCodeValid(code) {
|
||||||
return nil, ErrInvalidArguments.Subject(codeStr)
|
return phase, nil, ErrInvalidArguments.Subject(codeStr)
|
||||||
}
|
}
|
||||||
textTmpl, err := validateTemplate(text, true)
|
tmplReq, textTmpl, err := validateTemplate(text, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidArguments.With(err)
|
return phase, nil, ErrInvalidArguments.With(err)
|
||||||
}
|
}
|
||||||
return &Tuple[int, templateString]{code, textTmpl}, nil
|
phase |= tmplReq
|
||||||
|
return phase, &Tuple[int, templateString]{code, textTmpl}, nil
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
|
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
// error command should overwrite the response body
|
// error command should overwrite the response body
|
||||||
httputils.GetInitResponseModifier(w).ResetBody()
|
w.ResetBody()
|
||||||
w.WriteHeader(code)
|
w.WriteHeader(code)
|
||||||
err := textTmpl.ExpandVars(w, r, w)
|
_, err := textTmpl.ExpandVars(w, r, w.BodyBuffer())
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
})
|
}
|
||||||
|
return errTerminateRule
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
terminate: true,
|
||||||
},
|
},
|
||||||
CommandRequireBasicAuth: {
|
CommandRequireBasicAuth: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -268,20 +308,22 @@ var commands = map[string]struct {
|
|||||||
"realm": "the authentication realm",
|
"realm": "the authentication realm",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePre
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
return args[0], nil
|
return phase, args[0], nil
|
||||||
}
|
}
|
||||||
return nil, ErrExpectOneArg
|
return phase, nil, ErrExpectOneArg
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
realm := args.(string)
|
realm := args.(string)
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, realm))
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
return nil
|
return errTerminateRule
|
||||||
})
|
}
|
||||||
},
|
},
|
||||||
|
terminate: true,
|
||||||
},
|
},
|
||||||
CommandProxy: {
|
CommandProxy: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -294,14 +336,19 @@ var commands = map[string]struct {
|
|||||||
"to": "the url to proxy to, must be an absolute URL",
|
"to": "the url to proxy to, must be an absolute URL",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateURL,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
build: func(args any) CommandHandler {
|
phase = PhasePre
|
||||||
|
parsedArgs, err = validateURL(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
|
build: func(args any) HandlerFunc {
|
||||||
target := args.(*nettypes.URL)
|
target := args.(*nettypes.URL)
|
||||||
if target.Scheme == "" {
|
if target.Scheme == "" {
|
||||||
target.Scheme = "http"
|
target.Scheme = "http"
|
||||||
}
|
}
|
||||||
if target.Host == "" {
|
if target.Host == "" {
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
rawPath := target.EscapedPath()
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
url := target.URL
|
url := target.URL
|
||||||
url.Host = routes.TryGetUpstreamHostPort(r)
|
url.Host = routes.TryGetUpstreamHostPort(r)
|
||||||
if url.Host == "" {
|
if url.Host == "" {
|
||||||
@@ -309,18 +356,19 @@ var commands = map[string]struct {
|
|||||||
}
|
}
|
||||||
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
|
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
|
||||||
r.URL.Path = target.Path
|
r.URL.Path = target.Path
|
||||||
r.URL.RawPath = r.URL.EscapedPath()
|
r.URL.RawPath = rawPath
|
||||||
r.RequestURI = r.URL.RequestURI()
|
r.RequestURI = ""
|
||||||
rp.ServeHTTP(w, r)
|
rp.ServeHTTP(w, r)
|
||||||
return nil
|
return errTerminateRule
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
|
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
|
||||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
rp.ServeHTTP(w, r)
|
rp.ServeHTTP(w, r)
|
||||||
return nil
|
return errTerminateRule
|
||||||
})
|
}
|
||||||
},
|
},
|
||||||
|
terminate: true,
|
||||||
},
|
},
|
||||||
CommandSet: {
|
CommandSet: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -335,11 +383,11 @@ var commands = map[string]struct {
|
|||||||
"value": "the value to set",
|
"value": "the value to set",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
return validateModField(ModFieldSet, args)
|
return validateModField(ModFieldSet, args)
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
return args.(CommandHandler)
|
return args.(HandlerFunc)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandAdd: {
|
CommandAdd: {
|
||||||
@@ -355,11 +403,11 @@ var commands = map[string]struct {
|
|||||||
"value": "the value to add",
|
"value": "the value to add",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
return validateModField(ModFieldAdd, args)
|
return validateModField(ModFieldAdd, args)
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
return args.(CommandHandler)
|
return args.(HandlerFunc)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandRemove: {
|
CommandRemove: {
|
||||||
@@ -374,15 +422,14 @@ var commands = map[string]struct {
|
|||||||
"field": "the field to remove",
|
"field": "the field to remove",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
return validateModField(ModFieldRemove, args)
|
return validateModField(ModFieldRemove, args)
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
return args.(CommandHandler)
|
return args.(HandlerFunc)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandLog: {
|
CommandLog: {
|
||||||
isResponseHandler: true,
|
|
||||||
help: Help{
|
help: Help{
|
||||||
command: CommandLog,
|
command: CommandLog,
|
||||||
description: makeLines(
|
description: makeLines(
|
||||||
@@ -399,46 +446,57 @@ var commands = map[string]struct {
|
|||||||
"template": "the template to log",
|
"template": "the template to log",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return nil, ErrExpectThreeArgs
|
return phase, nil, ErrExpectThreeArgs
|
||||||
}
|
}
|
||||||
tmpl, err := validateTemplate(args[2], true)
|
phase, tmpl, err := validateTemplate(args[2], true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return phase, nil, err
|
||||||
}
|
}
|
||||||
level, err := validateLevel(args[0])
|
level, err := validateLevel(args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return phase, nil, err
|
||||||
}
|
}
|
||||||
// NOTE: file will stay opened forever
|
// NOTE: file will stay opened forever
|
||||||
// it leverages accesslog.NewFileIO so
|
// it leverages accesslog.NewFileIO so
|
||||||
// it will be opened only once for the same path
|
// it will be opened only once for the same path
|
||||||
f, err := openFile(args[1])
|
f, err := openFile(args[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return phase, nil, err
|
||||||
}
|
}
|
||||||
return &onLogArgs{level, f, tmpl}, nil
|
return phase, &onLogArgs{level, f, tmpl}, nil
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
level, f, tmpl := args.(*onLogArgs).Unpack()
|
level, f, tmpl := args.(*onLogArgs).Unpack()
|
||||||
var logger io.Writer
|
var logger io.Writer
|
||||||
if f == stdout || f == stderr {
|
isStdLogger := f == stdout || f == stderr
|
||||||
|
if isStdLogger {
|
||||||
logger = logging.NewLoggerWithFixedLevel(level, f)
|
logger = logging.NewLoggerWithFixedLevel(level, f)
|
||||||
} else {
|
} else {
|
||||||
logger = f
|
logger = f
|
||||||
}
|
}
|
||||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
err := tmpl.ExpandVars(w, r, logger)
|
if isStdLogger {
|
||||||
if err != nil {
|
bufPool := w.BufPool()
|
||||||
|
buf := bufPool.GetBuffer()
|
||||||
|
defer bufPool.PutBuffer(buf)
|
||||||
|
|
||||||
|
if _, err := tmpl.ExpandVars(w, r, buf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if buf.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
_, err := logger.Write(buf.Bytes())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := tmpl.ExpandVars(w, r, logger)
|
||||||
|
return err
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
CommandNotify: {
|
CommandNotify: {
|
||||||
isResponseHandler: true,
|
|
||||||
help: Help{
|
help: Help{
|
||||||
command: CommandNotify,
|
command: CommandNotify,
|
||||||
description: makeLines(
|
description: makeLines(
|
||||||
@@ -456,22 +514,24 @@ var commands = map[string]struct {
|
|||||||
"body": "the body of the notification",
|
"body": "the body of the notification",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 4 {
|
if len(args) != 4 {
|
||||||
return nil, ErrExpectFourArgs
|
return phase, nil, ErrExpectFourArgs
|
||||||
}
|
}
|
||||||
titleTmpl, err := validateTemplate(args[2], false)
|
req1, titleTmpl, err := validateTemplate(args[2], false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return phase, nil, err
|
||||||
}
|
}
|
||||||
bodyTmpl, err := validateTemplate(args[3], false)
|
req2, bodyTmpl, err := validateTemplate(args[3], false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return phase, nil, err
|
||||||
}
|
}
|
||||||
level, err := validateLevel(args[0])
|
level, err := validateLevel(args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return phase, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
phase |= req1 | req2
|
||||||
// TODO: validate provider
|
// TODO: validate provider
|
||||||
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
|
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
|
||||||
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
|
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
|
||||||
@@ -480,34 +540,34 @@ var commands = map[string]struct {
|
|||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
// }
|
// }
|
||||||
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
return phase, &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
||||||
},
|
},
|
||||||
build: func(args any) CommandHandler {
|
build: func(args any) HandlerFunc {
|
||||||
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
|
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
|
||||||
to := []string{provider}
|
to := []string{provider}
|
||||||
|
|
||||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
|
var respBuf strings.Builder
|
||||||
|
|
||||||
err := titleTmpl.ExpandVars(w, r, respBuf)
|
_, err := titleTmpl.ExpandVars(w, r, &respBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
titleLen := respBuf.Len()
|
titleLen := respBuf.Len()
|
||||||
err = bodyTmpl.ExpandVars(w, r, respBuf)
|
_, err = bodyTmpl.ExpandVars(w, r, &respBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b := respBuf.Bytes()
|
s := respBuf.String()
|
||||||
notif.Notify(¬if.LogMessage{
|
notif.Notify(¬if.LogMessage{
|
||||||
Level: level,
|
Level: level,
|
||||||
Title: string(b[:titleLen]),
|
Title: s[:titleLen],
|
||||||
Body: notif.MessageBodyBytes(b[titleLen:]),
|
Body: notif.MessageBodyBytes(s[titleLen:]),
|
||||||
To: to,
|
To: to,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -519,121 +579,29 @@ type (
|
|||||||
|
|
||||||
// Parse implements strutils.Parser.
|
// Parse implements strutils.Parser.
|
||||||
func (cmd *Command) Parse(v string) error {
|
func (cmd *Command) Parse(v string) error {
|
||||||
executors := make([]CommandHandler, 0)
|
executors, parseErr := parseDoWithBlocks(v)
|
||||||
isResponseHandler := false
|
if parseErr != nil {
|
||||||
for line := range strings.SplitSeq(v, "\n") {
|
return parseErr
|
||||||
if line == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
directive, args, err := parse(line)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if directive == CommandPass || directive == CommandPassAlt {
|
|
||||||
if len(args) != 0 {
|
|
||||||
return ErrExpectNoArg
|
|
||||||
}
|
|
||||||
executors = append(executors, BypassCommand{})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
builder, ok := commands[directive]
|
|
||||||
if !ok {
|
|
||||||
return ErrUnknownDirective.Subject(directive)
|
|
||||||
}
|
|
||||||
validArgs, err := builder.validate(args)
|
|
||||||
if err != nil {
|
|
||||||
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
|
|
||||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := builder.build(validArgs)
|
|
||||||
executors = append(executors, handler)
|
|
||||||
if builder.isResponseHandler || handler.IsResponseHandler() {
|
|
||||||
isResponseHandler = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(executors) == 0 {
|
if len(executors) == 0 {
|
||||||
cmd.raw = v
|
cmd.raw = v
|
||||||
cmd.exec = nil
|
cmd.pre = nil
|
||||||
cmd.isResponseHandler = false
|
cmd.post = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
exec, err := buildCmd(executors)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.raw = v
|
cmd.raw = v
|
||||||
cmd.exec = exec
|
for _, executor := range executors {
|
||||||
if exec.IsResponseHandler() {
|
if executor.Phase().IsPostRule() {
|
||||||
isResponseHandler = true
|
cmd.post = append(cmd.post, executor)
|
||||||
|
} else {
|
||||||
|
cmd.pre = append(cmd.pre, executor)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cmd.isResponseHandler = isResponseHandler
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
|
|
||||||
// Validate the execution order.
|
|
||||||
//
|
|
||||||
// This allows sequences like:
|
|
||||||
// route ws-api
|
|
||||||
// log info /dev/stdout "..."
|
|
||||||
// where the first command is request-phase and the last is response-phase.
|
|
||||||
lastNonResp := -1
|
|
||||||
seenResp := false
|
|
||||||
for i, exec := range executors {
|
|
||||||
if exec.IsResponseHandler() {
|
|
||||||
seenResp = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if seenResp {
|
|
||||||
return nil, ErrInvalidCommandSequence.Withf("response handlers must be the last commands")
|
|
||||||
}
|
|
||||||
lastNonResp = i
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, exec := range executors {
|
|
||||||
if i > lastNonResp {
|
|
||||||
break // response-handler tail
|
|
||||||
}
|
|
||||||
switch exec.(type) {
|
|
||||||
case TerminatingCommand, BypassCommand:
|
|
||||||
if i != lastNonResp {
|
|
||||||
return nil, ErrInvalidCommandSequence.
|
|
||||||
Withf("a response handler or terminating/bypass command must be the last command")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Commands(executors), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Command is purely "bypass" or empty.
|
|
||||||
func (cmd *Command) isBypass() bool {
|
|
||||||
if cmd == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
switch cmd := cmd.exec.(type) {
|
|
||||||
case BypassCommand:
|
|
||||||
return true
|
|
||||||
case Commands:
|
|
||||||
// bypass command is always the last one
|
|
||||||
_, ok := cmd[len(cmd)-1].(BypassCommand)
|
|
||||||
return ok
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
return cmd.exec.Handle(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cmd *Command) String() string {
|
func (cmd *Command) String() string {
|
||||||
return cmd.raw
|
return cmd.raw
|
||||||
}
|
}
|
||||||
|
|||||||
436
internal/route/rules/do_blocks.go
Normal file
436
internal/route/rules/do_blocks.go
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IfBlockCommand is an inline conditional block inside a do-body.
|
||||||
|
//
|
||||||
|
// Syntax (within a rule do block):
|
||||||
|
//
|
||||||
|
// <on-expr> { <do...> }
|
||||||
|
//
|
||||||
|
// Semantics:
|
||||||
|
// - Evaluated in the same phase the parent rule runs.
|
||||||
|
// - If <on-expr> matches, run the nested commands in-order.
|
||||||
|
// - Otherwise do nothing.
|
||||||
|
//
|
||||||
|
// NOTE: Per current design decision, we keep this permissive:
|
||||||
|
// nested blocks may use response matchers and response commands; no extra phase validation is performed.
|
||||||
|
type IfBlockCommand struct {
|
||||||
|
On RuleOn
|
||||||
|
Do []CommandHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c IfBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
|
if c.Do == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// If On.checker is nil, treat as unconditional (should not happen if parsed).
|
||||||
|
if c.On.checker == nil {
|
||||||
|
return Commands(c.Do).ServeHTTP(w, r, upstream)
|
||||||
|
}
|
||||||
|
if c.On.checker.Check(w, r) {
|
||||||
|
return Commands(c.Do).ServeHTTP(w, r, upstream)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c IfBlockCommand) Phase() PhaseFlag {
|
||||||
|
phase := c.On.phase
|
||||||
|
for _, cmd := range c.Do {
|
||||||
|
phase |= cmd.Phase()
|
||||||
|
}
|
||||||
|
return phase
|
||||||
|
}
|
||||||
|
|
||||||
|
// IfElseBlockCommand is a chained conditional block inside a do-body.
|
||||||
|
//
|
||||||
|
// Syntax (within a rule do block):
|
||||||
|
//
|
||||||
|
// <on-expr> { <do...> } elif <on-expr> { <do...> } ... else { <do...> }
|
||||||
|
//
|
||||||
|
// NOTE: `elif`/`else` must appear on the same line as the preceding closing brace (`}`),
|
||||||
|
// e.g. `} elif ... {` and `} else {`.
|
||||||
|
type IfElseBlockCommand struct {
|
||||||
|
Ifs []IfBlockCommand
|
||||||
|
Else []CommandHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c IfElseBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
|
for _, br := range c.Ifs {
|
||||||
|
// If On.checker is nil, treat as unconditional.
|
||||||
|
if br.On.checker == nil {
|
||||||
|
if br.Do == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return Commands(br.Do).ServeHTTP(w, r, upstream)
|
||||||
|
}
|
||||||
|
if br.On.checker.Check(w, r) {
|
||||||
|
if br.Do == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return Commands(br.Do).ServeHTTP(w, r, upstream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(c.Else) > 0 {
|
||||||
|
return Commands(c.Else).ServeHTTP(w, r, upstream)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c IfElseBlockCommand) Phase() PhaseFlag {
|
||||||
|
phase := PhaseNone
|
||||||
|
for _, br := range c.Ifs {
|
||||||
|
phase |= br.Phase()
|
||||||
|
}
|
||||||
|
if len(c.Else) > 0 {
|
||||||
|
phase |= Commands(c.Else).Phase()
|
||||||
|
}
|
||||||
|
return phase
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipSameLineSpace(src string, pos int) int {
|
||||||
|
for pos < len(src) {
|
||||||
|
switch src[pos] {
|
||||||
|
case '\n':
|
||||||
|
return pos
|
||||||
|
case '\r':
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
case ' ', '\t':
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
return pos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAtBlockChain(src string, blockPos int) (CommandHandler, int, error) {
|
||||||
|
length := len(src)
|
||||||
|
headerStart := blockPos
|
||||||
|
|
||||||
|
parseBranch := func(onExpr string, bodyStart int, bodyEnd int) (RuleOn, []CommandHandler, error) {
|
||||||
|
var on RuleOn
|
||||||
|
if err := on.Parse(onExpr); err != nil {
|
||||||
|
return RuleOn{}, nil, err
|
||||||
|
}
|
||||||
|
innerSrc := ""
|
||||||
|
if bodyStart < bodyEnd {
|
||||||
|
innerSrc = src[bodyStart:bodyEnd]
|
||||||
|
}
|
||||||
|
inner, err := parseDoWithBlocks(innerSrc)
|
||||||
|
if err != nil {
|
||||||
|
return RuleOn{}, nil, err
|
||||||
|
}
|
||||||
|
if len(inner) == 0 {
|
||||||
|
return on, nil, nil
|
||||||
|
}
|
||||||
|
return on, inner, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
onExpr, bracePos, herr := parseHeaderToBrace(src, headerStart)
|
||||||
|
if herr != nil {
|
||||||
|
return nil, 0, herr
|
||||||
|
}
|
||||||
|
if onExpr == "" {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr before '{'")
|
||||||
|
}
|
||||||
|
if bracePos >= length || src[bracePos] != '{' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after nested block header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse first <on-expr> { ... }
|
||||||
|
p := bracePos
|
||||||
|
bodyStart := p + 1
|
||||||
|
bodyEnd, ferr := findMatchingBrace(src, &p, bodyStart)
|
||||||
|
if ferr != nil {
|
||||||
|
return nil, 0, ferr
|
||||||
|
}
|
||||||
|
firstOn, firstDo, berr := parseBranch(onExpr, bodyStart, bodyEnd)
|
||||||
|
if berr != nil {
|
||||||
|
return nil, 0, berr
|
||||||
|
}
|
||||||
|
|
||||||
|
ifs := []IfBlockCommand{{On: firstOn, Do: firstDo}}
|
||||||
|
var elseDo []CommandHandler
|
||||||
|
hasChain := false
|
||||||
|
hasElse := false
|
||||||
|
|
||||||
|
for {
|
||||||
|
q := skipSameLineSpace(src, p)
|
||||||
|
if q >= length || src[q] == '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// elif <on-expr> { ... }
|
||||||
|
if strings.HasPrefix(src[q:], "elif") {
|
||||||
|
next := q + len("elif")
|
||||||
|
if next >= length {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||||
|
}
|
||||||
|
if src[next] == '\n' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||||
|
}
|
||||||
|
if !unicode.IsSpace(rune(src[next])) {
|
||||||
|
if src[next] == '{' || src[next] == '}' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||||
|
}
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected whitespace after 'elif'")
|
||||||
|
}
|
||||||
|
next++
|
||||||
|
for next < length {
|
||||||
|
c := src[next]
|
||||||
|
if c == '\n' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
|
||||||
|
}
|
||||||
|
if c == '\r' {
|
||||||
|
next++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !unicode.IsSpace(rune(c)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next++
|
||||||
|
}
|
||||||
|
|
||||||
|
p2 := next
|
||||||
|
elifOnExpr, bracePos, herr := parseHeaderToBrace(src, p2)
|
||||||
|
if herr != nil {
|
||||||
|
return nil, 0, herr
|
||||||
|
}
|
||||||
|
if elifOnExpr == "" {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||||
|
}
|
||||||
|
if bracePos >= length || src[bracePos] != '{' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
|
||||||
|
}
|
||||||
|
p2 = bracePos
|
||||||
|
elifBodyStart := p2 + 1
|
||||||
|
elifBodyEnd, ferr := findMatchingBrace(src, &p2, elifBodyStart)
|
||||||
|
if ferr != nil {
|
||||||
|
return nil, 0, ferr
|
||||||
|
}
|
||||||
|
elifOn, elifDo, berr := parseBranch(elifOnExpr, elifBodyStart, elifBodyEnd)
|
||||||
|
if berr != nil {
|
||||||
|
return nil, 0, berr
|
||||||
|
}
|
||||||
|
ifs = append(ifs, IfBlockCommand{On: elifOn, Do: elifDo})
|
||||||
|
hasChain = true
|
||||||
|
p = p2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// else { ... }
|
||||||
|
if strings.HasPrefix(src[q:], "else") {
|
||||||
|
if hasElse {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("multiple 'else' branches")
|
||||||
|
}
|
||||||
|
next := q + len("else")
|
||||||
|
for next < length {
|
||||||
|
c := src[next]
|
||||||
|
if c == '\n' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
|
||||||
|
}
|
||||||
|
if c == '\r' {
|
||||||
|
next++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !unicode.IsSpace(rune(c)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next++
|
||||||
|
}
|
||||||
|
if next >= length || src[next] != '{' {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
|
||||||
|
}
|
||||||
|
|
||||||
|
elseBodyStart := next + 1
|
||||||
|
p2 := next
|
||||||
|
elseBodyEnd, ferr := findMatchingBrace(src, &p2, elseBodyStart)
|
||||||
|
if ferr != nil {
|
||||||
|
return nil, 0, ferr
|
||||||
|
}
|
||||||
|
innerSrc := ""
|
||||||
|
if elseBodyStart < elseBodyEnd {
|
||||||
|
innerSrc = src[elseBodyStart:elseBodyEnd]
|
||||||
|
}
|
||||||
|
inner, ierr := parseDoWithBlocks(innerSrc)
|
||||||
|
if ierr != nil {
|
||||||
|
return nil, 0, ierr
|
||||||
|
}
|
||||||
|
if len(inner) == 0 {
|
||||||
|
elseDo = nil
|
||||||
|
} else {
|
||||||
|
elseDo = inner
|
||||||
|
}
|
||||||
|
hasChain = true
|
||||||
|
hasElse = true
|
||||||
|
p = p2
|
||||||
|
|
||||||
|
// else must be the last branch on that line.
|
||||||
|
for q2 := skipSameLineSpace(src, p); q2 < length && src[q2] != '\n'; q2 = skipSameLineSpace(src, q2) {
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after else block")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after nested block; expected 'elif'/'else' or newline")
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasChain {
|
||||||
|
return IfElseBlockCommand{Ifs: ifs, Else: elseDo}, p, nil
|
||||||
|
}
|
||||||
|
return IfBlockCommand{On: ifs[0].On, Do: ifs[0].Do}, p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func lineEndsWithUnquotedOpenBrace(src string, lineStart int, lineEnd int) bool {
|
||||||
|
quote := byte(0)
|
||||||
|
lastSignificant := byte(0)
|
||||||
|
atLineStart := true
|
||||||
|
prevIsSpace := true
|
||||||
|
|
||||||
|
for i := lineStart; i < lineEnd; i++ {
|
||||||
|
c := src[i]
|
||||||
|
if quote != 0 {
|
||||||
|
if c == '\\' && i+1 < lineEnd {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if quoteChars[c] {
|
||||||
|
quote = c
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '#' && (atLineStart || prevIsSpace) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if c == '/' && i+1 < lineEnd {
|
||||||
|
n := rune(src[i+1])
|
||||||
|
if (atLineStart || prevIsSpace) && (n == '/' || n == '*') {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if unicode.IsSpace(rune(c)) {
|
||||||
|
prevIsSpace = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lastSignificant = c
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = false
|
||||||
|
}
|
||||||
|
return quote == 0 && lastSignificant == '{'
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDoWithBlocks parses a do-body containing plain command lines and nested blocks.
|
||||||
|
// It returns the outer command handlers and the require phase.
|
||||||
|
//
|
||||||
|
// A nested block is recognized when a line ends with an unquoted '{' (ignoring trailing whitespace).
|
||||||
|
func parseDoWithBlocks(src string) (handlers []CommandHandler, err error) {
|
||||||
|
pos := 0
|
||||||
|
length := len(src)
|
||||||
|
lineStart := true
|
||||||
|
handlers = make([]CommandHandler, 0, strings.Count(src, "\n")+1)
|
||||||
|
|
||||||
|
appendLineCommand := func(line string) error {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
directive, args, err := parse(line)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
builder, ok := commands[directive]
|
||||||
|
if !ok {
|
||||||
|
return ErrUnknownDirective.Subject(directive)
|
||||||
|
}
|
||||||
|
|
||||||
|
phase, validArgs, err := builder.validate(args)
|
||||||
|
if err != nil {
|
||||||
|
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
h := builder.build(validArgs)
|
||||||
|
handlers = append(handlers, Handler{fn: h, phase: phase, terminate: builder.terminate})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for pos < length {
|
||||||
|
// Handle newlines
|
||||||
|
switch src[pos] {
|
||||||
|
case '\n':
|
||||||
|
pos++
|
||||||
|
lineStart = true
|
||||||
|
continue
|
||||||
|
case '\r':
|
||||||
|
// tolerate CRLF
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if lineStart {
|
||||||
|
// Find first non-space on the line.
|
||||||
|
linePos := pos
|
||||||
|
for linePos < length {
|
||||||
|
c := rune(src[linePos])
|
||||||
|
if c == '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !unicode.IsSpace(c) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
linePos++
|
||||||
|
}
|
||||||
|
|
||||||
|
lineEnd := linePos
|
||||||
|
for lineEnd < length && src[lineEnd] != '\n' {
|
||||||
|
lineEnd++
|
||||||
|
}
|
||||||
|
|
||||||
|
if linePos < length && lineEndsWithUnquotedOpenBrace(src, linePos, lineEnd) {
|
||||||
|
h, next, err := parseAtBlockChain(src, linePos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
handlers = append(handlers, h)
|
||||||
|
pos = next
|
||||||
|
lineStart = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a nested block; parse the rest of this line as a command.
|
||||||
|
if lerr := appendLineCommand(src[pos:lineEnd]); lerr != nil {
|
||||||
|
return nil, lerr
|
||||||
|
}
|
||||||
|
pos = lineEnd
|
||||||
|
lineStart = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not at line start; advance to the next line boundary.
|
||||||
|
for pos < length && src[pos] != '\n' {
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
lineStart = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return handlers, nil
|
||||||
|
}
|
||||||
73
internal/route/rules/do_blocks_test.go
Normal file
73
internal/route/rules/do_blocks_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoNotFallsThrough(t *testing.T) {
|
||||||
|
elseCalled := false
|
||||||
|
cmd := IfElseBlockCommand{
|
||||||
|
Ifs: []IfBlockCommand{
|
||||||
|
{
|
||||||
|
On: RuleOn{},
|
||||||
|
Do: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Else: []CommandHandler{
|
||||||
|
Handler{
|
||||||
|
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
|
||||||
|
elseCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
phase: PhaseNone,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
|
err := cmd.ServeHTTP(rm, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, elseCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoNotFallsThrough(t *testing.T) {
|
||||||
|
elseCalled := false
|
||||||
|
cmd := IfElseBlockCommand{
|
||||||
|
Ifs: []IfBlockCommand{
|
||||||
|
{
|
||||||
|
On: RuleOn{
|
||||||
|
checker: CheckFunc(func(_ *httputils.ResponseModifier, _ *http.Request) bool {
|
||||||
|
return true
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
Do: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Else: []CommandHandler{
|
||||||
|
Handler{
|
||||||
|
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
|
||||||
|
elseCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
phase: PhaseNone,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
rm := httputils.NewResponseModifier(w)
|
||||||
|
|
||||||
|
err := cmd.ServeHTTP(rm, req, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, elseCalled)
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -8,6 +9,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -37,7 +39,7 @@ func parseRules(data string, target *Rules) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLogCommand_TemporaryFile(t *testing.T) {
|
func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||||
upstream := mockUpstreamWithHeaders(200, "success response", http.Header{
|
upstream := mockUpstreamWithHeaders(http.StatusOK, "success response", http.Header{
|
||||||
"Content-Type": []string{"application/json"},
|
"Content-Type": []string{"application/json"},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -45,10 +47,9 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
|||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(fmt.Sprintf(`
|
err := parseRules(fmt.Sprintf(`
|
||||||
- name: log-request-response
|
default {
|
||||||
do: |
|
|
||||||
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
||||||
`, logFile), &rules)
|
}`, logFile), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(upstream)
|
handler := rules.BuildHandler(upstream)
|
||||||
@@ -59,7 +60,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
|||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, "success response", w.Body.String())
|
assert.Equal(t, "success response", w.Body.String())
|
||||||
|
|
||||||
// Read and verify log content
|
// Read and verify log content
|
||||||
@@ -70,16 +71,25 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||||
|
originalStdout := stdout
|
||||||
|
originalStderr := stderr
|
||||||
|
var stdoutBuf bytes.Buffer
|
||||||
|
var stderrBuf bytes.Buffer
|
||||||
|
stdout = noopWriteCloser{&stdoutBuf}
|
||||||
|
stderr = noopWriteCloser{&stderrBuf}
|
||||||
|
defer func() {
|
||||||
|
stdout = originalStdout
|
||||||
|
stderr = originalStderr
|
||||||
|
}()
|
||||||
|
|
||||||
upstream := mockUpstream(http.StatusOK, "success")
|
upstream := mockUpstream(http.StatusOK, "success")
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(`
|
err := parseRules(`
|
||||||
- name: log-stdout
|
default {
|
||||||
do: |
|
|
||||||
log info /dev/stdout "stdout: $req_method $status_code"
|
log info /dev/stdout "stdout: $req_method $status_code"
|
||||||
- name: log-stderr
|
|
||||||
do: |
|
|
||||||
log error /dev/stderr "stderr: $req_path $status_code"
|
log error /dev/stderr "stderr: $req_path $status_code"
|
||||||
|
}
|
||||||
`, &rules)
|
`, &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -90,9 +100,13 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
|||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 200, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
// Note: We can't easily capture stdout/stderr in unit tests,
|
require.Eventually(t, func() bool {
|
||||||
// but we can verify no errors occurred and the handler completed
|
return strings.Contains(stdoutBuf.String(), "stdout: GET 200") &&
|
||||||
|
strings.Contains(stderrBuf.String(), "stderr: /test 200")
|
||||||
|
}, time.Second, 10*time.Millisecond)
|
||||||
|
assert.Equal(t, 1, strings.Count(stdoutBuf.String(), "stdout: GET 200"))
|
||||||
|
assert.Equal(t, 1, strings.Count(stderrBuf.String(), "stderr: /test 200"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||||
@@ -104,26 +118,22 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
|||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(fmt.Sprintf(`
|
err := parseRules(fmt.Sprintf(`
|
||||||
- name: log-info
|
default {
|
||||||
do: |
|
|
||||||
log info %s "INFO: $req_method $status_code"
|
log info %s "INFO: $req_method $status_code"
|
||||||
- name: log-warn
|
|
||||||
do: |
|
|
||||||
log warn %s "WARN: $req_path $status_code"
|
log warn %s "WARN: $req_path $status_code"
|
||||||
- name: log-error
|
|
||||||
do: |
|
|
||||||
log error %s "ERROR: $req_method $req_path $status_code"
|
log error %s "ERROR: $req_method $req_path $status_code"
|
||||||
|
}
|
||||||
`, infoFile, warnFile, errorFile), &rules)
|
`, infoFile, warnFile, errorFile), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(upstream)
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
req := httptest.NewRequest("DELETE", "/api/resource/123", nil)
|
req := httptest.NewRequest(http.MethodDelete, "/api/resource/123", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 404, w.Code)
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
|
||||||
// Verify each log file
|
// Verify each log file
|
||||||
infoContent := TestFileContent(infoFile)
|
infoContent := TestFileContent(infoFile)
|
||||||
@@ -148,22 +158,22 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
|
|||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(fmt.Sprintf(`
|
err := parseRules(fmt.Sprintf(`
|
||||||
- name: log-with-templates
|
default {
|
||||||
do: |
|
|
||||||
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
||||||
|
}
|
||||||
`, tempFile), &rules)
|
`, tempFile), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(upstream)
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
req := httptest.NewRequest("PUT", "/api/resource", nil)
|
req := httptest.NewRequest(http.MethodPut, "/api/resource", nil)
|
||||||
req.Header.Set("User-Agent", "test-client/1.0")
|
req.Header.Set("User-Agent", "test-client/1.0")
|
||||||
req.Host = "example.com"
|
req.Host = "example.com"
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, 201, w.Code)
|
assert.Equal(t, http.StatusCreated, w.Code)
|
||||||
|
|
||||||
// Verify log content
|
// Verify log content
|
||||||
content := TestFileContent(tempFile)
|
content := TestFileContent(tempFile)
|
||||||
@@ -192,14 +202,12 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
|||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(fmt.Sprintf(`
|
err := parseRules(fmt.Sprintf(`
|
||||||
- name: log-success
|
status 2xx {
|
||||||
on: status 2xx
|
|
||||||
do: |
|
|
||||||
log info %q "SUCCESS: $req_method $req_path $status_code"
|
log info %q "SUCCESS: $req_method $req_path $status_code"
|
||||||
- name: log-error
|
}
|
||||||
on: status 4xx | status 5xx
|
status 4xx | status 5xx {
|
||||||
do: |
|
|
||||||
log error %q "ERROR: $req_method $req_path $status_code"
|
log error %q "ERROR: $req_method $req_path $status_code"
|
||||||
|
}
|
||||||
`, successFile, errorFile), &rules)
|
`, successFile, errorFile), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -244,9 +252,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
|||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(fmt.Sprintf(`
|
err := parseRules(fmt.Sprintf(`
|
||||||
- name: log-multiple
|
default {
|
||||||
do: |
|
log info %q "$req_method $req_path $status_code"
|
||||||
log info %q "$req_method $req_path $status_code"`, tempFile), &rules)
|
}
|
||||||
|
`, tempFile), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(upstream)
|
handler := rules.BuildHandler(upstream)
|
||||||
@@ -256,10 +265,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
|||||||
method string
|
method string
|
||||||
path string
|
path string
|
||||||
}{
|
}{
|
||||||
{"GET", "/users"},
|
{http.MethodGet, "/users"},
|
||||||
{"POST", "/users"},
|
{http.MethodPost, "/users"},
|
||||||
{"PUT", "/users/1"},
|
{http.MethodPost, "/users/1"},
|
||||||
{"DELETE", "/users/1"},
|
{http.MethodDelete, "/users/1"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, reqInfo := range requests {
|
for _, reqInfo := range requests {
|
||||||
@@ -287,8 +296,9 @@ func TestLogCommand_InvalidTemplate(t *testing.T) {
|
|||||||
|
|
||||||
// Test with invalid template syntax
|
// Test with invalid template syntax
|
||||||
err := parseRules(`
|
err := parseRules(`
|
||||||
- name: log-invalid
|
default {
|
||||||
do: |
|
log info /dev/stdout "$invalid_var"
|
||||||
log info /dev/stdout "$invalid_var"`, &rules)
|
}
|
||||||
assert.ErrorIs(t, err, ErrUnexpectedVar)
|
`, &rules)
|
||||||
|
require.ErrorIs(t, err, ErrUnexpectedVar)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
FieldHandler struct {
|
FieldHandler struct {
|
||||||
set, add, remove CommandHandler
|
set, add, remove HandlerFunc
|
||||||
}
|
}
|
||||||
FieldModifier string
|
FieldModifier string
|
||||||
)
|
)
|
||||||
@@ -49,30 +49,30 @@ var modFields = map[string]struct {
|
|||||||
"value": "the header template",
|
"value": "the header template",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKeyValueTemplate,
|
validate: validatePreRequestKVTemplate,
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.Header[k] = []string{v}
|
r.Header[k] = []string{v}
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
r.Header[k] = append(r.Header[k], v)
|
r.Header[k] = append(r.Header[k], v)
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
delete(r.Header, k)
|
delete(r.Header, k)
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -84,30 +84,30 @@ var modFields = map[string]struct {
|
|||||||
"value": "the response header template",
|
"value": "the response header template",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKeyValueTemplate,
|
validate: validatePostResponseKVTemplate,
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header()[k] = []string{v}
|
w.Header()[k] = []string{v}
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.Header()[k] = append(w.Header()[k], v)
|
w.Header()[k] = append(w.Header()[k], v)
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
delete(w.Header(), k)
|
delete(w.Header(), k)
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -119,36 +119,36 @@ var modFields = map[string]struct {
|
|||||||
"value": "the query template",
|
"value": "the query template",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKeyValueTemplate,
|
validate: validatePreRequestKVTemplate,
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||||
queries.Set(k, v)
|
queries.Set(k, v)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||||
queries.Add(k, v)
|
queries.Add(k, v)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||||
queries.Del(k)
|
queries.Del(k)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -160,16 +160,16 @@ var modFields = map[string]struct {
|
|||||||
"value": "the cookie value",
|
"value": "the cookie value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKeyValueTemplate,
|
validate: validatePreRequestKVTemplate,
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
for i, c := range cookies {
|
for i, c := range cookies {
|
||||||
if c.Name == k {
|
if c.Name == k {
|
||||||
cookies[i].Value = v
|
cookies[i].Value = v
|
||||||
@@ -179,19 +179,19 @@ var modFields = map[string]struct {
|
|||||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
v, err := tmpl.ExpandVarsToString(w, r)
|
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||||
index := -1
|
index := -1
|
||||||
for i, c := range cookies {
|
for i, c := range cookies {
|
||||||
if c.Name == k {
|
if c.Name == k {
|
||||||
@@ -208,7 +208,7 @@ var modFields = map[string]struct {
|
|||||||
return cookies
|
return cookies
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -227,24 +227,27 @@ var modFields = map[string]struct {
|
|||||||
"template": "the body template",
|
"template": "the body template",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return 0, nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
return validateTemplate(args[0], true)
|
phase = PhasePre
|
||||||
|
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
|
||||||
|
phase |= tmplReq
|
||||||
|
return
|
||||||
},
|
},
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
tmpl := args.(templateString)
|
tmpl := args.(templateString)
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
if r.Body != nil {
|
if r.Body != nil {
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
r.Body = nil
|
r.Body = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
bufPool := httputils.GetInitResponseModifier(w).BufPool()
|
bufPool := w.BufPool()
|
||||||
b := bufPool.GetBuffer()
|
b := bufPool.GetBuffer()
|
||||||
err := tmpl.ExpandVars(w, r, b)
|
_, err := tmpl.ExpandVars(w, r, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -252,7 +255,7 @@ var modFields = map[string]struct {
|
|||||||
bufPool.PutBuffer(b)
|
bufPool.PutBuffer(b)
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -272,20 +275,26 @@ var modFields = map[string]struct {
|
|||||||
"template": "the response body template",
|
"template": "the response body template",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return 0, nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
return validateTemplate(args[0], true)
|
phase = PhasePost
|
||||||
|
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
|
||||||
|
phase |= tmplReq
|
||||||
|
return
|
||||||
},
|
},
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
tmpl := args.(templateString)
|
tmpl := args.(templateString)
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
rm := httputils.GetInitResponseModifier(w)
|
w.ResetBody()
|
||||||
rm.ResetBody()
|
_, err := tmpl.ExpandVars(w, r, w)
|
||||||
return tmpl.ExpandVars(w, r, rm)
|
if err != nil {
|
||||||
}),
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -300,26 +309,27 @@ var modFields = map[string]struct {
|
|||||||
"code": "the status code",
|
"code": "the status code",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return phase, nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
|
phase = PhasePost
|
||||||
status, err := strconv.Atoi(args[0])
|
status, err := strconv.Atoi(args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidArguments.With(err)
|
return phase, nil, ErrInvalidArguments.With(err)
|
||||||
}
|
}
|
||||||
if status < 100 || status > 599 {
|
if status < 100 || status > 599 {
|
||||||
return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
return phase, nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
||||||
}
|
}
|
||||||
return status, nil
|
return phase, status, nil
|
||||||
},
|
},
|
||||||
builder: func(args any) *FieldHandler {
|
builder: func(args any) *FieldHandler {
|
||||||
status := args.(int)
|
status := args.(int)
|
||||||
return &FieldHandler{
|
return &FieldHandler{
|
||||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||||
httputils.GetInitResponseModifier(w).WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
return nil
|
return nil
|
||||||
}),
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -72,12 +71,12 @@ func TestFieldHandler_Header(t *testing.T) {
|
|||||||
tt.setup(req)
|
tt.setup(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.value, false)
|
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
t.Fatalf("Failed to validate template: %v", tErr)
|
t.Fatalf("Failed to validate template: %v", tErr)
|
||||||
}
|
}
|
||||||
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||||
var cmd CommandHandler
|
var cmd HandlerFunc
|
||||||
switch tt.modifier {
|
switch tt.modifier {
|
||||||
case ModFieldSet:
|
case ModFieldSet:
|
||||||
cmd = handler.set
|
cmd = handler.set
|
||||||
@@ -87,7 +86,7 @@ func TestFieldHandler_Header(t *testing.T) {
|
|||||||
cmd = handler.remove
|
cmd = handler.remove
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cmd.Handle(w, req)
|
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -150,12 +149,12 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
|||||||
tt.setup(w)
|
tt.setup(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.value, false)
|
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
t.Fatalf("Failed to validate template: %v", tErr)
|
t.Fatalf("Failed to validate template: %v", tErr)
|
||||||
}
|
}
|
||||||
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||||
var cmd CommandHandler
|
var cmd HandlerFunc
|
||||||
switch tt.modifier {
|
switch tt.modifier {
|
||||||
case ModFieldSet:
|
case ModFieldSet:
|
||||||
cmd = handler.set
|
cmd = handler.set
|
||||||
@@ -165,7 +164,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
|||||||
cmd = handler.remove
|
cmd = handler.remove
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cmd.Handle(w, req)
|
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -237,12 +236,12 @@ func TestFieldHandler_Query(t *testing.T) {
|
|||||||
tt.setup(req)
|
tt.setup(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.value, false)
|
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
t.Fatalf("Failed to validate template: %v", tErr)
|
t.Fatalf("Failed to validate template: %v", tErr)
|
||||||
}
|
}
|
||||||
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
|
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
|
||||||
var cmd CommandHandler
|
var cmd HandlerFunc
|
||||||
switch tt.modifier {
|
switch tt.modifier {
|
||||||
case ModFieldSet:
|
case ModFieldSet:
|
||||||
cmd = handler.set
|
cmd = handler.set
|
||||||
@@ -252,7 +251,7 @@ func TestFieldHandler_Query(t *testing.T) {
|
|||||||
cmd = handler.remove
|
cmd = handler.remove
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cmd.Handle(w, req)
|
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -335,12 +334,12 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
|||||||
tt.setup(req)
|
tt.setup(req)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.value, false)
|
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
t.Fatalf("Failed to validate template: %v", tErr)
|
t.Fatalf("Failed to validate template: %v", tErr)
|
||||||
}
|
}
|
||||||
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
|
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
|
||||||
var cmd CommandHandler
|
var cmd HandlerFunc
|
||||||
switch tt.modifier {
|
switch tt.modifier {
|
||||||
case ModFieldSet:
|
case ModFieldSet:
|
||||||
cmd = handler.set
|
cmd = handler.set
|
||||||
@@ -350,7 +349,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
|||||||
cmd = handler.remove
|
cmd = handler.remove
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cmd.Handle(w, req)
|
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -371,7 +370,7 @@ func TestFieldHandler_Body(t *testing.T) {
|
|||||||
name: "set body with template",
|
name: "set body with template",
|
||||||
template: "Hello $req_method $req_path",
|
template: "Hello $req_method $req_path",
|
||||||
setup: func(r *http.Request) {
|
setup: func(r *http.Request) {
|
||||||
r.Method = "POST"
|
r.Method = http.MethodPost
|
||||||
r.URL.Path = "/test"
|
r.URL.Path = "/test"
|
||||||
},
|
},
|
||||||
verify: func(r *http.Request) {
|
verify: func(r *http.Request) {
|
||||||
@@ -399,15 +398,15 @@ func TestFieldHandler_Body(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
tt.setup(req)
|
tt.setup(req)
|
||||||
w := httptest.NewRecorder()
|
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.template, false)
|
_, tmpl, tErr := validateTemplate(tt.template, false)
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
t.Fatalf("Failed to parse template: %v", tErr)
|
t.Fatalf("Failed to parse template: %v", tErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := modFields[FieldBody].builder(tmpl)
|
handler := modFields[FieldBody].builder(tmpl)
|
||||||
err := handler.set.Handle(w, req)
|
err := handler.set(w, req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -428,7 +427,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
|||||||
name: "set response body with template",
|
name: "set response body with template",
|
||||||
template: "Response: $req_method $req_path",
|
template: "Response: $req_method $req_path",
|
||||||
setup: func(r *http.Request) {
|
setup: func(r *http.Request) {
|
||||||
r.Method = "GET"
|
r.Method = http.MethodGet
|
||||||
r.URL.Path = "/api/test"
|
r.URL.Path = "/api/test"
|
||||||
},
|
},
|
||||||
verify: func(rm *httputils.ResponseModifier) {
|
verify: func(rm *httputils.ResponseModifier) {
|
||||||
@@ -443,23 +442,20 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
tt.setup(req)
|
tt.setup(req)
|
||||||
w := httptest.NewRecorder()
|
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
// Create ResponseModifier wrapper
|
_, tmpl, tErr := validateTemplate(tt.template, false)
|
||||||
rm := httputils.NewResponseModifier(w)
|
|
||||||
|
|
||||||
tmpl, tErr := validateTemplate(tt.template, false)
|
|
||||||
if tErr != nil {
|
if tErr != nil {
|
||||||
t.Fatalf("Failed to parse template: %v", tErr)
|
t.Fatalf("Failed to parse template: %v", tErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := modFields[FieldResponseBody].builder(tmpl)
|
handler := modFields[FieldResponseBody].builder(tmpl)
|
||||||
err := handler.set.Handle(rm, req)
|
err := handler.set(w, req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tt.verify(rm)
|
tt.verify(w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -472,23 +468,23 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "set status code 200",
|
name: "set status code 200",
|
||||||
status: 200,
|
status: http.StatusOK,
|
||||||
verify: func(w *httptest.ResponseRecorder) {
|
verify: func(w *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, 200, w.Code, "Expected status code 200")
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set status code 404",
|
name: "set status code 404",
|
||||||
status: 404,
|
status: http.StatusNotFound,
|
||||||
verify: func(w *httptest.ResponseRecorder) {
|
verify: func(w *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, 404, w.Code, "Expected status code 404")
|
assert.Equal(t, http.StatusNotFound, w.Code, "Expected status code 404")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set status code 500",
|
name: "set status code 500",
|
||||||
status: 500,
|
status: http.StatusInternalServerError,
|
||||||
verify: func(w *httptest.ResponseRecorder) {
|
verify: func(w *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, 500, w.Code, "Expected status code 500")
|
assert.Equal(t, http.StatusInternalServerError, w.Code, "Expected status code 500")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -503,12 +499,11 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
err = cmd.ServeHTTP(rm, req)
|
err = cmd.post.ServeHTTP(rm, req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Handler returned error: %v", err)
|
t.Fatalf("Handler returned error: %v", err)
|
||||||
}
|
}
|
||||||
rm.FlushRelease()
|
rm.FlushRelease()
|
||||||
|
|
||||||
tt.verify(w)
|
tt.verify(w)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -600,7 +595,7 @@ func TestFieldValidation(t *testing.T) {
|
|||||||
field, exists := modFields[tt.field]
|
field, exists := modFields[tt.field]
|
||||||
assert.True(t, exists, "Field %s does not exist", tt.field)
|
assert.True(t, exists, "Field %s does not exist", tt.field)
|
||||||
|
|
||||||
_, err := field.validate(tt.args)
|
_, _, err := field.validate(tt.args)
|
||||||
if tt.wantError {
|
if tt.wantError {
|
||||||
assert.Error(t, err, "Expected error but got none")
|
assert.Error(t, err, "Expected error but got none")
|
||||||
} else {
|
} else {
|
||||||
@@ -610,25 +605,6 @@ func TestFieldValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllFields(t *testing.T) {
|
|
||||||
expectedFields := []string{
|
|
||||||
FieldHeader,
|
|
||||||
FieldResponseHeader,
|
|
||||||
FieldQuery,
|
|
||||||
FieldCookie,
|
|
||||||
FieldBody,
|
|
||||||
FieldResponseBody,
|
|
||||||
FieldStatusCode,
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
|
|
||||||
|
|
||||||
for _, expected := range expectedFields {
|
|
||||||
found := slices.Contains(AllFields, expected)
|
|
||||||
assert.True(t, found, "Expected field %s not found in AllFields", expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModFields(t *testing.T) {
|
func TestModFields(t *testing.T) {
|
||||||
for fieldName, field := range modFields {
|
for fieldName, field := range modFields {
|
||||||
// Test that each field has required components
|
// Test that each field has required components
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ var (
|
|||||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
|
||||||
ErrMultipleDefaultRules = gperr.New("multiple default rules")
|
ErrMultipleDefaultRules = gperr.New("multiple default rules")
|
||||||
|
ErrDeadRule = gperr.New("dead rule")
|
||||||
|
|
||||||
// vars errors
|
// vars errors
|
||||||
ErrNoArgProvided = gperr.New("no argument provided")
|
ErrNoArgProvided = gperr.New("no argument provided")
|
||||||
@@ -31,5 +32,5 @@ var (
|
|||||||
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
|
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
|
||||||
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
|
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
|
||||||
|
|
||||||
errTerminated = gperr.New("terminated")
|
ErrInvalidBlockSyntax = gperr.New("invalid block syntax") // TODO: struct this error
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -131,12 +131,13 @@ Error generates help string as error, e.g.
|
|||||||
from: the path to rewrite, must start with /
|
from: the path to rewrite, must start with /
|
||||||
to: the path to rewrite to, must start with /
|
to: the path to rewrite to, must start with /
|
||||||
*/
|
*/
|
||||||
func (h *Help) Error() error {
|
func (h *Help) Error() gperr.Error {
|
||||||
var lines gperr.MultilineError
|
help := gperr.New(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||||
|
for _, line := range h.description {
|
||||||
|
help = help.Withf("%s", line)
|
||||||
|
}
|
||||||
|
|
||||||
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
args := gperr.New("args")
|
||||||
lines.AddStrings(h.description...)
|
|
||||||
lines.Adds(" args:")
|
|
||||||
|
|
||||||
argKeys := make([]string, 0, len(h.args))
|
argKeys := make([]string, 0, len(h.args))
|
||||||
longestArg := 0
|
longestArg := 0
|
||||||
@@ -151,7 +152,9 @@ func (h *Help) Error() error {
|
|||||||
slices.Sort(argKeys)
|
slices.Sort(argKeys)
|
||||||
for _, arg := range argKeys {
|
for _, arg := range argKeys {
|
||||||
desc := h.args[arg]
|
desc := h.args[arg]
|
||||||
lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc)
|
paddedArg := fmt.Sprintf("%-"+strconv.Itoa(longestArg)+"s", arg)
|
||||||
|
args = args.Withf("%s%s", ansi.WithANSI(paddedArg, ansi.HighlightCyan)+": ", desc)
|
||||||
}
|
}
|
||||||
return &lines
|
|
||||||
|
return help.With(args)
|
||||||
}
|
}
|
||||||
|
|||||||
1356
internal/route/rules/http_flow_block_test.go
Normal file
1356
internal/route/rules/http_flow_block_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// mockUpstream creates a simple upstream handler for testing
|
// mockUpstream creates a simple upstream handler for testing
|
||||||
func mockUpstream(body string) http.HandlerFunc {
|
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(status)
|
||||||
w.Write([]byte(body))
|
w.Write([]byte(body))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -47,7 +48,7 @@ func parseRules(data string, target *Rules) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
func TestHTTPFlow_BasicPreRulesYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -74,8 +75,8 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
|||||||
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_BypassRule(t *testing.T) {
|
func TestHTTPFlow_BypassRuleYAML(t *testing.T) {
|
||||||
upstream := mockUpstream("upstream response")
|
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(`
|
err := parseRules(`
|
||||||
@@ -99,8 +100,8 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
|
|||||||
assert.Equal(t, "upstream response", w.Body.String())
|
assert.Equal(t, "upstream response", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
func TestHTTPFlow_TerminatingCommandYAML(t *testing.T) {
|
||||||
upstream := mockUpstream("should not be called")
|
upstream := mockUpstream(http.StatusOK, "should not be called")
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(`
|
err := parseRules(`
|
||||||
@@ -120,13 +121,13 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
|||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
assert.Equal(t, 403, w.Code)
|
||||||
assert.Equal(t, "Forbidden\n", w.Body.String())
|
assert.Equal(t, "Forbidden\n", w.Body.String())
|
||||||
assert.Empty(t, w.Header().Get("X-Header"))
|
assert.Empty(t, w.Header().Get("X-Header"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
func TestHTTPFlow_RedirectFlowYAML(t *testing.T) {
|
||||||
upstream := mockUpstream("should not be called")
|
upstream := mockUpstream(http.StatusOK, "should not be called")
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(`
|
err := parseRules(`
|
||||||
@@ -143,11 +144,11 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
|||||||
|
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
|
assert.Equal(t, 307, w.Code) // TemporaryRedirect
|
||||||
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
func TestHTTPFlow_RewriteFlowYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("path: " + r.URL.Path))
|
w.Write([]byte("path: " + r.URL.Path))
|
||||||
@@ -172,7 +173,7 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
|||||||
assert.Equal(t, "path: /v1/users", w.Body.String())
|
assert.Equal(t, "path: /v1/users", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
func TestHTTPFlow_MultiplePreRulesYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
||||||
@@ -201,7 +202,7 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
|||||||
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
func TestHTTPFlow_PostResponseRuleYAML(t *testing.T) {
|
||||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
|
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
|
||||||
"X-Upstream": []string{"upstream-value"},
|
"X-Upstream": []string{"upstream-value"},
|
||||||
})
|
})
|
||||||
@@ -229,11 +230,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
|||||||
|
|
||||||
// Check log file
|
// Check log file
|
||||||
content := TestFileContent(tempFile)
|
content := TestFileContent(tempFile)
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, "GET 200\n", string(content))
|
assert.Equal(t, "GET 200\n", string(content))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
func TestHTTPFlow_ResponseRuleWithStatusConditionYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/success" {
|
if r.URL.Path == "/success" {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -246,14 +246,15 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
|||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
|
|
||||||
// Create a temporary file for logging
|
errorLog := TestRandomFileName()
|
||||||
tempFile := TestRandomFileName()
|
infoLog := TestRandomFileName()
|
||||||
|
|
||||||
err := parseRules(fmt.Sprintf(`
|
err := parseRules(fmt.Sprintf(`
|
||||||
- name: log-errors
|
- on: status 4xx
|
||||||
on: status 4xx
|
|
||||||
do: log error %s "$req_url returned $status_code"
|
do: log error %s "$req_url returned $status_code"
|
||||||
`, tempFile), &rules)
|
- on: status 200
|
||||||
|
do: log info %s "$req_url returned $status_code"
|
||||||
|
`, errorLog, infoLog), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(upstream)
|
handler := rules.BuildHandler(upstream)
|
||||||
@@ -273,14 +274,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||||
|
|
||||||
// Check log file
|
// Check log file
|
||||||
content := TestFileContent(tempFile)
|
content := TestFileContent(errorLog)
|
||||||
require.NoError(t, err)
|
|
||||||
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
|
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
|
||||||
require.Len(t, lines, 1, "only 4xx requests should be logged")
|
require.Len(t, lines, 1, "only 4xx requests should be logged")
|
||||||
assert.Equal(t, "/notfound returned 404", lines[0])
|
assert.Equal(t, "/notfound returned 404", lines[0])
|
||||||
|
|
||||||
|
infoContent := TestFileContent(infoLog)
|
||||||
|
lines = strings.Split(strings.TrimSpace(string(infoContent)), "\n")
|
||||||
|
require.Len(t, lines, 1, "only 200 requests should be logged")
|
||||||
|
assert.Equal(t, "/success returned 200", lines[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
func TestHTTPFlow_ConditionalRulesYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
||||||
@@ -320,22 +325,21 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
|||||||
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
func TestHTTPFlow_ComplexFlowWithPreAndPostRulesYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Simulate different responses based on path
|
// Simulate different responses based on path
|
||||||
if r.URL.Path == "/protected" {
|
if r.URL.Path == "/protected" {
|
||||||
if r.Header.Get("X-Auth") != "valid" {
|
if r.Header.Get("X-Auth") != "valid" {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
w.Write([]byte("unauthorized"))
|
fmt.Fprint(w, "unauthorized")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.Header().Set("X-Response-Time", "100ms")
|
w.Header().Set("X-Response-Time", "100ms")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("success"))
|
fmt.Fprint(w, "success")
|
||||||
})
|
})
|
||||||
|
|
||||||
// Create temporary files for logging
|
|
||||||
logFile := TestRandomFileName()
|
logFile := TestRandomFileName()
|
||||||
errorLogFile := TestRandomFileName()
|
errorLogFile := TestRandomFileName()
|
||||||
|
|
||||||
@@ -402,8 +406,8 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
|||||||
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
|
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_DefaultRule(t *testing.T) {
|
func TestHTTPFlow_DefaultRuleYAML(t *testing.T) {
|
||||||
upstream := mockUpstream("upstream response")
|
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
err := parseRules(`
|
err := parseRules(`
|
||||||
@@ -426,21 +430,57 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
|
|||||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||||
|
|
||||||
// Test special rule + default rule
|
// Test special rule (default should not run)
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||||
w2 := httptest.NewRecorder()
|
w2 := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w2, req2)
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w2.Code)
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
|
assert.Empty(t, w2.Header().Get("X-Default-Applied"))
|
||||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
func TestHTTPFlow_DefaultRuleWithOnDefaultYAML(t *testing.T) {
|
||||||
|
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
err := parseRules(`
|
||||||
|
- name: default-on-rule
|
||||||
|
on: default
|
||||||
|
do: set resp_header X-Default-Applied true
|
||||||
|
- name: special-rule
|
||||||
|
on: path /special
|
||||||
|
do: set resp_header X-Special-Handled true
|
||||||
|
`, &rules)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
|
// Test default rule on regular request
|
||||||
|
req1 := httptest.NewRequest(http.MethodGet, "/regular", nil)
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w1, req1)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w1.Code)
|
||||||
|
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||||
|
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||||
|
|
||||||
|
// Test special rule on matching request (default should not run)
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w2.Code)
|
||||||
|
assert.Empty(t, w2.Header().Get("X-Default-Applied"))
|
||||||
|
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPFlow_HeaderManipulationYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Echo back a header
|
// Echo back a header
|
||||||
headerValue := r.Header.Get("X-Test-Header")
|
headerValue := r.Header.Get("X-Test-Header")
|
||||||
w.Header().Set("X-Echoed-Header", headerValue)
|
w.Header().Set("X-Echoed-Header", headerValue)
|
||||||
|
w.Header().Set("X-Secret", "sensitive-data")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("header echoed"))
|
w.Write([]byte("header echoed"))
|
||||||
})
|
})
|
||||||
@@ -460,7 +500,6 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
|||||||
handler := rules.BuildHandler(upstream)
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.Header.Set("X-Secret", "secret-value")
|
|
||||||
req.Header.Set("X-Test-Header", "original-value")
|
req.Header.Set("X-Test-Header", "original-value")
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -469,11 +508,10 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
||||||
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
||||||
// Ensure the secret header was removed and not passed to upstream
|
assert.Empty(t, w.Header().Get("X-Secret"))
|
||||||
// (we can't directly test this, but the upstream shouldn't see it)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
func TestHTTPFlow_QueryParameterHandlingYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
query := r.URL.Query()
|
query := r.URL.Query()
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
@@ -500,13 +538,15 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
|||||||
assert.Equal(t, "query: added-value", w.Body.String())
|
assert.Equal(t, "query: added-value", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ServeCommand(t *testing.T) {
|
func TestHTTPFlow_ServeCommandYAML(t *testing.T) {
|
||||||
// Create a temporary directory with test files
|
// Create a temporary directory with test files
|
||||||
tempDir := t.TempDir()
|
tempDir, err := os.MkdirTemp("", "test-serve-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
// Create test files directly in the temp directory
|
// Create test files directly in the temp directory
|
||||||
testFile := filepath.Join(tempDir, "index.html")
|
testFile := filepath.Join(tempDir, "index.html")
|
||||||
err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
@@ -517,7 +557,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
|||||||
`, tempDir), &rules)
|
`, tempDir), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
|
||||||
|
|
||||||
// Test serving a file - serve command serves files relative to the root directory
|
// Test serving a file - serve command serves files relative to the root directory
|
||||||
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
||||||
@@ -546,7 +586,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
func TestHTTPFlow_ProxyCommandYAML(t *testing.T) {
|
||||||
// Create a mock upstream server
|
// Create a mock upstream server
|
||||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Upstream-Header", "upstream-value")
|
w.Header().Set("X-Upstream-Header", "upstream-value")
|
||||||
@@ -563,7 +603,7 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
|||||||
`, upstreamServer.URL), &rules)
|
`, upstreamServer.URL), &rules)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -576,11 +616,28 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
|||||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_NotifyCommand(t *testing.T) {
|
func TestHTTPFlow_NotifyCommandYAML(t *testing.T) {
|
||||||
// TODO:
|
upstream := mockUpstream(http.StatusOK, "ok")
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
err := parseRules(`
|
||||||
|
- name: notify-rule
|
||||||
|
on: path /notify
|
||||||
|
do: notify info test-provider "title $req_method" "body $req_url $status_code"
|
||||||
|
`, &rules)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/notify", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Equal(t, "ok", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_FormConditions(t *testing.T) {
|
func TestHTTPFlow_FormConditionsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("form processed"))
|
w.Write([]byte("form processed"))
|
||||||
@@ -620,7 +677,7 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
|
|||||||
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
func TestHTTPFlow_RemoteConditionsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("remote processed"))
|
w.Write([]byte("remote processed"))
|
||||||
@@ -654,11 +711,11 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
|||||||
w2 := httptest.NewRecorder()
|
w2 := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w2, req2)
|
handler.ServeHTTP(w2, req2)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusForbidden, w2.Code)
|
assert.Equal(t, 403, w2.Code)
|
||||||
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
func TestHTTPFlow_BasicAuthConditionsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("auth processed"))
|
w.Write([]byte("auth processed"))
|
||||||
@@ -702,7 +759,7 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
|||||||
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_RouteConditions(t *testing.T) {
|
func TestHTTPFlow_RouteConditionsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("route processed"))
|
w.Write([]byte("route processed"))
|
||||||
@@ -742,10 +799,10 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
|
|||||||
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
func TestHTTPFlow_ResponseStatusConditionsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
w.Write([]byte("method not allowed"))
|
fmt.Fprint(w, "method not allowed")
|
||||||
})
|
})
|
||||||
|
|
||||||
var rules Rules
|
var rules Rules
|
||||||
@@ -767,11 +824,11 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
|||||||
assert.Equal(t, "error\n", w.Body.String())
|
assert.Equal(t, "error\n", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
func TestHTTPFlow_ResponseHeaderConditionsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Response-Header", "response header")
|
w.Header().Set("X-Response-Header", "response header")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("processed"))
|
fmt.Fprint(w, "processed")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("any_value", func(t *testing.T) {
|
t.Run("any_value", func(t *testing.T) {
|
||||||
@@ -831,7 +888,65 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMatchersYAML(t *testing.T) {
|
||||||
|
upstreamCalled := false
|
||||||
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
upstreamCalled = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprint(w, "upstream")
|
||||||
|
})
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
err := parseRules(`
|
||||||
|
- on: path /
|
||||||
|
do: error 403 blocked
|
||||||
|
- on: path /
|
||||||
|
do: set resp_header X-Late should-run
|
||||||
|
- on: status 4xx
|
||||||
|
do: set resp_header X-Post true
|
||||||
|
`, &rules)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.False(t, upstreamCalled)
|
||||||
|
assert.Equal(t, 403, w.Code)
|
||||||
|
assert.Equal(t, "blocked\n", w.Body.String())
|
||||||
|
assert.Equal(t, "should-run", w.Header().Get("X-Late"))
|
||||||
|
assert.Equal(t, "true", w.Header().Get("X-Post"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPFlow_PostRuleTermination_StopsRemainingCommandsInRuleYAML(t *testing.T) {
|
||||||
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("ok"))
|
||||||
|
})
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
err := parseRules(`
|
||||||
|
- on: status 200
|
||||||
|
do: |
|
||||||
|
error 500 failed
|
||||||
|
set resp_header X-After should-not-run
|
||||||
|
`, &rules)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
handler := rules.BuildHandler(upstream)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
assert.Equal(t, "failed\n", w.Body.String())
|
||||||
|
assert.Empty(t, w.Header().Get("X-After"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPFlow_ComplexRuleCombinationsYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("complex processed"))
|
w.Write([]byte("complex processed"))
|
||||||
@@ -887,12 +1002,12 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
|||||||
w3 := httptest.NewRecorder()
|
w3 := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w3, req3)
|
handler.ServeHTTP(w3, req3)
|
||||||
|
|
||||||
assert.Equal(t, 200, w3.Code)
|
assert.Equal(t, http.StatusOK, w3.Code)
|
||||||
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
|
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
|
||||||
assert.Empty(t, w3.Header()["X-API-Version"])
|
assert.Empty(t, w3.Header()["X-API-Version"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
func TestHTTPFlow_ResponseModifierYAML(t *testing.T) {
|
||||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("original response"))
|
w.Write([]byte("original response"))
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/yusing/godoxy/internal/common"
|
"github.com/yusing/godoxy/internal/common"
|
||||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
type noopWriteCloser struct {
|
type noopWriteCloser struct {
|
||||||
@@ -30,7 +31,7 @@ var (
|
|||||||
testFilesLock sync.Mutex
|
testFilesLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
func openFile(path string) (io.WriteCloser, error) {
|
func openFile(path string) (io.WriteCloser, gperr.Error) {
|
||||||
switch path {
|
switch path {
|
||||||
case "/dev/stdout":
|
case "/dev/stdout":
|
||||||
return stdout, nil
|
return stdout, nil
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gobwas/glob"
|
"github.com/gobwas/glob"
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,6 +14,8 @@ type (
|
|||||||
MatcherType string
|
MatcherType string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var matcherCache = xsync.NewMap[string, Matcher]() // map[string]Matcher
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MatcherTypeString MatcherType = "string"
|
MatcherTypeString MatcherType = "string"
|
||||||
MatcherTypeGlob MatcherType = "glob"
|
MatcherTypeGlob MatcherType = "glob"
|
||||||
@@ -59,7 +62,12 @@ func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||||
|
if cached, ok := matcherCache.Load(expr); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
|
||||||
negate := false
|
negate := false
|
||||||
|
origExpr := expr
|
||||||
if strings.HasPrefix(expr, "!") {
|
if strings.HasPrefix(expr, "!") {
|
||||||
negate = true
|
negate = true
|
||||||
expr = expr[1:]
|
expr = expr[1:]
|
||||||
@@ -72,11 +80,23 @@ func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
|||||||
|
|
||||||
switch t {
|
switch t {
|
||||||
case MatcherTypeString:
|
case MatcherTypeString:
|
||||||
return StringMatcher(expr, negate)
|
m, err := StringMatcher(expr, negate)
|
||||||
|
if err == nil {
|
||||||
|
matcherCache.Store(origExpr, m)
|
||||||
|
}
|
||||||
|
return m, err
|
||||||
case MatcherTypeGlob:
|
case MatcherTypeGlob:
|
||||||
return GlobMatcher(expr, negate)
|
m, err := GlobMatcher(expr, negate)
|
||||||
|
if err == nil {
|
||||||
|
matcherCache.Store(origExpr, m)
|
||||||
|
}
|
||||||
|
return m, err
|
||||||
case MatcherTypeRegex:
|
case MatcherTypeRegex:
|
||||||
return RegexMatcher(expr, negate)
|
m, err := RegexMatcher(expr, negate)
|
||||||
|
if err == nil {
|
||||||
|
matcherCache.Store(origExpr, m)
|
||||||
|
}
|
||||||
|
return m, err
|
||||||
}
|
}
|
||||||
// won't reach here
|
// won't reach here
|
||||||
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
||||||
|
|||||||
@@ -14,17 +14,17 @@ import (
|
|||||||
type RuleOn struct {
|
type RuleOn struct {
|
||||||
raw string
|
raw string
|
||||||
checker Checker
|
checker Checker
|
||||||
isResponseChecker bool
|
phase PhaseFlag
|
||||||
}
|
|
||||||
|
|
||||||
func (on *RuleOn) IsResponseChecker() bool {
|
|
||||||
return on.isResponseChecker
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
|
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||||
return on.checker.Check(w, r)
|
if on.checker == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// on request
|
||||||
const (
|
const (
|
||||||
OnDefault = "default"
|
OnDefault = "default"
|
||||||
OnHeader = "header"
|
OnHeader = "header"
|
||||||
@@ -39,9 +39,10 @@ const (
|
|||||||
OnRemote = "remote"
|
OnRemote = "remote"
|
||||||
OnBasicAuth = "basic_auth"
|
OnBasicAuth = "basic_auth"
|
||||||
OnRoute = "route"
|
OnRoute = "route"
|
||||||
|
)
|
||||||
|
|
||||||
// on response
|
// on response
|
||||||
|
const (
|
||||||
OnResponseHeader = "resp_header"
|
OnResponseHeader = "resp_header"
|
||||||
OnStatus = "status"
|
OnStatus = "status"
|
||||||
)
|
)
|
||||||
@@ -50,24 +51,24 @@ var checkers = map[string]struct {
|
|||||||
help Help
|
help Help
|
||||||
validate ValidateFunc
|
validate ValidateFunc
|
||||||
builder func(args any) CheckFunc
|
builder func(args any) CheckFunc
|
||||||
isResponseChecker bool
|
|
||||||
}{
|
}{
|
||||||
OnDefault: {
|
OnDefault: {
|
||||||
help: Help{
|
help: Help{
|
||||||
command: OnDefault,
|
command: OnDefault,
|
||||||
description: makeLines(
|
description: makeLines(
|
||||||
"The default rule is matched when no other rules are matched.",
|
"Select the default (fallback) rule.",
|
||||||
),
|
),
|
||||||
args: map[string]string{},
|
args: map[string]string{},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, ErrExpectNoArg
|
return phase, nil, ErrExpectNoArg
|
||||||
}
|
}
|
||||||
//nolint:nilnil
|
return phase, nil, nil
|
||||||
return nil, nil
|
},
|
||||||
|
builder: func(args any) CheckFunc {
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request) bool { return true }
|
||||||
},
|
},
|
||||||
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
|
|
||||||
},
|
},
|
||||||
OnHeader: {
|
OnHeader: {
|
||||||
help: Help{
|
help: Help{
|
||||||
@@ -83,21 +84,23 @@ var checkers = map[string]struct {
|
|||||||
"[value]": "the header value",
|
"[value]": "the header value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKVOptionalVMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return len(r.Header[k]) > 0
|
return len(r.Header[k]) > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return slices.ContainsFunc(r.Header[k], matcher)
|
return slices.ContainsFunc(r.Header[k], matcher)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnResponseHeader: {
|
OnResponseHeader: {
|
||||||
isResponseChecker: true,
|
|
||||||
help: Help{
|
help: Help{
|
||||||
command: OnResponseHeader,
|
command: OnResponseHeader,
|
||||||
description: makeLines(
|
description: makeLines(
|
||||||
@@ -111,16 +114,20 @@ var checkers = map[string]struct {
|
|||||||
"[value]": "the response header value",
|
"[value]": "the response header value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKVOptionalVMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePost
|
||||||
|
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
|
return len(w.Header()[k]) > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
|
return slices.ContainsFunc(w.Header()[k], matcher)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -138,16 +145,19 @@ var checkers = map[string]struct {
|
|||||||
"[value]": "the query value",
|
"[value]": "the query value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKVOptionalVMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
|
return len(w.SharedData().GetQueries(r)[k]) > 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
|
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -165,12 +175,15 @@ var checkers = map[string]struct {
|
|||||||
"[value]": "the cookie value",
|
"[value]": "the cookie value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKVOptionalVMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
cookies := w.SharedData().GetCookies(r)
|
||||||
for _, cookie := range cookies {
|
for _, cookie := range cookies {
|
||||||
if cookie.Name == k {
|
if cookie.Name == k {
|
||||||
return true
|
return true
|
||||||
@@ -179,8 +192,8 @@ var checkers = map[string]struct {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
cookies := w.SharedData().GetCookies(r)
|
||||||
for _, cookie := range cookies {
|
for _, cookie := range cookies {
|
||||||
if cookie.Name == k {
|
if cookie.Name == k {
|
||||||
if matcher(cookie.Value) {
|
if matcher(cookie.Value) {
|
||||||
@@ -192,6 +205,7 @@ var checkers = map[string]struct {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
//nolint:dupl
|
||||||
OnForm: {
|
OnForm: {
|
||||||
help: Help{
|
help: Help{
|
||||||
command: OnForm,
|
command: OnForm,
|
||||||
@@ -206,15 +220,18 @@ var checkers = map[string]struct {
|
|||||||
"[value]": "the form value",
|
"[value]": "the form value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKVOptionalVMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return r.FormValue(k) != ""
|
return r.FormValue(k) != ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return matcher(r.FormValue(k))
|
return matcher(r.FormValue(k))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -233,15 +250,18 @@ var checkers = map[string]struct {
|
|||||||
"[value]": "the form value",
|
"[value]": "the form value",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: toKVOptionalVMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||||
if matcher == nil {
|
if matcher == nil {
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return r.PostFormValue(k) != ""
|
return r.PostFormValue(k) != ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return matcher(r.PostFormValue(k))
|
return matcher(r.PostFormValue(k))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -250,32 +270,46 @@ var checkers = map[string]struct {
|
|||||||
help: Help{
|
help: Help{
|
||||||
command: OnProto,
|
command: OnProto,
|
||||||
args: map[string]string{
|
args: map[string]string{
|
||||||
"proto": "the http protocol (http, https, h3)",
|
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: func(args []string) (any, error) {
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return phase, nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
proto := args[0]
|
proto := args[0]
|
||||||
if proto != "http" && proto != "https" && proto != "h3" {
|
switch proto {
|
||||||
return nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
case "http", "https", "h1", "h2", "h2c", "h3":
|
||||||
|
return phase, proto, nil
|
||||||
|
default:
|
||||||
|
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||||
}
|
}
|
||||||
return proto, nil
|
|
||||||
},
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
proto := args.(string)
|
proto := args.(string)
|
||||||
switch proto {
|
switch proto {
|
||||||
case "http":
|
case "http":
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return r.TLS == nil
|
return r.TLS == nil
|
||||||
}
|
}
|
||||||
case "https":
|
case "https":
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return r.TLS != nil
|
return r.TLS != nil
|
||||||
}
|
}
|
||||||
|
case "h1":
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
|
return r.TLS == nil && r.ProtoMajor == 1
|
||||||
|
}
|
||||||
|
case "h2":
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
|
return r.TLS != nil && r.ProtoMajor == 2
|
||||||
|
}
|
||||||
|
case "h2c":
|
||||||
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
|
return r.TLS == nil && r.ProtoMajor == 2
|
||||||
|
}
|
||||||
default: // h3
|
default: // h3
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return r.TLS != nil && r.ProtoMajor == 3
|
return r.TLS != nil && r.ProtoMajor == 3
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -288,10 +322,13 @@ var checkers = map[string]struct {
|
|||||||
"method": "the http method",
|
"method": "the http method",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateMethod,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = validateMethod(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
method := args.(string)
|
method := args.(string)
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return r.Method == method
|
return r.Method == method
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -310,10 +347,13 @@ var checkers = map[string]struct {
|
|||||||
"host": "the host name",
|
"host": "the host name",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateSingleMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = validateSingleMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
matcher := args.(Matcher)
|
matcher := args.(Matcher)
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return matcher(r.Host)
|
return matcher(r.Host)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -332,10 +372,13 @@ var checkers = map[string]struct {
|
|||||||
"path": "the request path",
|
"path": "the request path",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateURLPathMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = validateURLPathMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
matcher := args.(Matcher)
|
matcher := args.(Matcher)
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
reqPath := r.URL.Path
|
reqPath := r.URL.Path
|
||||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||||
reqPath = "/" + reqPath
|
reqPath = "/" + reqPath
|
||||||
@@ -351,22 +394,25 @@ var checkers = map[string]struct {
|
|||||||
"ip|cidr": "the remote ip or cidr",
|
"ip|cidr": "the remote ip or cidr",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateCIDR,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = validateCIDR(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
ipnet := args.(*net.IPNet)
|
ipnet := args.(*net.IPNet)
|
||||||
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
||||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||||
wantIP := ipnet.IP
|
wantIP := ipnet.IP
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
ip := w.SharedData().GetRemoteIP(r)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return ip.Equal(wantIP)
|
return ip.Equal(wantIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
ip := w.SharedData().GetRemoteIP(r)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -382,11 +428,14 @@ var checkers = map[string]struct {
|
|||||||
"password": "the password encrypted with bcrypt",
|
"password": "the password encrypted with bcrypt",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateUserBCryptPassword,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = validateUserBCryptPassword(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
cred := args.(*HashedCrendentials)
|
cred := args.(*HashedCrendentials)
|
||||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
|
return cred.Match(w.SharedData().GetBasicAuth(r))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -403,16 +452,18 @@ var checkers = map[string]struct {
|
|||||||
"route": "the route name",
|
"route": "the route name",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateSingleMatcher,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
parsedArgs, err = validateSingleMatcher(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
matcher := args.(Matcher)
|
matcher := args.(Matcher)
|
||||||
return func(_ http.ResponseWriter, r *http.Request) bool {
|
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return matcher(routes.TryGetUpstreamName(r))
|
return matcher(routes.TryGetUpstreamName(r))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnStatus: {
|
OnStatus: {
|
||||||
isResponseChecker: true,
|
|
||||||
help: Help{
|
help: Help{
|
||||||
command: OnStatus,
|
command: OnStatus,
|
||||||
description: makeLines(
|
description: makeLines(
|
||||||
@@ -429,16 +480,20 @@ var checkers = map[string]struct {
|
|||||||
"status": "the status code range",
|
"status": "the status code range",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
validate: validateStatusRange,
|
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
phase = PhasePost
|
||||||
|
parsedArgs, err = validateStatusRange(args)
|
||||||
|
return
|
||||||
|
},
|
||||||
builder: func(args any) CheckFunc {
|
builder: func(args any) CheckFunc {
|
||||||
beg, end := args.(*IntTuple).Unpack()
|
beg, end := args.(*IntTuple).Unpack()
|
||||||
if beg == end {
|
if beg == end {
|
||||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||||
return httputils.GetInitResponseModifier(w).StatusCode() == beg
|
return w.StatusCode() == beg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||||
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
|
statusCode := w.StatusCode()
|
||||||
return statusCode >= beg && statusCode <= end
|
return statusCode >= beg && statusCode <= end
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -515,85 +570,90 @@ func splitPipe(s string) []string {
|
|||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []string
|
result := []string{}
|
||||||
var current strings.Builder
|
forEachPipePart(s, func(part string) {
|
||||||
escaped := false
|
result = append(result, part)
|
||||||
quote := rune(0)
|
})
|
||||||
brackets := 0
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
for _, r := range s {
|
func forEachAndPart(s string, fn func(part string)) {
|
||||||
if escaped {
|
start := 0
|
||||||
current.WriteRune(r)
|
for i := 0; i <= len(s); i++ {
|
||||||
escaped = false
|
if i < len(s) && andSeps[s[i]] == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
part := strings.TrimSpace(s[start:i])
|
||||||
|
if part != "" {
|
||||||
|
fn(part)
|
||||||
|
}
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch r {
|
func forEachPipePart(s string, fn func(part string)) {
|
||||||
|
quote := byte(0)
|
||||||
|
brackets := 0
|
||||||
|
start := 0
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch s[i] {
|
||||||
case '\\':
|
case '\\':
|
||||||
escaped = true
|
if i+1 < len(s) {
|
||||||
current.WriteRune(r)
|
i++
|
||||||
|
}
|
||||||
case '"', '\'', '`':
|
case '"', '\'', '`':
|
||||||
if quote == 0 && brackets == 0 {
|
if quote == 0 && brackets == 0 {
|
||||||
quote = r
|
quote = s[i]
|
||||||
} else if r == quote {
|
} else if s[i] == quote {
|
||||||
quote = 0
|
quote = 0
|
||||||
}
|
}
|
||||||
current.WriteRune(r)
|
|
||||||
case '(':
|
case '(':
|
||||||
brackets++
|
brackets++
|
||||||
current.WriteRune(r)
|
|
||||||
case ')':
|
case ')':
|
||||||
if brackets > 0 {
|
if brackets > 0 {
|
||||||
brackets--
|
brackets--
|
||||||
}
|
}
|
||||||
current.WriteRune(r)
|
|
||||||
case '|':
|
case '|':
|
||||||
if quote == 0 && brackets == 0 {
|
if quote == 0 && brackets == 0 {
|
||||||
// Found a pipe outside quotes/brackets, split here
|
if part := strings.TrimSpace(s[start:i]); part != "" {
|
||||||
result = append(result, strings.TrimSpace(current.String()))
|
fn(part)
|
||||||
current.Reset()
|
|
||||||
} else {
|
|
||||||
current.WriteRune(r)
|
|
||||||
}
|
}
|
||||||
default:
|
start = i + 1
|
||||||
current.WriteRune(r)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the last part
|
|
||||||
if current.Len() > 0 {
|
|
||||||
result = append(result, strings.TrimSpace(current.String()))
|
|
||||||
}
|
}
|
||||||
|
if start < len(s) {
|
||||||
return result
|
if part := strings.TrimSpace(s[start:]); part != "" {
|
||||||
|
fn(part)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse implements strutils.Parser.
|
// Parse implements strutils.Parser.
|
||||||
func (on *RuleOn) Parse(v string) error {
|
func (on *RuleOn) Parse(v string) error {
|
||||||
on.raw = v
|
on.raw = v
|
||||||
|
|
||||||
rules := splitAnd(v)
|
ruleCount := 0
|
||||||
checkAnd := make(CheckMatchAll, 0, len(rules))
|
forEachAndPart(v, func(_ string) {
|
||||||
|
ruleCount++
|
||||||
|
})
|
||||||
|
checkAnd := make(CheckMatchAll, 0, ruleCount)
|
||||||
|
|
||||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||||
isResponseChecker := false
|
i := 0
|
||||||
for i, rule := range rules {
|
forEachAndPart(v, func(rule string) {
|
||||||
if rule == "" {
|
i++
|
||||||
continue
|
parsed, phase, err := parseOn(rule)
|
||||||
}
|
|
||||||
parsed, isResp, err := parseOn(rule)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.AddSubjectf(err, "line %d", i+1)
|
errs.AddSubjectf(err, "line %d", i)
|
||||||
continue
|
return
|
||||||
}
|
|
||||||
if isResp {
|
|
||||||
isResponseChecker = true
|
|
||||||
}
|
}
|
||||||
|
on.phase |= phase
|
||||||
checkAnd = append(checkAnd, parsed)
|
checkAnd = append(checkAnd, parsed)
|
||||||
}
|
})
|
||||||
|
|
||||||
on.checker = checkAnd
|
on.checker = checkAnd
|
||||||
on.isResponseChecker = isResponseChecker
|
|
||||||
return errs.Error()
|
return errs.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -605,33 +665,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
|||||||
return []byte(on.String()), nil
|
return []byte(on.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOn(line string) (Checker, bool, error) {
|
func parseOn(line string) (Checker, PhaseFlag, error) {
|
||||||
ors := splitPipe(line)
|
orCount := 0
|
||||||
|
forEachPipePart(line, func(_ string) {
|
||||||
if len(ors) > 1 {
|
orCount++
|
||||||
|
})
|
||||||
|
if orCount > 1 {
|
||||||
|
var phase PhaseFlag
|
||||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||||
checkOr := make(CheckMatchSingle, len(ors))
|
checkOr := make(CheckMatchSingle, orCount)
|
||||||
isResponseChecker := false
|
i := 0
|
||||||
for i, or := range ors {
|
forEachPipePart(line, func(or string) {
|
||||||
curCheckers, isResp, err := parseOn(or)
|
i++
|
||||||
|
checkFunc, req, err := parseOnAtom(or)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.Add(err)
|
errs.AddSubjectf(err, "or[%d]", i)
|
||||||
continue
|
return
|
||||||
}
|
|
||||||
if isResp {
|
|
||||||
isResponseChecker = true
|
|
||||||
}
|
|
||||||
checkOr[i] = curCheckers.(CheckFunc)
|
|
||||||
}
|
}
|
||||||
|
checkOr[i-1] = checkFunc
|
||||||
|
phase |= req
|
||||||
|
})
|
||||||
if err := errs.Error(); err != nil {
|
if err := errs.Error(); err != nil {
|
||||||
return nil, false, err
|
return nil, phase, err
|
||||||
}
|
}
|
||||||
return checkOr, isResponseChecker, nil
|
return checkOr, phase, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return parseOnAtom(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
|
||||||
|
var phase PhaseFlag
|
||||||
subject, args, err := parse(line)
|
subject, args, err := parse(line)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, phase, err
|
||||||
}
|
}
|
||||||
|
|
||||||
negate := false
|
negate := false
|
||||||
@@ -642,20 +709,21 @@ func parseOn(line string) (Checker, bool, error) {
|
|||||||
|
|
||||||
checker, ok := checkers[subject]
|
checker, ok := checkers[subject]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false, ErrInvalidOnTarget.Subject(subject)
|
return nil, phase, ErrInvalidOnTarget.Subject(subject)
|
||||||
}
|
}
|
||||||
|
|
||||||
validArgs, err := checker.validate(args)
|
req, validArgs, err := checker.validate(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, gperr.Wrap(err).With(checker.help.Error())
|
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
|
||||||
}
|
}
|
||||||
|
phase |= req
|
||||||
|
|
||||||
checkFunc := checker.builder(validArgs)
|
checkFunc := checker.builder(validArgs)
|
||||||
if negate {
|
if negate {
|
||||||
origCheckFunc := checkFunc
|
origCheckFunc := checkFunc
|
||||||
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
|
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
return !origCheckFunc(w, r)
|
return !origCheckFunc(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return checkFunc, checker.isResponseChecker, nil
|
return checkFunc, phase, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func TestSplitPipe(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "empty_segments",
|
name: "empty_segments",
|
||||||
input: "rule1 || rule2 | | rule3",
|
input: "rule1 || rule2 | | rule3",
|
||||||
want: []string{"rule1", "", "rule2", "", "rule3"},
|
want: []string{"rule1", "rule2", "rule3"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/yusing/godoxy/internal/route"
|
"github.com/yusing/godoxy/internal/route"
|
||||||
"github.com/yusing/godoxy/internal/route/routes"
|
"github.com/yusing/godoxy/internal/route/routes"
|
||||||
. "github.com/yusing/godoxy/internal/route/rules"
|
. "github.com/yusing/godoxy/internal/route/rules"
|
||||||
|
httputils "github.com/yusing/goutils/http"
|
||||||
expect "github.com/yusing/goutils/testing"
|
expect "github.com/yusing/goutils/testing"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
@@ -386,7 +387,7 @@ func TestOnCorrectness(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
var on RuleOn
|
var on RuleOn
|
||||||
err := on.Parse(tt.checker)
|
err := on.Parse(tt.checker)
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"strings"
|
||||||
"fmt"
|
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
"github.com/yusing/goutils/env"
|
"github.com/yusing/goutils/env"
|
||||||
@@ -25,6 +24,76 @@ var quoteChars = [256]bool{
|
|||||||
'`': true,
|
'`': true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseSimple(v string) (subject string, args []string, err error, ok bool) {
|
||||||
|
brackets := 0
|
||||||
|
for i := range len(v) {
|
||||||
|
switch v[i] {
|
||||||
|
case '\\', '$', '"', '\'', '`', '\t', '\r', '\n':
|
||||||
|
return "", nil, nil, false
|
||||||
|
case '(':
|
||||||
|
brackets++
|
||||||
|
case ')':
|
||||||
|
if brackets == 0 {
|
||||||
|
return "", nil, ErrUnterminatedBrackets, true
|
||||||
|
}
|
||||||
|
brackets--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if brackets != 0 {
|
||||||
|
return "", nil, ErrUnterminatedBrackets, true
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(v) && v[i] == ' ' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i >= len(v) {
|
||||||
|
return "", nil, nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
start := i
|
||||||
|
for i < len(v) && v[i] != ' ' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
subject = v[start:i]
|
||||||
|
|
||||||
|
if i >= len(v) {
|
||||||
|
return subject, nil, nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
argCount := 0
|
||||||
|
for j := i; j < len(v); {
|
||||||
|
for j < len(v) && v[j] == ' ' {
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
if j >= len(v) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
argCount++
|
||||||
|
for j < len(v) && v[j] != ' ' {
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if argCount == 0 {
|
||||||
|
return subject, nil, nil, true
|
||||||
|
}
|
||||||
|
args = make([]string, 0, argCount)
|
||||||
|
for i < len(v) {
|
||||||
|
for i < len(v) && v[i] == ' ' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i >= len(v) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start = i
|
||||||
|
for i < len(v) && v[i] != ' ' {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
args = append(args, v[start:i])
|
||||||
|
}
|
||||||
|
return subject, args, nil, true
|
||||||
|
}
|
||||||
|
|
||||||
// parse expression to subject and args
|
// parse expression to subject and args
|
||||||
// with support for quotes, escaped chars, and env substitution, e.g.
|
// with support for quotes, escaped chars, and env substitution, e.g.
|
||||||
//
|
//
|
||||||
@@ -32,14 +101,21 @@ var quoteChars = [256]bool{
|
|||||||
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
||||||
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
||||||
func parse(v string) (subject string, args []string, err error) {
|
func parse(v string) (subject string, args []string, err error) {
|
||||||
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
|
if subject, args, err, ok := parseSimple(v); ok {
|
||||||
|
return subject, args, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := getStringBuffer(len(v))
|
||||||
|
args = make([]string, 0, 4)
|
||||||
|
|
||||||
escaped := false
|
escaped := false
|
||||||
quote := rune(0)
|
quote := rune(0)
|
||||||
brackets := 0
|
brackets := 0
|
||||||
|
|
||||||
var envVar bytes.Buffer
|
var (
|
||||||
var missingEnvVars []string
|
envVar strings.Builder
|
||||||
|
missingEnvVars []string
|
||||||
|
)
|
||||||
inEnvVar := false
|
inEnvVar := false
|
||||||
expectingBrace := false
|
expectingBrace := false
|
||||||
|
|
||||||
@@ -71,7 +147,8 @@ func parse(v string) (subject string, args []string, err error) {
|
|||||||
if ch, ok := escapedChars[r]; ok {
|
if ch, ok := escapedChars[r]; ok {
|
||||||
buf.WriteRune(ch)
|
buf.WriteRune(ch)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(buf, `\%c`, r)
|
buf.WriteRune('\\')
|
||||||
|
buf.WriteRune(r)
|
||||||
}
|
}
|
||||||
escaped = false
|
escaped = false
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
expect "github.com/yusing/goutils/testing"
|
expect "github.com/yusing/goutils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,6 +14,7 @@ func TestParser(t *testing.T) {
|
|||||||
input string
|
input string
|
||||||
subject string
|
subject string
|
||||||
args []string
|
args []string
|
||||||
|
wantErr gperr.Error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "basic",
|
name: "basic",
|
||||||
@@ -90,6 +92,10 @@ func TestParser(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
subject, args, err := parse(tt.input)
|
subject, args, err := parse(tt.input)
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
expect.ErrorIs(t, tt.wantErr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
// t.Log(subject, args, err)
|
// t.Log(subject, args, err)
|
||||||
expect.NoError(t, err)
|
expect.NoError(t, err)
|
||||||
expect.Equal(t, subject, tt.subject)
|
expect.Equal(t, subject, tt.subject)
|
||||||
|
|||||||
29
internal/route/rules/phase.go
Normal file
29
internal/route/rules/phase.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type PhaseFlag uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
PhaseNone PhaseFlag = 0
|
||||||
|
PhasePre PhaseFlag = 1 << (iota - 1)
|
||||||
|
PhasePost
|
||||||
|
)
|
||||||
|
|
||||||
|
func (phase PhaseFlag) IsPostRule() bool {
|
||||||
|
return phase&PhasePost != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (phase PhaseFlag) String() string {
|
||||||
|
if phase == PhaseNone {
|
||||||
|
return "none"
|
||||||
|
}
|
||||||
|
var flags []string
|
||||||
|
if phase&PhasePre != 0 {
|
||||||
|
flags = append(flags, "PhasePre")
|
||||||
|
}
|
||||||
|
if phase&PhasePost != 0 {
|
||||||
|
flags = append(flags, "PhasePost")
|
||||||
|
}
|
||||||
|
return strings.Join(flags, ",")
|
||||||
|
}
|
||||||
@@ -4,9 +4,16 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/quic-go/quic-go/http3"
|
"github.com/quic-go/quic-go/http3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/yusing/godoxy/internal/serialization"
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
httputils "github.com/yusing/goutils/http"
|
httputils "github.com/yusing/goutils/http"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
|
||||||
@@ -15,6 +22,8 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
/*
|
/*
|
||||||
|
Rules is a list of rules.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
proxy.app1.rules: |
|
proxy.app1.rules: |
|
||||||
@@ -35,17 +44,14 @@ type (
|
|||||||
on: method POST | method PUT
|
on: method POST | method PUT
|
||||||
do: error 403 Forbidden
|
do: error 403 Forbidden
|
||||||
*/
|
*/
|
||||||
|
//nolint:recvcheck
|
||||||
Rules []Rule
|
Rules []Rule
|
||||||
/*
|
// Rule represents a reverse proxy rule.
|
||||||
Rule is a rule for a reverse proxy.
|
// The `Do` field is executed when `On` matches.
|
||||||
It do `Do` when `On` matches.
|
//
|
||||||
|
// - A rule may have multiple lines in the `On` section.
|
||||||
A rule can have multiple lines of on.
|
// - All `On` lines must match for the rule to trigger.
|
||||||
|
// - Each line can have several checks—one match per line is enough for that line.
|
||||||
All lines of on must match,
|
|
||||||
but each line can have multiple checks that
|
|
||||||
one match means this line is matched.
|
|
||||||
*/
|
|
||||||
Rule struct {
|
Rule struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
On RuleOn `json:"on" swaggertype:"string"`
|
On RuleOn `json:"on" swaggertype:"string"`
|
||||||
@@ -53,210 +59,395 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (rule *Rule) IsResponseRule() bool {
|
func isDefaultRule(rule Rule) bool {
|
||||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
return rule.Name == "default" || rule.On.raw == OnDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rules Rules) Validate() error {
|
func (rules Rules) Validate() gperr.Error {
|
||||||
var defaultRulesFound []int
|
var defaultRulesFound []int
|
||||||
for i, rule := range rules {
|
for i := range rules {
|
||||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
rule := rules[i]
|
||||||
|
if isDefaultRule(rule) {
|
||||||
defaultRulesFound = append(defaultRulesFound, i)
|
defaultRulesFound = append(defaultRulesFound, i)
|
||||||
}
|
}
|
||||||
|
if rules[i].Name == "" {
|
||||||
|
// set name to index if name is empty
|
||||||
|
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if len(defaultRulesFound) > 1 {
|
if len(defaultRulesFound) > 1 {
|
||||||
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
||||||
}
|
}
|
||||||
|
for i := range rules {
|
||||||
|
r1 := rules[i]
|
||||||
|
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sig1, ok := matcherSignature(r1.On.raw)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for j := i + 1; j < len(rules); j++ {
|
||||||
|
r2 := rules[j]
|
||||||
|
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sig2, ok := matcherSignature(r2.On.raw)
|
||||||
|
if !ok || sig1 != sig2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rule Rule) doesTerminateInPre() bool {
|
||||||
|
return commandsTerminateInPre(rule.Do.pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
func commandsTerminateInPre(cmds []CommandHandler) bool {
|
||||||
|
return slices.ContainsFunc(cmds, commandTerminatesInPre)
|
||||||
|
}
|
||||||
|
|
||||||
|
func commandTerminatesInPre(cmd CommandHandler) bool {
|
||||||
|
switch c := cmd.(type) {
|
||||||
|
case Handler:
|
||||||
|
return c.Terminates()
|
||||||
|
case *Handler:
|
||||||
|
return c.Terminates()
|
||||||
|
case IfBlockCommand:
|
||||||
|
return ruleOnAlwaysTrue(c.On) && commandsTerminateInPre(c.Do)
|
||||||
|
case *IfBlockCommand:
|
||||||
|
return c != nil && ruleOnAlwaysTrue(c.On) && commandsTerminateInPre(c.Do)
|
||||||
|
case IfElseBlockCommand:
|
||||||
|
return ifElseBlockTerminatesInPre(c)
|
||||||
|
case *IfElseBlockCommand:
|
||||||
|
return c != nil && ifElseBlockTerminatesInPre(*c)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ifElseBlockTerminatesInPre(cmd IfElseBlockCommand) bool {
|
||||||
|
hasFallback := len(cmd.Else) > 0
|
||||||
|
for _, br := range cmd.Ifs {
|
||||||
|
if !commandsTerminateInPre(br.Do) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ruleOnAlwaysTrue(br.On) {
|
||||||
|
hasFallback = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasFallback {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(cmd.Else) > 0 && !commandsTerminateInPre(cmd.Else) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ruleOnAlwaysTrue(on RuleOn) bool {
|
||||||
|
return strings.TrimSpace(on.raw) == OnDefault || on.checker == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func matcherSignature(raw string) (string, bool) {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return "(any)", true // unconditional rule
|
||||||
|
}
|
||||||
|
|
||||||
|
andParts := splitAnd(raw)
|
||||||
|
if len(andParts) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
canonAnd := make([]string, 0, len(andParts))
|
||||||
|
for _, andPart := range andParts {
|
||||||
|
orParts := splitPipe(andPart)
|
||||||
|
if len(orParts) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
canonOr := make([]string, 0, len(orParts))
|
||||||
|
for _, atom := range orParts {
|
||||||
|
subject, args, err := parse(strings.TrimSpace(atom))
|
||||||
|
if err != nil || subject == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
||||||
|
}
|
||||||
|
slices.Sort(canonOr)
|
||||||
|
canonOr = slices.Compact(canonOr)
|
||||||
|
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(canonAnd)
|
||||||
|
canonAnd = slices.Compact(canonAnd)
|
||||||
|
if len(canonAnd) == 0 {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return strings.Join(canonAnd, "&"), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses a rule configuration string.
|
||||||
|
// It first tries the block syntax (if the string contains a top-level '{'),
|
||||||
|
// then falls back to YAML syntax.
|
||||||
|
func (rules *Rules) Parse(config string) error {
|
||||||
|
config = strings.TrimSpace(config)
|
||||||
|
if config == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
blockTried := false
|
||||||
|
var blockErr gperr.Error
|
||||||
|
|
||||||
|
// Prefer block syntax if it looks like block syntax.
|
||||||
|
if hasTopLevelLBrace(config) {
|
||||||
|
blockTried = true
|
||||||
|
blockRules, err := parseBlockRules(config)
|
||||||
|
if err == nil {
|
||||||
|
*rules = blockRules
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
blockErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
// YAML fallback
|
||||||
|
var anySlice []any
|
||||||
|
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
||||||
|
if yamlErr == nil {
|
||||||
|
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If YAML fails and we haven't tried block syntax yet, try it now.
|
||||||
|
if !blockTried {
|
||||||
|
blockRules, err := parseBlockRules(config)
|
||||||
|
if err == nil {
|
||||||
|
*rules = blockRules
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
blockErr = err
|
||||||
|
}
|
||||||
|
return blockErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
||||||
|
// Used to decide whether to prioritize the block syntax.
|
||||||
|
func hasTopLevelLBrace(s string) bool {
|
||||||
|
quote := rune(0)
|
||||||
|
inLine := false
|
||||||
|
inBlock := false
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
|
||||||
|
if inLine {
|
||||||
|
if c == '\n' {
|
||||||
|
inLine = false
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inBlock {
|
||||||
|
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
||||||
|
inBlock = false
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if quote != 0 {
|
||||||
|
if quote != '`' && c == '\\' && i+1 < len(s) {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rune(c) == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c {
|
||||||
|
case '\'', '"', '`':
|
||||||
|
quote = rune(c)
|
||||||
|
continue
|
||||||
|
case '{':
|
||||||
|
return true
|
||||||
|
case '#':
|
||||||
|
inLine = true
|
||||||
|
continue
|
||||||
|
case '/':
|
||||||
|
if i+1 < len(s) && s[i+1] == '/' {
|
||||||
|
inLine = true
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i+1 < len(s) && s[i+1] == '*' {
|
||||||
|
inBlock = true
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if unicode.IsSpace(rune(c)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
||||||
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
||||||
if len(rules) == 0 {
|
if len(rules) == 0 {
|
||||||
return up
|
return up
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultRule := Rule{
|
var defaultRule *Rule
|
||||||
Name: "default",
|
|
||||||
Do: Command{
|
|
||||||
raw: "pass",
|
|
||||||
exec: BypassCommand{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var nonDefaultRules Rules
|
var nonDefaultRules Rules
|
||||||
hasDefaultRule := false
|
for _, rule := range rules {
|
||||||
for i, rule := range rules {
|
if isDefaultRule(rule) {
|
||||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
r := rule
|
||||||
defaultRule = rule
|
defaultRule = &r
|
||||||
hasDefaultRule = true
|
|
||||||
} else {
|
} else {
|
||||||
// set name to index if name is empty
|
|
||||||
if rule.Name == "" {
|
|
||||||
rule.Name = fmt.Sprintf("rule[%d]", i)
|
|
||||||
}
|
|
||||||
nonDefaultRules = append(nonDefaultRules, rule)
|
nonDefaultRules = append(nonDefaultRules, rule)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nonDefaultRules) == 0 {
|
if len(nonDefaultRules) == 0 {
|
||||||
if defaultRule.Do.isBypass() {
|
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
||||||
return up
|
return up
|
||||||
}
|
}
|
||||||
if defaultRule.IsResponseRule() {
|
}
|
||||||
|
|
||||||
|
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||||
|
return cmd.pre.ServeHTTP(w, r, up)
|
||||||
|
}
|
||||||
|
|
||||||
|
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||||
|
return cmd.post.ServeHTTP(w, r, up)
|
||||||
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
rm := httputils.NewResponseModifier(w)
|
rm := httputils.NewResponseModifier(w)
|
||||||
defer func() {
|
defer func() {
|
||||||
if _, err := rm.FlushRelease(); err != nil {
|
if _, err := rm.FlushRelease(); err != nil {
|
||||||
logError(err, r)
|
logFlushError(err, r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
w = rm
|
|
||||||
up(w, r)
|
var hasError bool
|
||||||
err := defaultRule.Do.exec.Handle(w, r)
|
|
||||||
if err != nil && !errors.Is(err, errTerminated) {
|
executedPre := make([]bool, len(nonDefaultRules))
|
||||||
appendRuleError(rm, &defaultRule, err)
|
terminatedInPre := make([]bool, len(nonDefaultRules))
|
||||||
|
matchedNonDefaultPre := false
|
||||||
|
preTerminated := false
|
||||||
|
for i, rule := range nonDefaultRules {
|
||||||
|
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matchedNonDefaultPre = true
|
||||||
|
if preTerminated {
|
||||||
|
// Preserve post-only commands (e.g. logging) even after
|
||||||
|
// pre-phase termination.
|
||||||
|
if len(rule.Do.pre) == 0 {
|
||||||
|
executedPre[i] = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
executedPre[i] = true
|
||||||
|
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
||||||
|
if errors.Is(err, errTerminateRule) {
|
||||||
|
terminatedInPre[i] = true
|
||||||
|
preTerminated = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isUnexpectedError(err) {
|
||||||
|
// will logged by logFlushError after FlushRelease
|
||||||
|
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
||||||
|
}
|
||||||
|
hasError = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Default rule is a fallback: run only when no non-default pre rule matched.
|
||||||
|
defaultExecutedPre := false
|
||||||
|
defaultTerminatedInPre := false
|
||||||
|
if defaultRule != nil && !matchedNonDefaultPre && !defaultRule.On.phase.IsPostRule() && defaultRule.On.Check(rm, r) {
|
||||||
|
defaultExecutedPre = true
|
||||||
|
if err := execPreCommand(defaultRule.Do, rm, r); err != nil {
|
||||||
|
if errors.Is(err, errTerminateRule) {
|
||||||
|
defaultTerminatedInPre = true
|
||||||
|
} else {
|
||||||
|
if isUnexpectedError(err) {
|
||||||
|
// will logged by logFlushError after FlushRelease
|
||||||
|
rm.AppendError("executing pre rule (%s): %w", defaultRule.Do.raw, err)
|
||||||
}
|
}
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
hasError = true
|
||||||
rm := httputils.NewResponseModifier(w)
|
|
||||||
defer func() {
|
|
||||||
if _, err := rm.FlushRelease(); err != nil {
|
|
||||||
logError(err, r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
w = rm
|
|
||||||
err := defaultRule.Do.exec.Handle(w, r)
|
|
||||||
if err == nil {
|
|
||||||
up(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !errors.Is(err, errTerminated) {
|
|
||||||
appendRuleError(rm, &defaultRule, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
preRules := make(Rules, 0, len(nonDefaultRules))
|
if !rm.HasStatus() {
|
||||||
postRules := make(Rules, 0, len(nonDefaultRules))
|
if hasError {
|
||||||
|
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
||||||
|
up(rm, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run post commands for rules that actually executed in pre phase,
|
||||||
|
// unless that same rule terminated in pre phase.
|
||||||
|
for i, rule := range nonDefaultRules {
|
||||||
|
if !executedPre[i] || terminatedInPre[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||||
|
if errors.Is(err, errTerminateRule) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isUnexpectedError(err) {
|
||||||
|
// will logged by logFlushError after FlushRelease
|
||||||
|
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if defaultExecutedPre && !defaultTerminatedInPre {
|
||||||
|
if err := execPostCommand(defaultRule.Do, rm, r); err != nil {
|
||||||
|
if !errors.Is(err, errTerminateRule) && isUnexpectedError(err) {
|
||||||
|
// will logged by logFlushError after FlushRelease
|
||||||
|
rm.AppendError("executing post rule (%s): %w", defaultRule.Do.raw, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run true post-matcher rules after response is available.
|
||||||
for _, rule := range nonDefaultRules {
|
for _, rule := range nonDefaultRules {
|
||||||
if rule.IsResponseRule() {
|
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||||
postRules = append(postRules, rule)
|
continue
|
||||||
} else {
|
}
|
||||||
preRules = append(preRules, rule)
|
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
||||||
|
// as "pre" for requirement purposes still need to run in this phase.
|
||||||
|
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
||||||
|
if errors.Is(err, errTerminateRule) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isUnexpectedError(err) {
|
||||||
|
// will logged by logFlushError after FlushRelease
|
||||||
|
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||||
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
|
if errors.Is(err, errTerminateRule) {
|
||||||
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
continue
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
rm := httputils.NewResponseModifier(w)
|
|
||||||
defer func() {
|
|
||||||
if _, err := rm.FlushRelease(); err != nil {
|
|
||||||
logError(err, r)
|
|
||||||
}
|
}
|
||||||
}()
|
if isUnexpectedError(err) {
|
||||||
|
// will logged by logFlushError after FlushRelease
|
||||||
w = rm
|
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
||||||
|
|
||||||
shouldCallUpstream := true
|
|
||||||
preMatched := false
|
|
||||||
|
|
||||||
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
|
|
||||||
if defaultRule.Do.isBypass() {
|
|
||||||
// continue to upstream
|
|
||||||
} else {
|
|
||||||
err := defaultRule.Handle(w, r)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, errTerminated) {
|
|
||||||
appendRuleError(rm, &defaultRule, err)
|
|
||||||
}
|
|
||||||
shouldCallUpstream = false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if shouldCallUpstream {
|
|
||||||
for _, rule := range preRules {
|
|
||||||
if rule.Check(w, r) {
|
|
||||||
preMatched = true
|
|
||||||
if rule.Do.isBypass() {
|
|
||||||
break // post rules should still execute
|
|
||||||
}
|
|
||||||
err := rule.Handle(w, r)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, errTerminated) {
|
|
||||||
appendRuleError(rm, &rule, err)
|
|
||||||
}
|
|
||||||
shouldCallUpstream = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
|
|
||||||
if defaultRule.Do.isBypass() {
|
|
||||||
// continue to upstream
|
|
||||||
} else {
|
|
||||||
err := defaultRule.Handle(w, r)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, errTerminated) {
|
|
||||||
appendRuleError(rm, &defaultRule, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
shouldCallUpstream = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldCallUpstream {
|
|
||||||
up(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no post rules, we are done here
|
|
||||||
if len(postRules) == 0 && !isDefaultRulePost {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range postRules {
|
|
||||||
if rule.Check(w, r) {
|
|
||||||
err := rule.Handle(w, r)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, errTerminated) {
|
|
||||||
appendRuleError(rm, &rule, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if isDefaultRulePost {
|
|
||||||
err := defaultRule.Handle(w, r)
|
|
||||||
if err != nil && !errors.Is(err, errTerminated) {
|
|
||||||
appendRuleError(rm, &defaultRule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) {
|
|
||||||
// rm.AppendError("rule: %s, error: %w", rule.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func isTerminatingHandler(handler CommandHandler) bool {
|
|
||||||
switch h := handler.(type) {
|
|
||||||
case TerminatingCommand:
|
|
||||||
return true
|
|
||||||
case Commands:
|
|
||||||
if len(h) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return isTerminatingHandler(h[len(h)-1])
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,41 +455,41 @@ func (rule *Rule) String() string {
|
|||||||
return rule.Name
|
return rule.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
|
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||||
if rule.On.checker == nil {
|
if rule.On.checker == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
v := rule.On.checker.Check(w, r)
|
return rule.On.Check(w, r)
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
return rule.Do.exec.Handle(w, r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
||||||
var errStreamClosed error
|
var errStreamClosed error
|
||||||
|
|
||||||
func logError(err error, r *http.Request) {
|
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
||||||
if errors.Is(err, errStreamClosed) {
|
var errClientDisconnected error
|
||||||
return
|
|
||||||
|
func isUnexpectedError(err error) bool {
|
||||||
|
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
var h2Err http2.StreamError
|
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
||||||
if errors.As(err, &h2Err) {
|
|
||||||
// ignore these errors
|
// ignore these errors
|
||||||
if h2Err.Code == http2.ErrCodeStreamClosed {
|
if h2Err.Code == http2.ErrCodeStreamClosed {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var h3Err *http3.Error
|
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
||||||
if errors.As(err, &h3Err) {
|
|
||||||
// ignore these errors
|
// ignore these errors
|
||||||
switch h3Err.ErrorCode {
|
switch h3Err.ErrorCode {
|
||||||
case
|
case
|
||||||
http3.ErrCodeNoError,
|
http3.ErrCodeNoError,
|
||||||
http3.ErrCodeRequestCanceled:
|
http3.ErrCodeRequestCanceled:
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func logFlushError(err error, r *http.Request) {
|
||||||
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,28 +19,133 @@ func TestRulesValidate(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "no default rule",
|
name: "no default rule",
|
||||||
rules: `
|
rules: `
|
||||||
- name: rule1
|
header Host example.com {
|
||||||
on: header Host example.com
|
pass
|
||||||
do: pass
|
}`,
|
||||||
`,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple default rules",
|
name: "multiple default rules",
|
||||||
rules: `
|
rules: `
|
||||||
- name: default
|
default {
|
||||||
do: pass
|
pass
|
||||||
- name: rule1
|
}
|
||||||
on: default
|
|
||||||
do: pass
|
default {
|
||||||
`,
|
pass
|
||||||
|
}`,
|
||||||
want: ErrMultipleDefaultRules,
|
want: ErrMultipleDefaultRules,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "multiple responses on same condition",
|
||||||
|
rules: `
|
||||||
|
header Host example.com {
|
||||||
|
error 404 "not found"
|
||||||
|
}
|
||||||
|
|
||||||
|
header Host example.com {
|
||||||
|
error 403 "forbidden"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: ErrDeadRule,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same condition different formatting error then proxy",
|
||||||
|
rules: `
|
||||||
|
header Host example.com & method GET {
|
||||||
|
error 404 "not found"
|
||||||
|
}
|
||||||
|
|
||||||
|
method GET
|
||||||
|
header Host example.com {
|
||||||
|
proxy http://127.0.0.1:8080
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: ErrDeadRule,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same condition with non terminating first rule",
|
||||||
|
rules: `
|
||||||
|
header Host example.com {
|
||||||
|
set resp_header X-Test first
|
||||||
|
}
|
||||||
|
|
||||||
|
header Host example.com {
|
||||||
|
error 403 "forbidden"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same condition with terminating handler inside if block",
|
||||||
|
rules: `
|
||||||
|
header Host example.com {
|
||||||
|
default {
|
||||||
|
error 404 "not found"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
header Host example.com {
|
||||||
|
error 403 "forbidden"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: ErrDeadRule,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same condition with terminating handler across if else block",
|
||||||
|
rules: `
|
||||||
|
header Host example.com {
|
||||||
|
method GET {
|
||||||
|
error 404 "not found"
|
||||||
|
} else {
|
||||||
|
redirect https://example.com
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
header Host example.com {
|
||||||
|
error 403 "forbidden"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: ErrDeadRule,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same condition with non terminating if branch in if else block",
|
||||||
|
rules: `
|
||||||
|
header Host example.com {
|
||||||
|
method GET {
|
||||||
|
set resp_header X-Test first
|
||||||
|
} else {
|
||||||
|
error 404 "not found"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
header Host example.com {
|
||||||
|
error 403 "forbidden"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unconditional terminating rule shadows later unconditional rule",
|
||||||
|
rules: `
|
||||||
|
{
|
||||||
|
error 404 "not found"
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
error 403 "forbidden"
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
want: ErrDeadRule,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var rules Rules
|
var rules Rules
|
||||||
convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules))
|
convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules))
|
||||||
require.True(t, convertible)
|
require.True(t, convertible)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = rules.Validate()
|
||||||
|
|
||||||
if tt.want == nil {
|
if tt.want == nil {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -50,3 +155,50 @@ func TestRulesValidate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHasTopLevelLBrace(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "escaped quote inside double quoted string",
|
||||||
|
in: `"test\"more{"`,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped quote inside single quoted string",
|
||||||
|
in: "'test\\'more{'",
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top-level brace outside quoted string",
|
||||||
|
in: `"test\"more" {`,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "backtick keeps existing behavior",
|
||||||
|
in: "`test\\`more{`",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.want, hasTopLevelLBrace(tt.in))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesParse_BlockTriedThenYAMLFails_ReturnsBlockError(t *testing.T) {
|
||||||
|
input := `default {`
|
||||||
|
|
||||||
|
_, blockErr := parseBlockRules(input)
|
||||||
|
require.Error(t, blockErr)
|
||||||
|
|
||||||
|
var rules Rules
|
||||||
|
err := rules.Parse(input)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, blockErr.Error(), err.Error())
|
||||||
|
}
|
||||||
|
|||||||
273
internal/route/rules/scanner.go
Normal file
273
internal/route/rules/scanner.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tokenizer provides utilities for parsing rule syntax with proper handling
|
||||||
|
// of quotes, comments, and env vars.
|
||||||
|
//
|
||||||
|
// This is intentionally reusable by both the top-level rule block parser and
|
||||||
|
// the nested do-block parser.
|
||||||
|
type Tokenizer struct {
|
||||||
|
src string
|
||||||
|
length int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTokenizer creates a tokenizer for the given source.
|
||||||
|
func newTokenizer(src string) Tokenizer {
|
||||||
|
return Tokenizer{src: src, length: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// skipComments skips whitespace, line comments, and block comments.
|
||||||
|
// It returns the new position and an error if a block comment is unterminated.
|
||||||
|
func (t *Tokenizer) skipComments(pos int, atLineStart bool, prevIsSpace bool) (int, gperr.Error) {
|
||||||
|
for pos < t.length {
|
||||||
|
c := t.src[pos]
|
||||||
|
|
||||||
|
// Skip whitespace
|
||||||
|
if unicode.IsSpace(rune(c)) {
|
||||||
|
pos++
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for line comment: // or #
|
||||||
|
if c == '/' {
|
||||||
|
if pos+1 < t.length && t.src[pos+1] == '/' {
|
||||||
|
// Skip to end of line
|
||||||
|
for pos < t.length && t.src[pos] != '\n' {
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c == '#' && (atLineStart || prevIsSpace) {
|
||||||
|
// Skip to end of line
|
||||||
|
for pos < t.length && t.src[pos] != '\n' {
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for block comment: /*
|
||||||
|
if c == '/' && pos+1 < t.length && t.src[pos+1] == '*' {
|
||||||
|
pos += 2
|
||||||
|
closed := false
|
||||||
|
for pos+1 < t.length {
|
||||||
|
if t.src[pos] == '*' && t.src[pos+1] == '/' {
|
||||||
|
pos += 2
|
||||||
|
closed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
if !closed {
|
||||||
|
return 0, ErrInvalidBlockSyntax.Withf("unterminated block comment")
|
||||||
|
}
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return pos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanToBrace scans from pos until it finds '{' outside quotes, or returns an error.
|
||||||
|
func (t *Tokenizer) scanToBrace(pos int) (int, gperr.Error) {
|
||||||
|
quote := rune(0)
|
||||||
|
for pos < t.length {
|
||||||
|
c := rune(t.src[pos])
|
||||||
|
if quote != 0 {
|
||||||
|
if c == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '"' || c == '\'' || c == '`' {
|
||||||
|
quote = c
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '{' {
|
||||||
|
return pos, nil
|
||||||
|
}
|
||||||
|
if c == '}' {
|
||||||
|
return 0, ErrInvalidBlockSyntax.Withf("unmatched '}' in block header")
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
return 0, ErrInvalidBlockSyntax.Withf("expected '{' after block header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingBrace finds the matching '}' for a '{' starting at startPos.
|
||||||
|
// It respects quotes/backticks and ${...} env vars.
|
||||||
|
func (t *Tokenizer) findMatchingBrace(startPos int) (int, gperr.Error) {
|
||||||
|
pos := startPos
|
||||||
|
braceDepth := 1
|
||||||
|
quote := rune(0)
|
||||||
|
inLine := false
|
||||||
|
inBlock := false
|
||||||
|
atLineStart := true
|
||||||
|
prevIsSpace := true
|
||||||
|
|
||||||
|
for pos < t.length {
|
||||||
|
c := rune(t.src[pos])
|
||||||
|
|
||||||
|
if inLine {
|
||||||
|
if c == '\n' {
|
||||||
|
inLine = false
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inBlock {
|
||||||
|
if c == '*' && pos+1 < t.length && t.src[pos+1] == '/' {
|
||||||
|
pos += 2
|
||||||
|
inBlock = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '\n' {
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if quote != 0 {
|
||||||
|
if c == quote {
|
||||||
|
quote = 0
|
||||||
|
}
|
||||||
|
if c == '\n' {
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
} else {
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = unicode.IsSpace(c)
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '"' || c == '\'' || c == '`' {
|
||||||
|
quote = c
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = false
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comments (only outside quotes) at token boundary
|
||||||
|
if c == '#' && (atLineStart || prevIsSpace) {
|
||||||
|
inLine = true
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '/' && pos+1 < t.length {
|
||||||
|
n := rune(t.src[pos+1])
|
||||||
|
if (atLineStart || prevIsSpace) && n == '/' {
|
||||||
|
inLine = true
|
||||||
|
pos += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (atLineStart || prevIsSpace) && n == '*' {
|
||||||
|
inBlock = true
|
||||||
|
pos += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '$' && pos+1 < t.length && t.src[pos+1] == '{' {
|
||||||
|
// Skip env var ${...}
|
||||||
|
pos += 2
|
||||||
|
envBraceDepth := 1
|
||||||
|
envQuote := rune(0)
|
||||||
|
for pos < t.length {
|
||||||
|
ec := rune(t.src[pos])
|
||||||
|
if envQuote != 0 {
|
||||||
|
if ec == envQuote {
|
||||||
|
envQuote = 0
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ec == '"' || ec == '\'' || ec == '`' {
|
||||||
|
envQuote = ec
|
||||||
|
pos++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ec == '{' {
|
||||||
|
envBraceDepth++
|
||||||
|
} else if ec == '}' {
|
||||||
|
envBraceDepth--
|
||||||
|
if envBraceDepth == 0 {
|
||||||
|
pos++ // Move past the closing '}'
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c {
|
||||||
|
case '{':
|
||||||
|
braceDepth++
|
||||||
|
case '}':
|
||||||
|
braceDepth--
|
||||||
|
if braceDepth == 0 {
|
||||||
|
return pos, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '\n' {
|
||||||
|
atLineStart = true
|
||||||
|
prevIsSpace = true
|
||||||
|
} else {
|
||||||
|
atLineStart = false
|
||||||
|
prevIsSpace = unicode.IsSpace(c)
|
||||||
|
}
|
||||||
|
pos++
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, ErrInvalidBlockSyntax.Withf("unmatched '{' at position %d", startPos)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHeaderToBrace parses an expression/header starting at start and returns:
|
||||||
|
// - header: trimmed src[start:bracePos]
|
||||||
|
// - bracePos: position of '{' (outside quotes/backticks)
|
||||||
|
func parseHeaderToBrace(src string, start int) (header string, bracePos int, err gperr.Error) {
|
||||||
|
t := newTokenizer(src)
|
||||||
|
bracePos, err = t.scanToBrace(start)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(src[start:bracePos]), bracePos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingBrace finds the matching '}' for a '{' at position startPos.
|
||||||
|
// It respects quotes/backticks and ${...} env vars in do_body.
|
||||||
|
func findMatchingBrace(src string, pos *int, startPos int) (int, gperr.Error) {
|
||||||
|
t := newTokenizer(src)
|
||||||
|
endPos, err := t.findMatchingBrace(startPos)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
*pos = endPos + 1
|
||||||
|
return endPos, nil
|
||||||
|
}
|
||||||
39
internal/route/rules/scanner_test.go
Normal file
39
internal/route/rules/scanner_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package rules
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenizer_skipComments_UnterminatedBlockComment(t *testing.T) {
|
||||||
|
tok := newTokenizer("/* unterminated")
|
||||||
|
pos, err := tok.skipComments(0, true, true)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, 0, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenizer_skipComments_SkipsLineAndBlockComments(t *testing.T) {
|
||||||
|
src := " // line\n /* block */\n # hash\n default"
|
||||||
|
tok := newTokenizer(src)
|
||||||
|
pos, err := tok.skipComments(0, true, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, strings.Index(src, "default"), pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenizer_scanToBrace_IgnoresQuotedBraces(t *testing.T) {
|
||||||
|
src := "cond \"{\" {" // the brace inside quotes must be ignored
|
||||||
|
tok := newTokenizer(src)
|
||||||
|
bracePos, err := tok.scanToBrace(0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, strings.LastIndex(src, "{"), bracePos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenizer_findMatchingBrace_IgnoresQuotedClosingBrace(t *testing.T) {
|
||||||
|
src := `{ "}" }`
|
||||||
|
tok := newTokenizer(src)
|
||||||
|
endPos, err := tok.findMatchingBrace(1) // body starts after the first '{'
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, strings.LastIndex(src, "}"), endPos)
|
||||||
|
}
|
||||||
@@ -4,13 +4,13 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
httputils "github.com/yusing/goutils/http"
|
httputils "github.com/yusing/goutils/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type templateString struct {
|
type templateString struct {
|
||||||
string
|
string
|
||||||
|
|
||||||
isTemplate bool
|
isTemplate bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,32 +23,28 @@ func (tmpl *keyValueTemplate) Unpack() (string, templateString) {
|
|||||||
return tmpl.key, tmpl.tmpl
|
return tmpl.key, tmpl.tmpl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error {
|
func (tmpl *templateString) ExpandVars(w *httputils.ResponseModifier, req *http.Request, dst io.Writer) (phase PhaseFlag, err error) {
|
||||||
if !tmpl.isTemplate {
|
if !tmpl.isTemplate {
|
||||||
_, err := dstW.Write(strtobNoCopy(tmpl.string))
|
_, err := asBytesBufferLike(dst).WriteString(tmpl.string)
|
||||||
return err
|
return PhaseNone, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW)
|
return ExpandVars(w, req, tmpl.string, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
|
func (tmpl *templateString) ExpandVarsToString(w *httputils.ResponseModifier, r *http.Request) (string, PhaseFlag, error) {
|
||||||
if !tmpl.isTemplate {
|
if !tmpl.isTemplate {
|
||||||
return tmpl.string, nil
|
return tmpl.string, PhaseNone, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf strings.Builder
|
var buf strings.Builder
|
||||||
err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf)
|
phase, err := tmpl.ExpandVars(w, r, &buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", PhaseNone, err
|
||||||
}
|
}
|
||||||
return buf.String(), nil
|
return buf.String(), phase, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tmpl *templateString) Len() int {
|
func (tmpl *templateString) Len() int {
|
||||||
return len(tmpl.string)
|
return len(tmpl.string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func strtobNoCopy(s string) []byte {
|
|
||||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
@@ -16,7 +17,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
ValidateFunc func(args []string) (any, error)
|
ValidateFunc func(args []string) (phase PhaseFlag, parsedArgs any, err error)
|
||||||
Tuple[T1, T2 any] struct {
|
Tuple[T1, T2 any] struct {
|
||||||
First T1
|
First T1
|
||||||
Second T2
|
Second T2
|
||||||
@@ -37,6 +38,8 @@ type (
|
|||||||
MapValueMatcher = Tuple[string, Matcher]
|
MapValueMatcher = Tuple[string, Matcher]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var cidrCache = xsync.NewMap[string, *net.IPNet]()
|
||||||
|
|
||||||
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
||||||
return t.First, t.Second
|
return t.First, t.Second
|
||||||
}
|
}
|
||||||
@@ -62,7 +65,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||||
func validateSingleMatcher(args []string) (any, error) {
|
func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
@@ -70,7 +73,7 @@ func validateSingleMatcher(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||||
func toKVOptionalVMatcher(args []string) (any, error) {
|
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||||
switch len(args) {
|
switch len(args) {
|
||||||
case 1:
|
case 1:
|
||||||
return &MapValueMatcher{args[0], nil}, nil
|
return &MapValueMatcher{args[0], nil}, nil
|
||||||
@@ -85,20 +88,8 @@ func toKVOptionalVMatcher(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toKeyValueTemplate(args []string) (any, error) {
|
|
||||||
if len(args) != 2 {
|
|
||||||
return nil, ErrExpectTwoArgs
|
|
||||||
}
|
|
||||||
|
|
||||||
isTemplate, err := validateTemplate(args[1], false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &keyValueTemplate{args[0], isTemplate}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateURL returns types.URL with the URL validated.
|
// validateURL returns types.URL with the URL validated.
|
||||||
func validateURL(args []string) (any, error) {
|
func validateURL(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
@@ -116,22 +107,27 @@ func validateURL(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateCIDR returns types.CIDR with the CIDR validated.
|
// validateCIDR returns types.CIDR with the CIDR validated.
|
||||||
func validateCIDR(args []string) (any, error) {
|
func validateCIDR(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
if !strings.Contains(args[0], "/") {
|
cidr := args[0]
|
||||||
args[0] += "/32"
|
if !strings.Contains(cidr, "/") {
|
||||||
|
cidr += "/32"
|
||||||
}
|
}
|
||||||
_, ipnet, err := net.ParseCIDR(args[0])
|
if cached, ok := cidrCache.Load(cidr); ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(cidr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidArguments.With(err)
|
return nil, ErrInvalidArguments.With(err)
|
||||||
}
|
}
|
||||||
|
cidrCache.Store(cidr, ipnet)
|
||||||
return ipnet, nil
|
return ipnet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateURLPath returns string with the path validated.
|
// validateURLPath returns string with the path validated.
|
||||||
func validateURLPath(args []string) (any, error) {
|
func validateURLPath(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
@@ -148,7 +144,7 @@ func validateURLPath(args []string) (any, error) {
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateURLPathMatcher(args []string) (any, error) {
|
func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||||
path, err := validateURLPath(args)
|
path, err := validateURLPath(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -157,7 +153,7 @@ func validateURLPathMatcher(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateFSPath returns string with the path validated.
|
// validateFSPath returns string with the path validated.
|
||||||
func validateFSPath(args []string) (any, error) {
|
func validateFSPath(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
@@ -169,7 +165,7 @@ func validateFSPath(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateMethod returns string with the method validated.
|
// validateMethod returns string with the method validated.
|
||||||
func validateMethod(args []string) (any, error) {
|
func validateMethod(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
@@ -200,7 +196,7 @@ func validateStatusCode(status string) (int, error) {
|
|||||||
// - 3xx
|
// - 3xx
|
||||||
// - 4xx
|
// - 4xx
|
||||||
// - 5xx
|
// - 5xx
|
||||||
func validateStatusRange(args []string) (any, error) {
|
func validateStatusRange(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return nil, ErrExpectOneArg
|
return nil, ErrExpectOneArg
|
||||||
}
|
}
|
||||||
@@ -232,7 +228,7 @@ func validateStatusRange(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||||
func validateUserBCryptPassword(args []string) (any, error) {
|
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return nil, ErrExpectTwoArgs
|
return nil, ErrExpectTwoArgs
|
||||||
}
|
}
|
||||||
@@ -240,64 +236,93 @@ func validateUserBCryptPassword(args []string) (any, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateModField returns CommandHandler with the field validated.
|
// validateModField returns CommandHandler with the field validated.
|
||||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, error) {
|
func validateModField(mod FieldModifier, args []string) (phase PhaseFlag, handler HandlerFunc, err error) {
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
return nil, ErrExpectTwoOrThreeArgs
|
return phase, nil, ErrExpectTwoOrThreeArgs
|
||||||
}
|
}
|
||||||
setField, ok := modFields[args[0]]
|
setField, ok := modFields[args[0]]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrUnknownModField.Subject(args[0])
|
return phase, nil, ErrUnknownModField.Subject(args[0])
|
||||||
}
|
}
|
||||||
if mod == ModFieldRemove {
|
if mod == ModFieldRemove {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return nil, ErrExpectTwoArgs
|
return phase, nil, ErrExpectTwoArgs
|
||||||
}
|
}
|
||||||
// setField expect validateStrTuple
|
// setField expect validateStrTuple
|
||||||
args = append(args, "")
|
args = append(args, "")
|
||||||
}
|
}
|
||||||
validArgs, err := setField.validate(args[1:])
|
phase, validArgs, err := setField.validate(args[1:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, gperr.Wrap(err).With(setField.help.Error())
|
return phase, nil, gperr.Wrap(err).With(setField.help.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
modder := setField.builder(validArgs)
|
modder := setField.builder(validArgs)
|
||||||
switch mod {
|
switch mod {
|
||||||
case ModFieldAdd:
|
case ModFieldAdd:
|
||||||
add := modder.add
|
add := modder.add
|
||||||
if add == nil {
|
if add == nil {
|
||||||
return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod)
|
return phase, nil, ErrInvalidArguments.Withf("add is not supported for field %s", args[0])
|
||||||
}
|
}
|
||||||
return add, nil
|
return phase, add, nil
|
||||||
case ModFieldRemove:
|
case ModFieldRemove:
|
||||||
remove := modder.remove
|
remove := modder.remove
|
||||||
if remove == nil {
|
if remove == nil {
|
||||||
return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod)
|
return phase, nil, ErrInvalidArguments.Withf("remove is not supported for field %s", args[0])
|
||||||
}
|
}
|
||||||
return remove, nil
|
return phase, remove, nil
|
||||||
}
|
}
|
||||||
set := modder.set
|
set := modder.set
|
||||||
if set == nil {
|
if set == nil {
|
||||||
return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod)
|
return phase, nil, ErrInvalidArguments.Withf("set is not supported for field %s", args[0])
|
||||||
}
|
}
|
||||||
return set, nil
|
return phase, set, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateTemplate(tmplStr string, newline bool) (templateString, error) {
|
func validateTemplate(tmplStr string, newline bool) (phase PhaseFlag, tmpl templateString, err error) {
|
||||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||||
tmplStr += "\n"
|
tmplStr += "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
if !NeedExpandVars(tmplStr) {
|
if !NeedExpandVars(tmplStr) {
|
||||||
return templateString{tmplStr, false}, nil
|
return phase, templateString{tmplStr, false}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ValidateVars(tmplStr)
|
phase, err = ValidateVars(tmplStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return templateString{}, err
|
return phase, templateString{}, gperr.Wrap(err)
|
||||||
}
|
}
|
||||||
return templateString{tmplStr, true}, nil
|
return phase, templateString{tmplStr, true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateLevel(level string) (zerolog.Level, error) {
|
func validatePreRequestKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
if len(args) != 2 {
|
||||||
|
return phase, nil, ErrExpectTwoArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
phase = PhasePre
|
||||||
|
tmplReq, tmpl, err := validateTemplate(args[1], false)
|
||||||
|
if err != nil {
|
||||||
|
return phase, nil, err
|
||||||
|
}
|
||||||
|
phase |= tmplReq
|
||||||
|
return phase, &keyValueTemplate{args[0], tmpl}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePostResponseKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||||
|
if len(args) != 2 {
|
||||||
|
return phase, nil, ErrExpectTwoArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
phase = PhasePost
|
||||||
|
tmplReq, tmpl, err := validateTemplate(args[1], false)
|
||||||
|
if err != nil {
|
||||||
|
return phase, nil, err
|
||||||
|
}
|
||||||
|
phase |= tmplReq
|
||||||
|
return phase, &keyValueTemplate{args[0], tmpl}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||||
l, err := zerolog.ParseLevel(level)
|
l, err := zerolog.ParseLevel(level)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func BenchmarkExpandVars(b *testing.B) {
|
|||||||
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
|
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
package rules
|
package rules
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
httputils "github.com/yusing/goutils/http"
|
httputils "github.com/yusing/goutils/http"
|
||||||
ioutils "github.com/yusing/goutils/io"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: remove middleware/vars.go and use this instead
|
// TODO: remove middleware/vars.go and use this instead
|
||||||
@@ -45,41 +46,84 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type bytesBufferLike interface {
|
||||||
|
io.Writer
|
||||||
|
WriteByte(c byte) error
|
||||||
|
WriteString(s string) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type bytesBufferAdapter struct {
|
||||||
|
io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b bytesBufferAdapter) WriteByte(c byte) error {
|
||||||
|
buf := [1]byte{c}
|
||||||
|
_, err := b.Write(buf[:])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b bytesBufferAdapter) WriteString(s string) (int, error) {
|
||||||
|
return b.Write(unsafe.Slice(unsafe.StringData(s), len(s))) // avoid copy
|
||||||
|
}
|
||||||
|
|
||||||
|
func asBytesBufferLike(w io.Writer) bytesBufferLike {
|
||||||
|
switch w := w.(type) {
|
||||||
|
case *bytes.Buffer:
|
||||||
|
return w
|
||||||
|
case bytesBufferLike:
|
||||||
|
return w
|
||||||
|
default:
|
||||||
|
return bytesBufferAdapter{w}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateVars validates the variables in the given string.
|
// ValidateVars validates the variables in the given string.
|
||||||
// It returns ErrUnexpectedVar if any invalid variable is found.
|
// It returns the phase that the variables require and an error if any error occurs.
|
||||||
func ValidateVars(s string) error {
|
//
|
||||||
|
// Possible errors:
|
||||||
|
// - ErrUnexpectedVar: if any invalid variable is found
|
||||||
|
// - ErrUnterminatedEnvVar: missing closing }
|
||||||
|
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
||||||
|
// - ErrUnterminatedParenthesis: missing closing )
|
||||||
|
func ValidateVars(s string) (phase PhaseFlag, err error) {
|
||||||
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
|
// ExpandVars expands the variables in the given string and writes the result to the given writer.
|
||||||
dst := ioutils.NewBufferedWriter(dstW, 1024)
|
// It returns the phase that the variables require and an error if any error occurs.
|
||||||
defer dst.Close()
|
//
|
||||||
|
// Possible errors:
|
||||||
|
// - ErrUnexpectedVar: if any invalid variable is found
|
||||||
|
// - ErrUnterminatedEnvVar: missing closing }
|
||||||
|
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
||||||
|
// - ErrUnterminatedParenthesis: missing closing )
|
||||||
|
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) (phase PhaseFlag, err error) {
|
||||||
|
dst := asBytesBufferLike(dstW)
|
||||||
for i := 0; i < len(src); i++ {
|
for i := 0; i < len(src); i++ {
|
||||||
ch := src[i]
|
ch := src[i]
|
||||||
if ch != '$' {
|
if ch != '$' {
|
||||||
if err := dst.WriteByte(ch); err != nil {
|
if err = dst.WriteByte(ch); err != nil {
|
||||||
return err
|
return phase, err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Look ahead
|
// Look ahead
|
||||||
if i+1 >= len(src) {
|
if i+1 >= len(src) {
|
||||||
return ErrUnterminatedEnvVar
|
return phase, ErrUnterminatedEnvVar
|
||||||
}
|
}
|
||||||
j := i + 1
|
j := i + 1
|
||||||
|
|
||||||
switch src[j] {
|
switch src[j] {
|
||||||
case '$': // $$ -> literal '$'
|
case '$': // $$ -> literal '$'
|
||||||
if err := dst.WriteByte('$'); err != nil {
|
if err := dst.WriteByte('$'); err != nil {
|
||||||
return err
|
return phase, err
|
||||||
}
|
}
|
||||||
i = j
|
i = j
|
||||||
continue
|
continue
|
||||||
case '{': // ${...} pass through as-is
|
case '{': // ${...} pass through as-is
|
||||||
if _, err := dst.WriteString("${"); err != nil {
|
if _, err := dst.WriteString("${"); err != nil {
|
||||||
return err
|
return phase, err
|
||||||
}
|
}
|
||||||
i = j // we've consumed the '{' too
|
i = j // we've consumed the '{' too
|
||||||
continue
|
continue
|
||||||
@@ -102,24 +146,26 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
|
|||||||
if getter, ok := dynamicVarSubsMap[name]; ok {
|
if getter, ok := dynamicVarSubsMap[name]; ok {
|
||||||
// Function-like variables
|
// Function-like variables
|
||||||
isStatic = false
|
isStatic = false
|
||||||
|
phase |= getter.phase
|
||||||
args, nextIdx, err := extractArgs(src, j, name)
|
args, nextIdx, err := extractArgs(src, j, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return phase, err
|
||||||
}
|
}
|
||||||
i = nextIdx
|
i = nextIdx
|
||||||
actual, err = getter(args, w, req)
|
actual, err = getter.get(args, w, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return phase, err
|
||||||
}
|
}
|
||||||
} else if getter, ok := staticReqVarSubsMap[name]; ok {
|
} else if getter, ok := staticReqVarSubsMap[name]; ok { // always available
|
||||||
actual = getter(req)
|
actual = getter(req)
|
||||||
} else if getter, ok := staticRespVarSubsMap[name]; ok {
|
} else if getter, ok := staticRespVarSubsMap[name]; ok { // post response
|
||||||
actual = getter(w)
|
actual = getter(w)
|
||||||
|
phase |= PhasePost
|
||||||
} else {
|
} else {
|
||||||
return ErrUnexpectedVar.Subject(name)
|
return phase, ErrUnexpectedVar.Subject(name)
|
||||||
}
|
}
|
||||||
if _, err := dst.WriteString(actual); err != nil {
|
if _, err := dst.WriteString(actual); err != nil {
|
||||||
return err
|
return phase, err
|
||||||
}
|
}
|
||||||
if isStatic {
|
if isStatic {
|
||||||
i = k - 1
|
i = k - 1
|
||||||
@@ -128,10 +174,10 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
|
|||||||
}
|
}
|
||||||
|
|
||||||
// No valid construct after '$'
|
// No valid construct after '$'
|
||||||
return ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
return phase, ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return phase, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
|
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
|
||||||
|
|||||||
@@ -11,36 +11,62 @@ import (
|
|||||||
var (
|
var (
|
||||||
VarHeader = "header"
|
VarHeader = "header"
|
||||||
VarResponseHeader = "resp_header"
|
VarResponseHeader = "resp_header"
|
||||||
|
VarCookie = "cookie"
|
||||||
VarQuery = "arg"
|
VarQuery = "arg"
|
||||||
VarForm = "form"
|
VarForm = "form"
|
||||||
VarPostForm = "postform"
|
VarPostForm = "postform"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
type dynamicVarGetter struct {
|
||||||
|
phase PhaseFlag
|
||||||
|
get func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
||||||
VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
VarHeader: {
|
||||||
|
phase: PhaseNone,
|
||||||
|
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(req.Header, key, index)
|
return getValueByKeyAtIndex(req.Header, key, index)
|
||||||
},
|
},
|
||||||
VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
},
|
||||||
|
VarResponseHeader: {
|
||||||
|
phase: PhasePost,
|
||||||
|
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||||
},
|
},
|
||||||
VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
},
|
||||||
|
VarCookie: {
|
||||||
|
phase: PhaseNone,
|
||||||
|
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
|
key, index, err := getKeyAndIndex(args)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sharedData := httputils.GetSharedData(w)
|
||||||
|
return getValueByKeyAtIndex(sharedData.GetCookiesMap(req), key, index)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
VarQuery: {
|
||||||
|
phase: PhaseNone,
|
||||||
|
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
||||||
},
|
},
|
||||||
VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
},
|
||||||
|
VarForm: {
|
||||||
|
phase: PhaseNone,
|
||||||
|
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -52,7 +78,10 @@ var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
|||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(req.Form, key, index)
|
return getValueByKeyAtIndex(req.Form, key, index)
|
||||||
},
|
},
|
||||||
VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
},
|
||||||
|
VarPostForm: {
|
||||||
|
phase: PhaseNone,
|
||||||
|
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||||
key, index, err := getKeyAndIndex(args)
|
key, index, err := getKeyAndIndex(args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -64,6 +93,7 @@ var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
|||||||
}
|
}
|
||||||
return getValueByKeyAtIndex(req.PostForm, key, index)
|
return getValueByKeyAtIndex(req.PostForm, key, index)
|
||||||
},
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func getValueByKeyAtIndex[Values http.Header | url.Values](values Values, key string, index int) (string, error) {
|
func getValueByKeyAtIndex[Values http.Header | url.Values](values Values, key string, index int) (string, error) {
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ func TestExpandVars(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "req_method",
|
name: "req_method",
|
||||||
input: "$req_method",
|
input: "$req_method",
|
||||||
want: "POST",
|
want: http.MethodPost,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "req_path",
|
name: "req_path",
|
||||||
@@ -484,7 +484,7 @@ func TestExpandVars(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
@@ -506,7 +506,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
|||||||
testResponseModifier.WriteHeader(http.StatusOK)
|
testResponseModifier.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest,
|
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||||
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
||||||
&out)
|
&out)
|
||||||
|
|
||||||
@@ -520,7 +520,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
|||||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest,
|
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||||
"Query: $arg(q), Page: $arg(page)",
|
"Query: $arg(q), Page: $arg(page)",
|
||||||
&out)
|
&out)
|
||||||
|
|
||||||
@@ -537,7 +537,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
|||||||
testResponseModifier.WriteHeader(http.StatusOK)
|
testResponseModifier.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest,
|
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||||
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
||||||
&out)
|
&out)
|
||||||
|
|
||||||
@@ -560,7 +560,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "https scheme",
|
name: "https scheme",
|
||||||
request: &http.Request{
|
request: &http.Request{
|
||||||
Method: "GET",
|
Method: http.MethodGet,
|
||||||
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
|
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
|
||||||
TLS: &tls.ConnectionState{}, // Simulate TLS connection
|
TLS: &tls.ConnectionState{}, // Simulate TLS connection
|
||||||
},
|
},
|
||||||
@@ -572,7 +572,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
_, err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.expected, out.String())
|
require.Equal(t, tt.expected, out.String())
|
||||||
})
|
})
|
||||||
@@ -598,7 +598,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
|
|||||||
for _, varExpr := range upstreamVars {
|
for _, varExpr := range upstreamVars {
|
||||||
t.Run(varExpr, func(t *testing.T) {
|
t.Run(varExpr, func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
||||||
// Should not error, may return empty string
|
// Should not error, may return empty string
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
@@ -614,16 +614,16 @@ func TestExpandVars_NoHostPort(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("req_host without port", func(t *testing.T) {
|
t.Run("req_host without port", func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "example.com", out.String())
|
require.Equal(t, "example.com", out.String())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("req_port without port", func(t *testing.T) {
|
t.Run("req_port without port", func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Empty(t, out.String())
|
require.Equal(t, "", out.String())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -636,16 +636,16 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("remote_host without port", func(t *testing.T) {
|
t.Run("remote_host without port", func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Empty(t, out.String())
|
require.Equal(t, "", out.String())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remote_port without port", func(t *testing.T) {
|
t.Run("remote_port without port", func(t *testing.T) {
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Empty(t, out.String())
|
require.Equal(t, "", out.String())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -654,7 +654,7 @@ func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
|||||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||||
|
|
||||||
var out strings.Builder
|
var out strings.Builder
|
||||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "GET /test", out.String())
|
require.Equal(t, "GET /test", out.String())
|
||||||
}
|
}
|
||||||
@@ -699,7 +699,7 @@ func TestValidateVars(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
err := ValidateVars(tt.input)
|
_, err := ValidateVars(tt.input)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user