mirror of
https://github.com/yusing/godoxy.git
synced 2026-02-24 17:54:57 +01:00
Compare commits
9 Commits
feat/rules
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6576b7640a | ||
|
|
d4e552754e | ||
|
|
9ca2983a52 | ||
|
|
ed2ca236b0 | ||
|
|
0eba045104 | ||
|
|
77f2779114 | ||
|
|
743eb03b27 | ||
|
|
d2d686b4d1 | ||
|
|
169358659a |
4
Makefile
4
Makefile
@@ -6,8 +6,8 @@ export GOOS = linux
|
||||
|
||||
REPO_URL ?= https://github.com/yusing/godoxy
|
||||
|
||||
WEBUI_DIR ?= ../godoxy-webui
|
||||
DOCS_DIR ?= wiki
|
||||
WEBUI_DIR ?= $(shell pwd)/../godoxy-webui
|
||||
DOCS_DIR ?= ${WEBUI_DIR}/wiki
|
||||
|
||||
ifneq ($(BRANCH), compat)
|
||||
GO_TAGS = sonic
|
||||
|
||||
@@ -46,7 +46,7 @@ require (
|
||||
github.com/docker/cli v29.2.1+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.9.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
@@ -91,8 +91,8 @@ require (
|
||||
github.com/valyala/fasthttp v1.69.0 // indirect
|
||||
github.com/yusing/ds v0.4.1 // indirect
|
||||
github.com/yusing/gointernals v0.2.0 // indirect
|
||||
github.com/yusing/goutils/http/reverseproxy v0.0.0-20260218062549-0b0fa3a059ec // indirect
|
||||
github.com/yusing/goutils/http/websocket v0.0.0-20260218062549-0b0fa3a059ec // indirect
|
||||
github.com/yusing/goutils/http/reverseproxy v0.0.0-20260223150038-3be815cb6e3b // indirect
|
||||
github.com/yusing/goutils/http/websocket v0.0.0-20260223150038-3be815cb6e3b // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
|
||||
|
||||
@@ -43,8 +43,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
|
||||
github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
|
||||
16
go.mod
16
go.mod
@@ -58,13 +58,13 @@ require (
|
||||
github.com/stretchr/testify v1.11.1 // testing framework
|
||||
github.com/valyala/fasthttp v1.69.0 // fast http for health check
|
||||
github.com/yusing/ds v0.4.1 // data structures and algorithms
|
||||
github.com/yusing/godoxy/agent v0.0.0-20260218101334-add7884a365e
|
||||
github.com/yusing/godoxy/internal/dnsproviders v0.0.0-20260218101334-add7884a365e
|
||||
github.com/yusing/godoxy/agent v0.0.0-20260224071728-0eba04510480
|
||||
github.com/yusing/godoxy/internal/dnsproviders v0.0.0-20260224071728-0eba04510480
|
||||
github.com/yusing/gointernals v0.2.0
|
||||
github.com/yusing/goutils v0.7.0
|
||||
github.com/yusing/goutils/http/reverseproxy v0.0.0-20260218062549-0b0fa3a059ec
|
||||
github.com/yusing/goutils/http/websocket v0.0.0-20260218062549-0b0fa3a059ec
|
||||
github.com/yusing/goutils/server v0.0.0-20260218062549-0b0fa3a059ec
|
||||
github.com/yusing/goutils/http/reverseproxy v0.0.0-20260223150038-3be815cb6e3b
|
||||
github.com/yusing/goutils/http/websocket v0.0.0-20260223150038-3be815cb6e3b
|
||||
github.com/yusing/goutils/server v0.0.0-20260223150038-3be815cb6e3b
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -88,7 +88,7 @@ require (
|
||||
github.com/djherbis/times v1.6.0 // indirect
|
||||
github.com/docker/go-connections v0.6.0
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/ebitengine/purego v0.9.1 // indirect
|
||||
github.com/ebitengine/purego v0.10.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
@@ -142,8 +142,8 @@ require (
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
google.golang.org/api v0.267.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d // indirect
|
||||
google.golang.org/api v0.268.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260223185530-2f722ef697dc // indirect
|
||||
google.golang.org/grpc v1.79.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/ini.v1 v1.67.1 // indirect
|
||||
|
||||
12
go.sum
12
go.sum
@@ -82,8 +82,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A=
|
||||
github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/elliotwutingfeng/asciiset v0.0.0-20230602022725-51bbb787efab h1:h1UgjJdAAhj+uPL68n7XASS6bU+07ZX1WJvVS2eyoeY=
|
||||
github.com/elliotwutingfeng/asciiset v0.0.0-20230602022725-51bbb787efab/go.mod h1:GLo/8fDswSAniFG+BFIaiSPcK610jyzgEhWYPQwuQdw=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
@@ -447,14 +447,14 @@ golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.267.0 h1:w+vfWPMPYeRs8qH1aYYsFX68jMls5acWl/jocfLomwE=
|
||||
google.golang.org/api v0.267.0/go.mod h1:Jzc0+ZfLnyvXma3UtaTl023TdhZu6OMBP9tJ+0EmFD0=
|
||||
google.golang.org/api v0.268.0 h1:hgA3aS4lt9rpF5RCCkX0Q2l7DvHgvlb53y4T4u6iKkA=
|
||||
google.golang.org/api v0.268.0/go.mod h1:HXMyMH496wz+dAJwD/GkAPLd3ZL33Kh0zEG32eNvy9w=
|
||||
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM=
|
||||
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d h1:t/LOSXPJ9R0B6fnZNyALBRfZBH0Uy0gT+uR+SJ6syqQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260223185530-2f722ef697dc h1:51Wupg8spF+5FC6D+iMKbOddFjMckETnNnEiZ+HX37s=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260223185530-2f722ef697dc/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY=
|
||||
google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
|
||||
2
goutils
2
goutils
Submodule goutils updated: 482b5bca9f...3be815cb6e
@@ -5093,11 +5093,6 @@
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"isResponseRule": {
|
||||
"type": "boolean",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
@@ -5130,10 +5125,7 @@
|
||||
"$ref": "#/definitions/MockResponse"
|
||||
},
|
||||
"rules": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/routeApi.RawRule"
|
||||
},
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
}
|
||||
@@ -6931,28 +6923,6 @@
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"routeApi.RawRule": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"do": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"on": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
}
|
||||
},
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"routeApi.RoutesByProvider": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
|
||||
@@ -891,8 +891,6 @@ definitions:
|
||||
properties:
|
||||
do:
|
||||
type: string
|
||||
isResponseRule:
|
||||
type: boolean
|
||||
name:
|
||||
type: string
|
||||
"on":
|
||||
@@ -907,9 +905,7 @@ definitions:
|
||||
mockResponse:
|
||||
$ref: '#/definitions/MockResponse'
|
||||
rules:
|
||||
items:
|
||||
$ref: '#/definitions/routeApi.RawRule'
|
||||
type: array
|
||||
type: string
|
||||
required:
|
||||
- rules
|
||||
type: object
|
||||
@@ -1837,12 +1833,12 @@ definitions:
|
||||
type: string
|
||||
kernel_version:
|
||||
type: string
|
||||
load_avg_5m:
|
||||
type: string
|
||||
load_avg_15m:
|
||||
type: string
|
||||
load_avg_1m:
|
||||
type: string
|
||||
load_avg_5m:
|
||||
type: string
|
||||
mem_pct:
|
||||
type: string
|
||||
mem_total:
|
||||
@@ -1860,15 +1856,6 @@ definitions:
|
||||
uptime:
|
||||
type: string
|
||||
type: object
|
||||
routeApi.RawRule:
|
||||
properties:
|
||||
do:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
"on":
|
||||
type: string
|
||||
type: object
|
||||
routeApi.RoutesByProvider:
|
||||
additionalProperties:
|
||||
items:
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
apitypes "github.com/yusing/goutils/apitypes"
|
||||
@@ -23,7 +24,7 @@ type RawRule struct {
|
||||
}
|
||||
|
||||
type PlaygroundRequest struct {
|
||||
Rules []RawRule `json:"rules" binding:"required"`
|
||||
Rules string `json:"rules" binding:"required"`
|
||||
MockRequest MockRequest `json:"mockRequest"`
|
||||
MockResponse MockResponse `json:"mockResponse"`
|
||||
} // @name PlaygroundRequest
|
||||
@@ -64,7 +65,6 @@ type ParsedRule struct {
|
||||
On string `json:"on"`
|
||||
Do string `json:"do"`
|
||||
ValidationError error `json:"validationError,omitempty"` // we need the structured error, not the plain string
|
||||
IsResponseRule bool `json:"isResponseRule"`
|
||||
} // @name ParsedRule
|
||||
|
||||
type FinalRequest struct {
|
||||
@@ -256,7 +256,35 @@ func handlerWithRecover(w http.ResponseWriter, r *http.Request, h http.HandlerFu
|
||||
h(w, r)
|
||||
}
|
||||
|
||||
func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
|
||||
func parseRules(config string) ([]ParsedRule, rules.Rules, error) {
|
||||
config = strings.TrimSpace(config)
|
||||
if config == "" {
|
||||
return []ParsedRule{}, nil, nil
|
||||
}
|
||||
|
||||
var rawRules []RawRule
|
||||
if err := yaml.Unmarshal([]byte(config), &rawRules); err == nil && len(rawRules) > 0 {
|
||||
return parseRawRules(rawRules)
|
||||
}
|
||||
|
||||
var rulesList rules.Rules
|
||||
if err := rulesList.Parse(config); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
parsedRules := make([]ParsedRule, 0, len(rulesList))
|
||||
for _, rule := range rulesList {
|
||||
parsedRules = append(parsedRules, ParsedRule{
|
||||
Name: rule.Name,
|
||||
On: rule.On.String(),
|
||||
Do: rule.Do.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return parsedRules, rulesList, nil
|
||||
}
|
||||
|
||||
func parseRawRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
|
||||
parsedRules := make([]ParsedRule, 0, len(rawRules))
|
||||
rulesList := make(rules.Rules, 0, len(rawRules))
|
||||
|
||||
@@ -298,7 +326,6 @@ func parseRules(rawRules []RawRule) ([]ParsedRule, rules.Rules, error) {
|
||||
On: onStr,
|
||||
Do: doStr,
|
||||
ValidationError: validationErr,
|
||||
IsResponseRule: rule.IsResponseRule(),
|
||||
})
|
||||
|
||||
// Only add valid rules to execution list
|
||||
|
||||
@@ -22,13 +22,10 @@ func TestPlayground(t *testing.T) {
|
||||
{
|
||||
name: "simple path matching rule",
|
||||
request: PlaygroundRequest{
|
||||
Rules: []RawRule{
|
||||
{
|
||||
Name: "test rule",
|
||||
On: "path /api",
|
||||
Do: "pass",
|
||||
},
|
||||
},
|
||||
Rules: `- name: test rule
|
||||
on: path /api
|
||||
do: pass
|
||||
`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "GET",
|
||||
Path: "/api",
|
||||
@@ -53,13 +50,10 @@ func TestPlayground(t *testing.T) {
|
||||
{
|
||||
name: "header matching rule",
|
||||
request: PlaygroundRequest{
|
||||
Rules: []RawRule{
|
||||
{
|
||||
Name: "check user agent",
|
||||
On: "header User-Agent Chrome",
|
||||
Do: "error 403 Forbidden",
|
||||
},
|
||||
},
|
||||
Rules: `- name: check user agent
|
||||
on: header User-Agent Chrome
|
||||
do: error 403 Forbidden
|
||||
`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "GET",
|
||||
Path: "/",
|
||||
@@ -90,13 +84,10 @@ func TestPlayground(t *testing.T) {
|
||||
{
|
||||
name: "invalid rule syntax",
|
||||
request: PlaygroundRequest{
|
||||
Rules: []RawRule{
|
||||
{
|
||||
Name: "bad rule",
|
||||
On: "invalid_checker something",
|
||||
Do: "pass",
|
||||
},
|
||||
},
|
||||
Rules: `- name: bad rule
|
||||
on: invalid_checker something
|
||||
do: pass
|
||||
`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "GET",
|
||||
Path: "/",
|
||||
@@ -115,13 +106,10 @@ func TestPlayground(t *testing.T) {
|
||||
{
|
||||
name: "rewrite path rule",
|
||||
request: PlaygroundRequest{
|
||||
Rules: []RawRule{
|
||||
{
|
||||
Name: "rewrite rule",
|
||||
On: "path glob(/api/*)",
|
||||
Do: "rewrite /api/ /v1/",
|
||||
},
|
||||
},
|
||||
Rules: `- name: rewrite rule
|
||||
on: path glob(/api/*)
|
||||
do: rewrite /api/ /v1/
|
||||
`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "GET",
|
||||
Path: "/api/users",
|
||||
@@ -148,13 +136,10 @@ func TestPlayground(t *testing.T) {
|
||||
{
|
||||
name: "method matching rule",
|
||||
request: PlaygroundRequest{
|
||||
Rules: []RawRule{
|
||||
{
|
||||
Name: "block POST",
|
||||
On: "method POST",
|
||||
Do: `error "405" "Method Not Allowed"`,
|
||||
},
|
||||
},
|
||||
Rules: `- name: block POST
|
||||
on: method POST
|
||||
do: error "405" "Method Not Allowed"
|
||||
`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "POST",
|
||||
Path: "/api",
|
||||
@@ -173,6 +158,63 @@ func TestPlayground(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block syntax default rule",
|
||||
request: PlaygroundRequest{
|
||||
Rules: `default {
|
||||
pass
|
||||
}`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "GET",
|
||||
Path: "/",
|
||||
},
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
checkResponse: func(t *testing.T, resp PlaygroundResponse) {
|
||||
if len(resp.ParsedRules) != 1 {
|
||||
t.Errorf("expected 1 parsed rule, got %d", len(resp.ParsedRules))
|
||||
}
|
||||
if resp.ParsedRules[0].ValidationError != nil {
|
||||
t.Errorf("expected rule to be valid, got error: %v", resp.ParsedRules[0].ValidationError)
|
||||
}
|
||||
if !resp.UpstreamCalled {
|
||||
t.Error("expected upstream to be called")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "block syntax conditional rule",
|
||||
request: PlaygroundRequest{
|
||||
Rules: `header User-Agent Chrome {
|
||||
error 403 Forbidden
|
||||
}`,
|
||||
MockRequest: MockRequest{
|
||||
Method: "GET",
|
||||
Path: "/",
|
||||
Headers: map[string][]string{
|
||||
"User-Agent": {"Chrome"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
checkResponse: func(t *testing.T, resp PlaygroundResponse) {
|
||||
if len(resp.ParsedRules) != 1 {
|
||||
t.Errorf("expected 1 parsed rule, got %d", len(resp.ParsedRules))
|
||||
}
|
||||
if resp.ParsedRules[0].ValidationError != nil {
|
||||
t.Errorf("expected rule to be valid, got error: %v", resp.ParsedRules[0].ValidationError)
|
||||
}
|
||||
if len(resp.MatchedRules) != 1 {
|
||||
t.Errorf("expected 1 matched rule, got %d", len(resp.MatchedRules))
|
||||
}
|
||||
if resp.FinalResponse.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status 403, got %d", resp.FinalResponse.StatusCode)
|
||||
}
|
||||
if resp.UpstreamCalled {
|
||||
t.Error("expected upstream not to be called")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -98,8 +98,8 @@ require (
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
google.golang.org/api v0.267.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d // indirect
|
||||
google.golang.org/api v0.268.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260223185530-2f722ef697dc // indirect
|
||||
google.golang.org/grpc v1.79.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/ini.v1 v1.67.1 // indirect
|
||||
|
||||
@@ -249,14 +249,14 @@ golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.267.0 h1:w+vfWPMPYeRs8qH1aYYsFX68jMls5acWl/jocfLomwE=
|
||||
google.golang.org/api v0.267.0/go.mod h1:Jzc0+ZfLnyvXma3UtaTl023TdhZu6OMBP9tJ+0EmFD0=
|
||||
google.golang.org/api v0.268.0 h1:hgA3aS4lt9rpF5RCCkX0Q2l7DvHgvlb53y4T4u6iKkA=
|
||||
google.golang.org/api v0.268.0/go.mod h1:HXMyMH496wz+dAJwD/GkAPLd3ZL33Kh0zEG32eNvy9w=
|
||||
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM=
|
||||
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d h1:t/LOSXPJ9R0B6fnZNyALBRfZBH0Uy0gT+uR+SJ6syqQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260223185530-2f722ef697dc h1:51Wupg8spF+5FC6D+iMKbOddFjMckETnNnEiZ+HX37s=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260223185530-2f722ef697dc/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY=
|
||||
google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
|
||||
@@ -13,6 +13,8 @@ This package implements a flexible HTTP middleware system for GoDoxy. Middleware
|
||||
- **Bypass Rules**: Skip middleware based on request properties
|
||||
- **Dynamic Loading**: Load middleware definitions from files at runtime
|
||||
|
||||
Response body rewriting is only applied to unencoded, text-like content types (for example `text/*`, JSON, YAML, XML). Response status and headers can always be modified.
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -195,21 +196,15 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
||||
}
|
||||
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
lrm := httputils.NewLazyResponseModifier(w, needsBuffering)
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
_, err := lrm.FlushRelease()
|
||||
_, err := rm.FlushRelease()
|
||||
if err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to flush response")
|
||||
}
|
||||
}()
|
||||
next(lrm, r)
|
||||
next(rm, r)
|
||||
|
||||
// Skip modification if response wasn't buffered (non-HTML content)
|
||||
if !lrm.IsBuffered() {
|
||||
return
|
||||
}
|
||||
|
||||
rm := lrm.ResponseModifier()
|
||||
currentBody := rm.BodyReader()
|
||||
currentResp := &http.Response{
|
||||
StatusCode: rm.StatusCode(),
|
||||
@@ -218,20 +213,31 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
||||
Body: currentBody,
|
||||
Request: r,
|
||||
}
|
||||
if err := exec.modifyResponse(currentResp); err != nil {
|
||||
allowBodyModification := canModifyResponseBody(currentResp)
|
||||
respToModify := currentResp
|
||||
if !allowBodyModification {
|
||||
shadow := *currentResp
|
||||
shadow.Body = eofReader{}
|
||||
respToModify = &shadow
|
||||
}
|
||||
if err := exec.modifyResponse(respToModify); err != nil {
|
||||
log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response")
|
||||
}
|
||||
|
||||
// override the response status code
|
||||
rm.WriteHeader(currentResp.StatusCode)
|
||||
rm.WriteHeader(respToModify.StatusCode)
|
||||
|
||||
// overriding the response header
|
||||
maps.Copy(rm.Header(), currentResp.Header)
|
||||
maps.Copy(rm.Header(), respToModify.Header)
|
||||
|
||||
// override the content length and body if changed
|
||||
if currentResp.Body != currentBody {
|
||||
if err := rm.SetBody(currentResp.Body); err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to set response body")
|
||||
if respToModify.Body != currentBody {
|
||||
if allowBodyModification {
|
||||
if err := rm.SetBody(respToModify.Body); err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to set response body")
|
||||
}
|
||||
} else {
|
||||
respToModify.Body.Close()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -239,10 +245,55 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
||||
}
|
||||
}
|
||||
|
||||
// needsBuffering determines if a response should be buffered for modification.
|
||||
// Only HTML responses need buffering; streaming content (video, audio, etc.) should pass through.
|
||||
func needsBuffering(header http.Header) bool {
|
||||
return httputils.GetContentType(header).IsHTML()
|
||||
func canModifyResponseBody(resp *http.Response) bool {
|
||||
if hasNonIdentityEncoding(resp.TransferEncoding) {
|
||||
return false
|
||||
}
|
||||
if hasNonIdentityEncoding(resp.Header.Values("Transfer-Encoding")) {
|
||||
return false
|
||||
}
|
||||
if hasNonIdentityEncoding(resp.Header.Values("Content-Encoding")) {
|
||||
return false
|
||||
}
|
||||
return isTextLikeMediaType(string(httputils.GetContentType(resp.Header)))
|
||||
}
|
||||
|
||||
func hasNonIdentityEncoding(values []string) bool {
|
||||
for _, value := range values {
|
||||
for _, token := range strings.Split(value, ",") {
|
||||
if strings.TrimSpace(token) == "" || strings.EqualFold(strings.TrimSpace(token), "identity") {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isTextLikeMediaType(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return false
|
||||
}
|
||||
contentType = strings.ToLower(contentType)
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
return true
|
||||
}
|
||||
if contentType == "application/json" || strings.HasSuffix(contentType, "+json") {
|
||||
return true
|
||||
}
|
||||
if contentType == "application/xml" || strings.HasSuffix(contentType, "+xml") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(contentType, "yaml") || strings.Contains(contentType, "toml") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(contentType, "javascript") || strings.Contains(contentType, "ecmascript") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(contentType, "csv") {
|
||||
return true
|
||||
}
|
||||
return contentType == "application/x-www-form-urlencoded"
|
||||
}
|
||||
|
||||
func (m *Middleware) LogWarn(req *http.Request) *zerolog.Event {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -46,10 +47,23 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||
if len(m.modResps) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowBodyModification := canModifyResponseBody(resp)
|
||||
for i, mr := range m.modResps {
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
respToModify := resp
|
||||
if !allowBodyModification {
|
||||
shadow := *resp
|
||||
shadow.Body = eofReader{}
|
||||
respToModify = &shadow
|
||||
}
|
||||
if err := mr.modifyResponse(respToModify); err != nil {
|
||||
return gperr.PrependSubject(err, strconv.Itoa(i))
|
||||
}
|
||||
if !allowBodyModification {
|
||||
resp.StatusCode = respToModify.StatusCode
|
||||
if respToModify.Header != nil {
|
||||
maps.Copy(resp.Header, respToModify.Header)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -14,12 +15,27 @@ type testPriority struct {
|
||||
}
|
||||
|
||||
var test = NewMiddleware[testPriority]()
|
||||
var responseRewrite = NewMiddleware[testResponseRewrite]()
|
||||
|
||||
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
|
||||
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
|
||||
return true
|
||||
}
|
||||
|
||||
type testResponseRewrite struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
HeaderKey string `json:"header_key"`
|
||||
HeaderVal string `json:"header_val"`
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
func (t testResponseRewrite) modifyResponse(resp *http.Response) error {
|
||||
resp.StatusCode = t.StatusCode
|
||||
resp.Header.Set(t.HeaderKey, t.HeaderVal)
|
||||
resp.Body = io.NopCloser(strings.NewReader(t.Body))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMiddlewarePriority(t *testing.T) {
|
||||
priorities := []int{4, 7, 9, 0}
|
||||
chain := make([]*Middleware, len(priorities))
|
||||
@@ -35,3 +51,85 @@ func TestMiddlewarePriority(t *testing.T) {
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
|
||||
}
|
||||
|
||||
func TestMiddlewareResponseRewriteGate(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"status_code": 418,
|
||||
"header_key": "X-Rewrite",
|
||||
"header_val": "1",
|
||||
"body": "rewritten-body",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respHeaders http.Header
|
||||
respBody []byte
|
||||
expectBody string
|
||||
}{
|
||||
{
|
||||
name: "allow_body_rewrite_for_html",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"text/html; charset=utf-8"},
|
||||
},
|
||||
respBody: []byte("<html><body>original</body></html>"),
|
||||
expectBody: "rewritten-body",
|
||||
},
|
||||
{
|
||||
name: "allow_body_rewrite_for_json",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
respBody: []byte(`{"message":"original"}`),
|
||||
expectBody: "rewritten-body",
|
||||
},
|
||||
{
|
||||
name: "allow_body_rewrite_for_yaml",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/yaml"},
|
||||
},
|
||||
respBody: []byte("k: v"),
|
||||
expectBody: "rewritten-body",
|
||||
},
|
||||
{
|
||||
name: "block_body_rewrite_for_binary_content",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
},
|
||||
respBody: []byte("binary"),
|
||||
expectBody: "binary",
|
||||
},
|
||||
{
|
||||
name: "block_body_rewrite_for_transfer_encoded_html",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
"Transfer-Encoding": []string{"chunked"},
|
||||
},
|
||||
respBody: []byte("<html><body>original</body></html>"),
|
||||
expectBody: "<html><body>original</body></html>",
|
||||
},
|
||||
{
|
||||
name: "block_body_rewrite_for_content_encoded_html",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
"Content-Encoding": []string{"gzip"},
|
||||
},
|
||||
respBody: []byte("<html><body>original</body></html>"),
|
||||
expectBody: "<html><body>original</body></html>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := newMiddlewareTest(responseRewrite, &testArgs{
|
||||
middlewareOpt: opts,
|
||||
respHeaders: tc.respHeaders,
|
||||
respBody: tc.respBody,
|
||||
respStatus: http.StatusOK,
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, result.ResponseStatus, 418)
|
||||
expect.Equal(t, result.ResponseHeaders.Get("X-Rewrite"), "1")
|
||||
expect.Equal(t, string(result.Data), tc.expectBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
@@ -54,7 +55,7 @@ func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, er
|
||||
resp = &http.Response{
|
||||
Status: http.StatusText(rt.args.respStatus),
|
||||
StatusCode: rt.args.respStatus,
|
||||
Header: testHeaders,
|
||||
Header: maps.Clone(testHeaders),
|
||||
Body: io.NopCloser(bytes.NewReader(rt.args.respBody)),
|
||||
ContentLength: int64(len(rt.args.respBody)),
|
||||
Request: req,
|
||||
@@ -65,9 +66,27 @@ func (rt *requestRecorder) RoundTrip(req *http.Request) (resp *http.Response, er
|
||||
return nil, err
|
||||
}
|
||||
maps.Copy(resp.Header, rt.args.respHeaders)
|
||||
if transferEncoding := resp.Header.Values("Transfer-Encoding"); len(transferEncoding) > 0 {
|
||||
resp.TransferEncoding = parseHeaderTokens(transferEncoding)
|
||||
resp.ContentLength = -1
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseHeaderTokens(values []string) []string {
|
||||
var tokens []string
|
||||
for _, value := range values {
|
||||
for token := range strings.SplitSeq(value, ",") {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
RequestHeaders http.Header
|
||||
ResponseHeaders http.Header
|
||||
|
||||
@@ -36,15 +36,15 @@ type Rule struct {
|
||||
}
|
||||
|
||||
type RuleOn struct {
|
||||
raw string
|
||||
checker Checker
|
||||
isResponseChecker bool
|
||||
raw string
|
||||
checker Checker
|
||||
phase PhaseFlag
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
raw string
|
||||
exec CommandHandler
|
||||
isResponseHandler bool
|
||||
raw string
|
||||
pre Commands
|
||||
post Commands
|
||||
}
|
||||
```
|
||||
|
||||
@@ -59,6 +59,9 @@ func ParseRules(config string) (Rules, error)
|
||||
|
||||
// ValidateRules validates rule syntax
|
||||
func ValidateRules(config string) error
|
||||
|
||||
// Validate validates rule semantics (e.g., prevents multiple default rules)
|
||||
func (rules Rules) Validate() gperr.Error
|
||||
```
|
||||
|
||||
## Architecture
|
||||
@@ -122,16 +125,52 @@ sequenceDiagram
|
||||
Pre->>Pre: Execute handler
|
||||
alt Terminating action
|
||||
Pre-->>Req: Response
|
||||
Return-->>Req: Return immediately
|
||||
Note right of Pre: Stop remaining pre commands
|
||||
end
|
||||
end
|
||||
Req->>Proxy: Forward request
|
||||
Proxy-->>Req: Response
|
||||
Req->>Post: Check post-rules
|
||||
Post->>Post: Execute handlers
|
||||
Post-->>Req: Modified response
|
||||
opt No pre termination
|
||||
Req->>Proxy: Forward request
|
||||
Proxy-->>Req: Response
|
||||
end
|
||||
Req->>Post: Run scheduled post commands
|
||||
Req->>Post: Evaluate response matchers
|
||||
Post->>Post: Execute matched post handlers
|
||||
Post-->>Req: Final response
|
||||
```
|
||||
|
||||
### Execution Model (Authoritative)
|
||||
|
||||
Rules run in two phases:
|
||||
|
||||
1. **Pre phase**
|
||||
- Evaluate only request-based matchers (`path`, `method`, `header`, `remote`, etc.) in declaration order.
|
||||
- Execute matched rule `do` pre-commands in order.
|
||||
- If a default rule exists (`name: default` or `on: default`), it is a fallback and runs only when no non-default pre rule matches.
|
||||
- If a terminating action runs, stop:
|
||||
- remaining commands in that rule
|
||||
- all later pre-phase commands.
|
||||
- Exception: rules that only contain post commands (no pre commands) are still scheduled for post phase.
|
||||
|
||||
2. **Upstream phase**
|
||||
- Upstream is called only if pre phase did not terminate.
|
||||
|
||||
3. **Post phase**
|
||||
- Run post-commands for rules whose pre phase executed, except rules that terminated in pre.
|
||||
- Then evaluate response-based matchers (`status`, `resp_header`) and execute their `do` commands.
|
||||
- Response-based rules run even when the response was produced in pre phase.
|
||||
|
||||
**Important:** termination is explicit by command semantics, not inferred from status-code mutation.
|
||||
|
||||
### Phase Flags
|
||||
|
||||
Rule and command parsing tracks phase requirements via `PhaseFlag`:
|
||||
|
||||
- `PhasePre`
|
||||
- `PhasePost`
|
||||
- `PhasePre | PhasePost` (combined)
|
||||
|
||||
Combined flags are expected for nested/compound commands and variable templates that may need both request and response context.
|
||||
|
||||
### Condition Matchers
|
||||
|
||||
| Matcher | Type | Description |
|
||||
@@ -166,22 +205,22 @@ path regex("/api/v[0-9]+/.*") // regex pattern
|
||||
|
||||
**Terminating Actions** (stop processing):
|
||||
|
||||
| Command | Description |
|
||||
| ------------------------ | ---------------------- |
|
||||
| `error <code> <message>` | Return HTTP error |
|
||||
| `redirect <url>` | Redirect to URL |
|
||||
| `serve <path>` | Serve local files |
|
||||
| `route <name>` | Route to another route |
|
||||
| `proxy <url>` | Proxy to upstream |
|
||||
| Command | Description |
|
||||
| ------------------------------ | ------------------------------------- |
|
||||
| `upstream` / `bypass` / `pass` | Call upstream and terminate pre-phase |
|
||||
| `error <code> <message>` | Return HTTP error |
|
||||
| `redirect <url>` | Redirect to URL |
|
||||
| `serve <path>` | Serve local files |
|
||||
| `route <name>` | Route to another route |
|
||||
| `proxy <url>` | Proxy to upstream |
|
||||
| `require_basic_auth <realm>` | Return 401 challenge |
|
||||
|
||||
**Non-Terminating Actions** (modify and continue):
|
||||
|
||||
| Command | Description |
|
||||
| ------------------------------ | ---------------------- |
|
||||
| `pass` / `bypass` | Pass through unchanged |
|
||||
| `rewrite <from> <to>` | Rewrite request path |
|
||||
| `require_auth` | Require authentication |
|
||||
| `require_basic_auth <realm>` | Basic auth challenge |
|
||||
| `set <target> <field> <value>` | Set header/variable |
|
||||
| `add <target> <field> <value>` | Add header/variable |
|
||||
| `remove <target> <field>` | Remove header/variable |
|
||||
@@ -195,54 +234,226 @@ path regex("/api/v[0-9]+/.*") // regex pattern
|
||||
|
||||
## Configuration Surface
|
||||
|
||||
### Rule Configuration (YAML)
|
||||
### Rule Configuration (Block Syntax)
|
||||
|
||||
```yaml
|
||||
rules:
|
||||
- name: rule name
|
||||
on: |
|
||||
condition1
|
||||
& condition2
|
||||
do: |
|
||||
action1
|
||||
action2
|
||||
```bash
|
||||
default {
|
||||
action1
|
||||
action2
|
||||
}
|
||||
|
||||
condition1 &
|
||||
condition2 {
|
||||
action1
|
||||
action2
|
||||
}
|
||||
```
|
||||
|
||||
This is the primary syntax for rules and avoids YAML wrappers.
|
||||
It keeps the **inner** `on` and `do` DSLs exactly the same (same matchers, same commands, same optional quotes), but wraps each rule in a `{ ... }` block.
|
||||
|
||||
#### Key ideas
|
||||
|
||||
- A rule is:
|
||||
- `default { <do...> }`,
|
||||
- `{ <do...> }`, or
|
||||
- `<on-expr> { <do...> }`
|
||||
- Comments are supported:
|
||||
- line comment: `// ...` (to end of line)
|
||||
- line comment: `# ...` (to end of line, for YAML familiarity)
|
||||
- block comment: `/* ... */` (may span multiple lines)
|
||||
- Comments are ignored **only when outside quotes** (`"`, `'` or backticks).
|
||||
- Environment variable syntax: `${NAME}` is supported by the inner DSL parser in [`parse()`](internal/route/rules/parser.go:34).
|
||||
Block-syntax rule:
|
||||
- In `on` (rule header): `${...}` must be inside quotes/backticks.
|
||||
- In `do` (rule body): `${...}` may be unquoted; the outer parser must treat `${...}` as an opaque token so braces inside it are not structural.
|
||||
|
||||
#### Grammar sketch (EBNF-ish)
|
||||
|
||||
```text
|
||||
file := { ws | comment | rule }
|
||||
rule := default_rule | unconditional_rule | conditional_rule
|
||||
|
||||
default_rule := 'default' ws* block
|
||||
unconditional_rule := ws* block
|
||||
conditional_rule := on_expr ws* block
|
||||
|
||||
block := '{' do_body '}'
|
||||
|
||||
// on_expr and do_body are raw text regions.
|
||||
// The outer parser only needs to:
|
||||
// - find the top-level '{' to start a rule block
|
||||
// - find the matching top-level '}' to end it
|
||||
// while respecting quotes and comments.
|
||||
```
|
||||
|
||||
#### Elif/Else Chain Grammar
|
||||
|
||||
```text
|
||||
// Elif/Else chains can appear in do_body
|
||||
do_stmt := command_line | nested_block | elif_else_chain
|
||||
elif_else_chain := nested_block { elif_clause } [else_clause]
|
||||
elif_clause := 'elif' ws* on_expr ws* '{' do_body '}'
|
||||
else_clause := 'else' ws* '{' do_body '}'
|
||||
```
|
||||
|
||||
#### Nested blocks (inline conditionals inside `do`)
|
||||
|
||||
Inside a rule body (`do_body`), you can write **nested blocks**:
|
||||
|
||||
```text
|
||||
do_stmt := command_line | nested_block | elif_else_chain
|
||||
|
||||
nested_block := on_expr ws* '{' do_body '}'
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- A nested block is recognized when a line ends with an unquoted `{` (ignoring trailing whitespace).
|
||||
- `on_expr` uses the same syntax as rule `on` (supports `|`, `&`, quoting/backticks, matcher functions, etc.).
|
||||
- The nested block executes **in sequence**, at the point where it appears in the parent `do` list.
|
||||
- Nested blocks are evaluated in the same phase the parent rule runs (no special phase promotion).
|
||||
- Nested blocks can be chained with `elif`/`else` for conditional execution (see Elif/Else Chains section).
|
||||
|
||||
Example:
|
||||
|
||||
```go
|
||||
default {
|
||||
remove resp_header X-Secret
|
||||
add resp_header X-Custom-Header custom-value
|
||||
}
|
||||
|
||||
header X-Test-Header {
|
||||
set header X-Remote-Type public
|
||||
remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||
set header X-Remote-Type private
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Elif/Else Chains
|
||||
|
||||
You can chain multiple conditions using `elif` and provide a fallback with `else`.
|
||||
The `elif`/`else` keywords must appear on the same line as the preceding closing brace (`}`).
|
||||
|
||||
```go
|
||||
header X-Test-Header {
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
} elif method POST {
|
||||
set header X-Mode post
|
||||
} else {
|
||||
set header X-Mode other
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- `elif` and `else` must be on the same line as the preceding `}`.
|
||||
- Multiple `elif` branches are allowed; only one `else` is allowed.
|
||||
- The entire chain is evaluated in sequence; the first matching branch executes.
|
||||
- Elif/else chains can only be used within nested blocks.
|
||||
- Each `elif` clause must have its own condition expression and block.
|
||||
- The `else` clause is optional and provides a default action when no conditions match.
|
||||
|
||||
#### Examples
|
||||
|
||||
Basic default rule:
|
||||
|
||||
```go
|
||||
default {
|
||||
bypass
|
||||
}
|
||||
```
|
||||
|
||||
WebSocket upgrade routing:
|
||||
|
||||
```bash
|
||||
# WebSocket requests
|
||||
header Connection Upgrade &
|
||||
header Upgrade websocket {
|
||||
route ws-api
|
||||
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
|
||||
}
|
||||
```
|
||||
|
||||
Block comments:
|
||||
|
||||
```go
|
||||
/* protect admin area */
|
||||
path glob("/admin/*") {
|
||||
require_auth
|
||||
}
|
||||
```
|
||||
|
||||
Always log the request
|
||||
|
||||
```bash
|
||||
{
|
||||
log info /dev/stdout "Request $req_method $req_path"
|
||||
}
|
||||
```
|
||||
|
||||
#### Notes and constraints
|
||||
|
||||
- The block syntax uses `{` and `}` as structure delimiters at **top-level** (outside quotes/comments).
|
||||
- Braces inside quoted strings (including backticks) are not structural.
|
||||
- `${...}` handling:
|
||||
- `on`: must be quoted/backticked
|
||||
- `do`: may be unquoted
|
||||
Preferred style: always write env vars as `${NAME}` rather than a bare `$NAME`.
|
||||
- If you need literal `{` or `}` outside quotes/backticks (for example unquoted templates like `{{ ... }}`), wrap that argument in quotes/backticks so the outer parser does not treat it as structure.
|
||||
- Rule naming remains minimal: if no explicit name is provided by the syntax, it will behave like the current YAML behavior (empty name becomes `rule[index]` in [`Rules.BuildHandler()`](internal/route/rules/rules.go:75)).
|
||||
- YAML remains supported as a fallback for backward compatibility.
|
||||
|
||||
### Condition Syntax
|
||||
|
||||
```yaml
|
||||
```bash
|
||||
# Simple condition
|
||||
on: path /api/users
|
||||
path /api/users
|
||||
|
||||
# Multiple conditions (AND)
|
||||
on: |
|
||||
header Authorization Bearer
|
||||
& path /api/admin/*
|
||||
header Authorization Bearer & path glob("/api/admin/*")
|
||||
|
||||
# Negation
|
||||
on: !path /public/*
|
||||
!path glob("/public/*")
|
||||
|
||||
# Negation on matcher
|
||||
path !glob("/public/*")
|
||||
|
||||
# OR within a line
|
||||
on: method GET | method POST
|
||||
method GET | method POST
|
||||
```
|
||||
|
||||
### Variable Substitution
|
||||
|
||||
```go
|
||||
// Static variables
|
||||
$req_method // Request method
|
||||
$req_host // Request host
|
||||
$req_path // Request path
|
||||
$status_code // Response status
|
||||
$remote_host // Client IP
|
||||
```bash
|
||||
# Static variables
|
||||
$req_method # Request method
|
||||
$req_host # Request host
|
||||
$req_path # Request path
|
||||
$status_code # Response status
|
||||
$remote_host # Client IP
|
||||
|
||||
// Dynamic variables
|
||||
$header(Name) // Request header
|
||||
$header(Name, index) // Header at index
|
||||
$arg(Name) // Query argument
|
||||
$form(Name) // Form field
|
||||
# Dynamic variables
|
||||
$header(Name) # Request header
|
||||
$header(Name, index) # Header at index
|
||||
$resp_header(Name) # Response header
|
||||
$arg(Name) # Query argument
|
||||
$form(Name) # Form field
|
||||
$postform(Name) # POST form field
|
||||
$cookie(Name) # Cookie value
|
||||
|
||||
// Environment variables
|
||||
# Function composition: pass result of one function to another
|
||||
$redacted($header(Authorization)) # Redact the Authorization header value
|
||||
$redacted($arg(token)) # Redact a query parameter value
|
||||
$redacted($cookie(session)) # Redact a cookie value
|
||||
|
||||
# $redacted: masks a value, showing only first 2 and last 2 characters
|
||||
$redacted(value) # Redact a plain string
|
||||
|
||||
# Environment variables
|
||||
${ENV_VAR}
|
||||
```
|
||||
|
||||
@@ -277,84 +488,104 @@ Log context includes: `rule`, `alias`, `match_result`
|
||||
|
||||
## Failure Modes and Recovery
|
||||
|
||||
| Failure | Behavior | Recovery |
|
||||
| ------------------- | ------------------------- | ---------------------------------- |
|
||||
| Invalid rule syntax | Route validation fails | Fix YAML syntax |
|
||||
| Missing variables | Variable renders as empty | Check variable sources |
|
||||
| Rule timeout | Request times out | Increase timeout or simplify rules |
|
||||
| Auth failure | Returns 401/403 | Fix credentials |
|
||||
| Failure | Behavior | Recovery |
|
||||
| ---------------------- | ------------------------- | ---------------------------------- |
|
||||
| Invalid rule syntax | Route validation fails | Fix block rule syntax |
|
||||
| Multiple default rules | Route validation fails | Remove duplicate default rules |
|
||||
| Missing variables | Variable renders as empty | Check variable sources |
|
||||
| Rule timeout | Request times out | Increase timeout or simplify rules |
|
||||
| Auth failure | Returns 401/403 | Fix credentials |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Pass-Through
|
||||
|
||||
```yaml
|
||||
- name: default
|
||||
do: pass
|
||||
```bash
|
||||
default {
|
||||
pass
|
||||
}
|
||||
```
|
||||
|
||||
### Path-Based Routing
|
||||
|
||||
```yaml
|
||||
- name: api proxy
|
||||
on: path /api/*
|
||||
do: proxy http://api-backend:8080
|
||||
```bash
|
||||
path glob("/api/*") {
|
||||
proxy http://api-backend:8080
|
||||
}
|
||||
|
||||
- name: static files
|
||||
on: path /static/*
|
||||
do: serve /var/www/static
|
||||
path glob("/static/*") {
|
||||
serve /var/www/static
|
||||
}
|
||||
```
|
||||
|
||||
### Authentication
|
||||
|
||||
```yaml
|
||||
- name: admin protection
|
||||
on: path /admin/*
|
||||
do: require_auth
|
||||
```bash
|
||||
path glob("/admin/*") {
|
||||
require_auth
|
||||
}
|
||||
|
||||
- name: basic auth for API
|
||||
on: path /api/*
|
||||
do: require_basic_auth "API Access"
|
||||
path glob("/api/*") {
|
||||
require_basic_auth "API Access"
|
||||
}
|
||||
```
|
||||
|
||||
### Path Rewriting
|
||||
|
||||
```yaml
|
||||
- name: rewrite API v1
|
||||
on: path /v1/*
|
||||
do: |
|
||||
rewrite /v1 /api/v1
|
||||
proxy http://backend:8080
|
||||
```bash
|
||||
path glob("/v1/*") {
|
||||
rewrite /v1 /api/v1
|
||||
proxy http://backend:8080
|
||||
}
|
||||
```
|
||||
|
||||
### IP-Based Access Control
|
||||
|
||||
```yaml
|
||||
- name: allow internal
|
||||
on: remote 10.0.0.0/8
|
||||
do: pass
|
||||
```bash
|
||||
remote 10.0.0.0/8 {
|
||||
pass
|
||||
}
|
||||
|
||||
- name: block external
|
||||
on: |
|
||||
!remote 10.0.0.0/8
|
||||
!remote 192.168.0.0/16
|
||||
do: error 403 "Access Denied"
|
||||
!remote 10.0.0.0/8 &
|
||||
!remote 192.168.0.0/16 {
|
||||
error 403 "Access Denied"
|
||||
}
|
||||
```
|
||||
|
||||
### WebSocket Support
|
||||
|
||||
```yaml
|
||||
- name: websocket upgrade
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
```bash
|
||||
header Connection Upgrade &
|
||||
header Upgrade websocket {
|
||||
bypass
|
||||
}
|
||||
```
|
||||
|
||||
### Default Rule (Fallback)
|
||||
|
||||
```bash
|
||||
# Default runs only if no non-default pre rule matches
|
||||
default {
|
||||
remove resp_header X-Internal
|
||||
add resp_header X-Powered-By godoxy
|
||||
}
|
||||
|
||||
# Matching rules suppress default
|
||||
path glob("/api/*") {
|
||||
proxy http://api:8080
|
||||
}
|
||||
|
||||
path glob("/api/*") {
|
||||
set resp_header X-API true
|
||||
}
|
||||
```
|
||||
|
||||
Only one default rule is allowed per route. `name: default` and `on: default` are equivalent selectors and both behave as fallback-only.
|
||||
|
||||
## Testing Notes
|
||||
|
||||
- Unit tests for all matchers and actions
|
||||
- Integration tests with real HTTP requests
|
||||
- Parser tests for YAML syntax
|
||||
- Parser tests for block syntax
|
||||
- Variable substitution tests
|
||||
- Performance benchmarks for hot paths
|
||||
|
||||
409
internal/route/rules/block_parser.go
Normal file
409
internal/route/rules/block_parser.go
Normal file
@@ -0,0 +1,409 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/yusing/goutils/env"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
func getStringBuffer(size int) *strings.Builder {
|
||||
var buf strings.Builder
|
||||
if size > 0 {
|
||||
buf.Grow(size)
|
||||
}
|
||||
return &buf
|
||||
}
|
||||
|
||||
// expandEnvVarsRaw expands ${NAME} in-place using env.LookupEnv (prefix-aware).
|
||||
func expandEnvVarsRaw(v string) (string, gperr.Error) {
|
||||
buf := getStringBuffer(len(v))
|
||||
envVar := getStringBuffer(0)
|
||||
|
||||
var missingEnvVars []string
|
||||
inEnvVar := false
|
||||
expectingBrace := false
|
||||
|
||||
for _, r := range v {
|
||||
if expectingBrace && r != '{' && r != '$' {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
}
|
||||
switch r {
|
||||
case '$':
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
} else {
|
||||
expectingBrace = true
|
||||
}
|
||||
case '{':
|
||||
if expectingBrace {
|
||||
inEnvVar = true
|
||||
expectingBrace = false
|
||||
envVar.Reset()
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
case '}':
|
||||
if inEnvVar {
|
||||
envValue, ok := env.LookupEnv(envVar.String())
|
||||
if !ok {
|
||||
missingEnvVars = append(missingEnvVars, envVar.String())
|
||||
} else {
|
||||
buf.WriteString(envValue)
|
||||
}
|
||||
inEnvVar = false
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
default:
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
}
|
||||
if inEnvVar {
|
||||
envVar.WriteRune(r)
|
||||
} else {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
}
|
||||
|
||||
var err gperr.Error
|
||||
if inEnvVar {
|
||||
// Write back the unterminated ${...} so the output matches the input.
|
||||
buf.WriteString("${")
|
||||
buf.WriteString(envVar.String())
|
||||
err = ErrUnterminatedEnvVar
|
||||
}
|
||||
if len(missingEnvVars) > 0 {
|
||||
err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...)))
|
||||
}
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
// parseBlockRules parses the block-syntax rule format.
|
||||
// Grammar:
|
||||
//
|
||||
// file := { ws | comment | rule }
|
||||
// rule := default_rule | conditional_rule
|
||||
// default_rule := 'default' ws* block
|
||||
// conditional_rule := on_expr ws* block
|
||||
// block := '{' do_body '}'
|
||||
//
|
||||
// Where:
|
||||
// - on_expr is passed verbatim to RuleOn.Parse()
|
||||
// - do_body is passed verbatim to Command.Parse()
|
||||
//
|
||||
// Comments (ignored outside quotes/backticks):
|
||||
// - line comment: // ... or # ...
|
||||
// - block comment: /* ... */
|
||||
//
|
||||
// Brace handling:
|
||||
// - Braces inside quotes/backticks are ignored
|
||||
// - Braces inside ${...} (env vars) are ignored in do_body
|
||||
// - Braces in on_expr are not ignored (env vars must be quoted in on_expr)
|
||||
//
|
||||
//nolint:dupword
|
||||
func parseBlockRules(src string) (Rules, gperr.Error) {
|
||||
var rules Rules
|
||||
var errs gperr.Builder
|
||||
|
||||
pos := 0
|
||||
length := len(src)
|
||||
t := newTokenizer(src)
|
||||
|
||||
for pos < length {
|
||||
// Skip whitespace/comments between rules.
|
||||
newPos, skipErr := t.skipComments(pos, true, true)
|
||||
if skipErr != nil {
|
||||
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
|
||||
}
|
||||
pos = newPos
|
||||
if pos >= length {
|
||||
break
|
||||
}
|
||||
|
||||
// Stray closing brace at top-level: keep parsing but mark invalid so Rules.Validate() fails.
|
||||
if src[pos] == '}' {
|
||||
return nil, ErrInvalidBlockSyntax.Withf("unmatched '}' at position %d", pos)
|
||||
}
|
||||
|
||||
// Parse rule header (default, unconditional, or on_expr)
|
||||
headerStart := pos
|
||||
header := parseRuleHeader(&t, src, &pos, length)
|
||||
headerStr := src[headerStart:pos]
|
||||
|
||||
// Skip whitespace/comments before '{' (default header may end before '{').
|
||||
newPos, skipErr = t.skipComments(pos, false, true)
|
||||
if skipErr != nil {
|
||||
return nil, ErrInvalidBlockSyntax.Withf("at position %d", pos)
|
||||
}
|
||||
pos = newPos
|
||||
|
||||
if pos >= length || src[pos] != '{' {
|
||||
errs.AddSubjectf(ErrInvalidBlockSyntax, "expected '{' after rule header %q", headerStr)
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
// Find matching '}' (respecting quotes and env vars in do_body)
|
||||
bodyStart := pos + 1
|
||||
bodyEnd, err := t.findMatchingBrace(bodyStart)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "rule header %q", headerStr)
|
||||
return nil, errs.Error()
|
||||
}
|
||||
pos = bodyEnd + 1
|
||||
|
||||
onExpr := header
|
||||
|
||||
doBody := ""
|
||||
if bodyStart < bodyEnd {
|
||||
doBody = src[bodyStart:bodyEnd]
|
||||
}
|
||||
// Normalize do body for the inner DSL parser:
|
||||
// - strip comments (outside quotes/backticks)
|
||||
// - trim block whitespace/indentation
|
||||
// - expand ${ENV} in-place so cmd.raw is usable/debuggable
|
||||
doBody, err = preprocessDoBody(doBody)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "rule header %q", headerStr)
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Name: "", // auto-generate if empty
|
||||
On: RuleOn{},
|
||||
Do: Command{},
|
||||
}
|
||||
|
||||
// Header semantics:
|
||||
// - "default" => default rule (matched when no other rules are matched)
|
||||
// - "" => unconditional rule (always matches)
|
||||
// - otherwise => conditional rule (on expression)
|
||||
switch onExpr {
|
||||
case "default":
|
||||
rule.On.raw = OnDefault
|
||||
case "":
|
||||
// leave rule.On as zero value => checker=nil => always matches
|
||||
default:
|
||||
if parseErr := rule.On.Parse(onExpr); parseErr != nil {
|
||||
errs.AddSubjectf(parseErr, "on")
|
||||
}
|
||||
}
|
||||
|
||||
if doBody != "" {
|
||||
if parseErr := rule.Do.Parse(doBody); parseErr != nil {
|
||||
errs.AddSubjectf(parseErr, "do")
|
||||
}
|
||||
}
|
||||
|
||||
if errs.HasError() {
|
||||
return nil, errs.Error()
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func preprocessDoBody(doBody string) (string, gperr.Error) {
|
||||
doBody = strings.TrimSpace(doBody)
|
||||
if doBody == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
normalized := doBody
|
||||
// If comments are possible, strip them first while preserving line breaks.
|
||||
if strings.ContainsAny(normalized, "#/") {
|
||||
stripped, err := stripCommentsPreserveNewlines(normalized)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
normalized = stripped
|
||||
}
|
||||
|
||||
// Drop lines that are empty after trimming, while preserving indentation of non-empty lines.
|
||||
out := getStringBuffer(len(normalized))
|
||||
|
||||
lineStart := 0
|
||||
wroteLine := false
|
||||
for i := 0; i <= len(normalized); i++ {
|
||||
if i < len(normalized) && normalized[i] != '\n' {
|
||||
continue
|
||||
}
|
||||
line := normalized[lineStart:i]
|
||||
if strings.TrimSpace(line) != "" {
|
||||
if wroteLine {
|
||||
out.WriteByte('\n')
|
||||
}
|
||||
out.WriteString(line)
|
||||
wroteLine = true
|
||||
}
|
||||
lineStart = i + 1
|
||||
}
|
||||
|
||||
if !wroteLine {
|
||||
return "", nil
|
||||
}
|
||||
normalized = out.String()
|
||||
|
||||
// Expand env vars to keep Command.raw consistent with parsed semantics.
|
||||
if !strings.Contains(normalized, "${") {
|
||||
return normalized, nil
|
||||
}
|
||||
expanded, err := expandEnvVarsRaw(normalized)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return expanded, nil
|
||||
}
|
||||
|
||||
// stripCommentsPreserveNewlines removes //, #, and /* */ comments outside quotes/backticks.
|
||||
// It preserves newlines so command line boundaries remain intact.
|
||||
func stripCommentsPreserveNewlines(src string) (string, gperr.Error) {
|
||||
if !strings.ContainsAny(src, "#/") {
|
||||
return src, nil
|
||||
}
|
||||
|
||||
out := getStringBuffer(len(src))
|
||||
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
atLineStart := true
|
||||
prevIsSpace := true
|
||||
|
||||
for i := 0; i < len(src); {
|
||||
c := src[i]
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
out.WriteByte('\n')
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '\n' {
|
||||
out.WriteByte('\n')
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if c == '*' && i+1 < len(src) && src[i+1] == '/' {
|
||||
inBlock = false
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
out.WriteByte(c)
|
||||
if c == '\\' && i+1 < len(src) {
|
||||
// Write next char and skip it (escape sequence)
|
||||
i++
|
||||
out.WriteByte(src[i])
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if rune(c) == quote {
|
||||
quote = 0
|
||||
}
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(rune(c))
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Not in quote/comment.
|
||||
switch c {
|
||||
case '\'', '"', '`':
|
||||
quote = rune(c)
|
||||
out.WriteByte(c)
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
i++
|
||||
continue
|
||||
case '#':
|
||||
if atLineStart || prevIsSpace {
|
||||
inLine = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
case '/':
|
||||
if i+1 < len(src) {
|
||||
n := src[i+1]
|
||||
if (atLineStart || prevIsSpace) && n == '/' {
|
||||
inLine = true
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if (atLineStart || prevIsSpace) && n == '*' {
|
||||
inBlock = true
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out.WriteByte(c)
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(rune(c))
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
if inBlock {
|
||||
return "", ErrInvalidBlockSyntax.Withf("unterminated block comment")
|
||||
}
|
||||
return out.String(), nil
|
||||
}
|
||||
|
||||
// parseRuleHeader parses the rule header (default or on expression).
|
||||
// Returns the header string, or "" if parsing failed.
|
||||
func parseRuleHeader(t *Tokenizer, src string, pos *int, length int) string {
|
||||
start := *pos
|
||||
|
||||
// Check for 'default' keyword
|
||||
if *pos+7 <= length && src[*pos:*pos+7] == "default" {
|
||||
next := *pos + 7
|
||||
if next >= length || unicode.IsSpace(rune(src[next])) {
|
||||
*pos = next
|
||||
return "default"
|
||||
}
|
||||
}
|
||||
|
||||
// Parse on expression until we hit '{' outside quotes.
|
||||
bracePos, err := t.scanToBrace(*pos)
|
||||
if err != nil {
|
||||
*pos = length
|
||||
return strings.TrimSpace(src[start:*pos])
|
||||
}
|
||||
*pos = bracePos
|
||||
return strings.TrimSpace(src[start:*pos])
|
||||
}
|
||||
48
internal/route/rules/block_parser_bench_test.go
Normal file
48
internal/route/rules/block_parser_bench_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rules
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkParseBlockRules(b *testing.B) {
|
||||
const rulesString = `
|
||||
default {
|
||||
remove resp_header X-Secret
|
||||
add resp_header X-Custom-Header custom-value
|
||||
}
|
||||
|
||||
header X-Test-Header {
|
||||
set header X-Remote-Type public
|
||||
remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||
set header X-Remote-Type private
|
||||
}
|
||||
}
|
||||
|
||||
path glob(/api/admin/*) {
|
||||
cookie session-id {
|
||||
set header X-Session-ID $cookie(session-id)
|
||||
}
|
||||
}
|
||||
|
||||
!remote 192.168.0.0/16 {
|
||||
!header X-User-Role admin & !header X-User-Role user {
|
||||
error 403 "Access denied"
|
||||
} elif remote 127.0.0.1 {
|
||||
header X-User-Role staff {
|
||||
set header X-User-Role staff
|
||||
}
|
||||
} else {
|
||||
error 403 "Access denied"
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
var rules Rules
|
||||
err := rules.Parse(rulesString)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for b.Loop() {
|
||||
var rules Rules
|
||||
_ = rules.Parse(rulesString)
|
||||
}
|
||||
}
|
||||
390
internal/route/rules/block_parser_test.go
Normal file
390
internal/route/rules/block_parser_test.go
Normal file
@@ -0,0 +1,390 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
func testParseRules(t *testing.T, data string) Rules {
|
||||
t.Helper()
|
||||
|
||||
var rules Rules
|
||||
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
|
||||
require.True(t, convertible)
|
||||
require.NoError(t, err)
|
||||
return rules
|
||||
}
|
||||
|
||||
func testParseRulesError(t *testing.T, data string) error {
|
||||
t.Helper()
|
||||
|
||||
var rules Rules
|
||||
convertible, err := serialization.ConvertString(data, reflect.ValueOf(&rules))
|
||||
require.True(t, convertible)
|
||||
return err
|
||||
}
|
||||
|
||||
func TestParseBlockRules_DefaultRule(t *testing.T) {
|
||||
rules := testParseRules(t, `default {
|
||||
upstream
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "upstream", rules[0].Do.raw)
|
||||
assert.True(t, rules[0].Do.raw == CommandUpstream)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_ConditionalRule(t *testing.T) {
|
||||
rules := testParseRules(t, `path glob(/api/*) {
|
||||
proxy http://localhost:8080
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, "path glob(/api/*)", rules[0].On.raw)
|
||||
assert.Equal(t, "proxy http://localhost:8080", rules[0].Do.raw)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[0].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_MultipleRules(t *testing.T) {
|
||||
rules := testParseRules(t, `default {
|
||||
bypass
|
||||
}
|
||||
|
||||
path /api/* {
|
||||
proxy http://localhost:8080
|
||||
}
|
||||
|
||||
header Connection Upgrade &
|
||||
header Upgrade websocket {
|
||||
route ws-api
|
||||
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"
|
||||
}`)
|
||||
require.Len(t, rules, 3)
|
||||
|
||||
// Default rule
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
|
||||
// API rule
|
||||
assert.Equal(t, "path /api/*", rules[1].On.raw)
|
||||
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
|
||||
|
||||
// WebSocket rule
|
||||
assert.Equal(t, "header Connection Upgrade &\nheader Upgrade websocket", rules[2].On.raw)
|
||||
assert.Equal(t, `route ws-api
|
||||
log info /dev/stdout "Websocket request $req_path from $remote_host to $upstream_name"`, rules[2].Do.raw)
|
||||
require.Len(t, rules[2].Do.pre, 2)
|
||||
_, ok := rules[2].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
_, ok = rules[2].Do.pre[1].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[2].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_Comments(t *testing.T) {
|
||||
rules := testParseRules(t, `// This is a comment
|
||||
default {
|
||||
bypass // inline comment
|
||||
}
|
||||
|
||||
/* Block comment
|
||||
spanning multiple lines */
|
||||
path /admin/* {
|
||||
require_auth
|
||||
}`)
|
||||
require.Len(t, rules, 2)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "path /admin/*", rules[1].On.raw)
|
||||
assert.Equal(t, "require_auth", rules[1].Do.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_HashComment(t *testing.T) {
|
||||
rules := testParseRules(t, `# YAML-style comment
|
||||
default {
|
||||
bypass
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_EnvVars(t *testing.T) {
|
||||
t.Setenv("CUSTOM_HEADER", "test-header")
|
||||
|
||||
rules := testParseRules(t, `path /api/* {
|
||||
set header X-Custom "${CUSTOM_HEADER}"
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, "path /api/*", rules[0].On.raw)
|
||||
assert.Equal(t, `set header X-Custom "test-header"`, rules[0].Do.raw)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[0].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_YAMLFallback(t *testing.T) {
|
||||
rules := testParseRules(t, `- name: default
|
||||
do: bypass
|
||||
- name: api
|
||||
on: path glob(/api/*)
|
||||
do: proxy http://localhost:8080`)
|
||||
require.Len(t, rules, 2)
|
||||
assert.Equal(t, "default", rules[0].Name)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
assert.Equal(t, "api", rules[1].Name)
|
||||
assert.Equal(t, "path glob(/api/*)", rules[1].On.raw)
|
||||
assert.Equal(t, "proxy http://localhost:8080", rules[1].Do.raw)
|
||||
require.Len(t, rules[1].Do.pre, 1)
|
||||
_, ok := rules[1].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[1].Do.post, 0)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_UnmatchedBrace(t *testing.T) {
|
||||
t.Run("unquoted", func(t *testing.T) {
|
||||
err := testParseRulesError(t, `path /api/* {
|
||||
proxy http://localhost:8080}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("quoted", func(t *testing.T) {
|
||||
_ = testParseRules(t, `path /api/* {
|
||||
error 403 "some message}"
|
||||
}`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseBlockRules_UnterminatedBlockComment(t *testing.T) {
|
||||
err := testParseRulesError(t, `/* unterminated block comment
|
||||
default {
|
||||
bypass
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks(t *testing.T) {
|
||||
rules := testParseRules(t, `
|
||||
header X-Test-Header {
|
||||
set header X-Remote-Type public
|
||||
remote 127.0.0.1 | remote 192.168.0.0/16 {
|
||||
set header X-Remote-Type private
|
||||
}
|
||||
}`)
|
||||
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, "header X-Test-Header", rules[0].On.raw)
|
||||
|
||||
require.Len(t, rules[0].Do.pre, 2)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
require.Len(t, rules[0].Do.post, 0)
|
||||
|
||||
ifCmd, ok := rules[0].Do.pre[1].(IfBlockCommand)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "remote 127.0.0.1 | remote 192.168.0.0/16", ifCmd.On.raw)
|
||||
|
||||
require.Len(t, ifCmd.Do, 1)
|
||||
|
||||
upstream := func(http.ResponseWriter, *http.Request) {}
|
||||
|
||||
t.Run("condition matched executes nested content", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Test-Header", "1")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "private", req.Header.Get("X-Remote-Type"))
|
||||
})
|
||||
|
||||
t.Run("condition not matched skips nested content", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Test-Header", "1")
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "public", req.Header.Get("X-Remote-Type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElifElse(t *testing.T) {
|
||||
rules := testParseRules(t, `
|
||||
header X-Test-Header {
|
||||
set header X-Mode outer
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
} elif method POST {
|
||||
set header X-Mode post
|
||||
} else {
|
||||
set header X-Mode other
|
||||
}
|
||||
}`)
|
||||
|
||||
require.Len(t, rules, 1)
|
||||
|
||||
require.Len(t, rules[0].Do.pre, 2)
|
||||
|
||||
ifCmd, ok := rules[0].Do.pre[1].(IfElseBlockCommand)
|
||||
require.True(t, ok)
|
||||
require.Len(t, ifCmd.Ifs, 2)
|
||||
assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw)
|
||||
assert.Equal(t, "method POST", ifCmd.Ifs[1].On.raw)
|
||||
require.NotNil(t, ifCmd.Else)
|
||||
|
||||
upstream := func(http.ResponseWriter, *http.Request) {}
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
want string
|
||||
}{
|
||||
{name: "get branch", method: http.MethodGet, want: "get"},
|
||||
{name: "post branch", method: http.MethodPost, want: "post"},
|
||||
{name: "else branch", method: http.MethodPut, want: "other"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.method, "/", nil)
|
||||
req.Header.Set("X-Test-Header", "1")
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := rules[0].Do.pre.ServeHTTP(rm, req, upstream)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.want, req.Header.Get("X-Mode"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBlockRules_DefaultRule_CommentBeforeBrace(t *testing.T) {
|
||||
rules := testParseRules(t, `default /* comment between header and brace */ {
|
||||
bypass
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
assert.Equal(t, OnDefault, rules[0].On.raw)
|
||||
assert.Equal(t, "bypass", rules[0].Do.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_StrayClosingBraceAtTopLevel(t *testing.T) {
|
||||
err := testParseRulesError(t, `}
|
||||
default {
|
||||
bypass
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElifMustBeSameLine(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
}
|
||||
elif method POST {
|
||||
set header X-Mode post
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElseMustBeLastOnLine(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
} else {
|
||||
set header X-Mode other
|
||||
} set header X-After else
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unexpected token after else block")
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_MultipleElse(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
} else {
|
||||
set header X-Mode other
|
||||
} else {
|
||||
set header X-Mode other2
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
// assert.Contains(t, err.Error(), "multiple 'else' branches")
|
||||
assert.Contains(t, err.Error(), "unexpected token after else block")
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_ElifMissingOnExpr(t *testing.T) {
|
||||
err := testParseRulesError(t, `header X-Test-Header {
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
} elif {
|
||||
set header X-Mode post
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expected on-expr after 'elif'")
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_LineEndingBraceHeuristic(t *testing.T) {
|
||||
rules := testParseRules(t, `{
|
||||
set header X-Literal "{"
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
_, ok := rules[0].Do.pre[0].(Handler)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_LineEndingBraceWithTrailingSpaces(t *testing.T) {
|
||||
rules := testParseRules(t, `header X-Test-Header {
|
||||
method GET {
|
||||
set header X-Mode get
|
||||
}
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
ifCmd, ok := rules[0].Do.pre[0].(IfBlockCommand)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "method GET", ifCmd.On.raw)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_LineEndingBraceWithTrailingComment(t *testing.T) {
|
||||
rules := testParseRules(t, `header X-Test-Header {
|
||||
method GET { // GET branch
|
||||
set header X-Mode get
|
||||
} else { # fallback branch
|
||||
set header X-Mode other
|
||||
}
|
||||
}`)
|
||||
require.Len(t, rules, 1)
|
||||
require.Len(t, rules[0].Do.pre, 1)
|
||||
|
||||
ifCmd, ok := rules[0].Do.pre[0].(IfElseBlockCommand)
|
||||
require.True(t, ok)
|
||||
require.Len(t, ifCmd.Ifs, 1)
|
||||
assert.Equal(t, "method GET", ifCmd.Ifs[0].On.raw)
|
||||
require.NotNil(t, ifCmd.Else)
|
||||
}
|
||||
|
||||
func TestParseBlockRules_NestedBlocks_LineEndingBraceInterpretsAsBlock(t *testing.T) {
|
||||
err := testParseRulesError(t, `{
|
||||
set header X-Bad {
|
||||
set header X-Test fail
|
||||
}
|
||||
}`)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid `rule.on` target")
|
||||
}
|
||||
@@ -1,21 +1,25 @@
|
||||
package rules
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
type (
|
||||
CheckFunc func(w http.ResponseWriter, r *http.Request) bool
|
||||
CheckFunc func(w *httputils.ResponseModifier, r *http.Request) bool
|
||||
Checker interface {
|
||||
Check(w http.ResponseWriter, r *http.Request) bool
|
||||
Check(w *httputils.ResponseModifier, r *http.Request) bool
|
||||
}
|
||||
CheckMatchSingle []Checker
|
||||
CheckMatchAll []Checker
|
||||
)
|
||||
|
||||
func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (checker CheckFunc) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return checker(w, r)
|
||||
}
|
||||
|
||||
func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (checkers CheckMatchSingle) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
for _, check := range checkers {
|
||||
if check.Check(w, r) {
|
||||
return true
|
||||
@@ -24,7 +28,7 @@ func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) b
|
||||
return false
|
||||
}
|
||||
|
||||
func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (checkers CheckMatchAll) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
for _, check := range checkers {
|
||||
if !check.Check(w, r) {
|
||||
return false
|
||||
|
||||
@@ -1,79 +1,62 @@
|
||||
package rules
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
var errTerminateRule = errors.New("terminate rule")
|
||||
|
||||
type (
|
||||
handlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
HandlerFunc func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
|
||||
Handler struct {
|
||||
fn HandlerFunc
|
||||
phase PhaseFlag
|
||||
terminate bool
|
||||
}
|
||||
|
||||
CommandHandler interface {
|
||||
// CommandHandler can read and modify the values
|
||||
// then handle the request
|
||||
// finally proceed to next command (or return) base on situation
|
||||
Handle(w http.ResponseWriter, r *http.Request) error
|
||||
IsResponseHandler() bool
|
||||
ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error
|
||||
Phase() PhaseFlag
|
||||
}
|
||||
// NonTerminatingCommand will run then proceed to next command or reverse proxy.
|
||||
NonTerminatingCommand handlerFunc
|
||||
// TerminatingCommand will run then return immediately.
|
||||
TerminatingCommand handlerFunc
|
||||
// OnResponseCommand will run then return based on the response.
|
||||
OnResponseCommand handlerFunc
|
||||
// BypassCommand will skip all the following commands
|
||||
// and directly return to reverse proxy.
|
||||
BypassCommand struct{}
|
||||
|
||||
// Commands is a slice of CommandHandler.
|
||||
Commands []CommandHandler
|
||||
)
|
||||
|
||||
func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return c(w, r)
|
||||
func (h Handler) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
return h.fn(w, r, upstream)
|
||||
}
|
||||
|
||||
func (c NonTerminatingCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
func (h Handler) Phase() PhaseFlag {
|
||||
return h.phase
|
||||
}
|
||||
|
||||
func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
if err := c(w, r); err != nil {
|
||||
return err
|
||||
}
|
||||
return errTerminated
|
||||
func (h Handler) Terminates() bool {
|
||||
return h.terminate
|
||||
}
|
||||
|
||||
func (c TerminatingCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return c(w, r)
|
||||
}
|
||||
|
||||
func (c OnResponseCommand) IsResponseHandler() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return errTerminated
|
||||
}
|
||||
|
||||
func (c BypassCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
func (c Commands) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
for _, cmd := range c {
|
||||
if err := cmd.Handle(w, r); err != nil {
|
||||
err := cmd.ServeHTTP(w, r, upstream)
|
||||
if err != nil {
|
||||
// Terminating actions stop the command chain immediately.
|
||||
// Will be handled by the caller.
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Commands) IsResponseHandler() bool {
|
||||
func (c Commands) Phase() PhaseFlag {
|
||||
req := PhaseNone
|
||||
for _, cmd := range c {
|
||||
if cmd.IsResponseHandler() {
|
||||
return true
|
||||
}
|
||||
req |= cmd.Phase()
|
||||
}
|
||||
return false
|
||||
return req
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -24,17 +23,17 @@ import (
|
||||
|
||||
type (
|
||||
Command struct {
|
||||
raw string
|
||||
exec CommandHandler
|
||||
isResponseHandler bool
|
||||
raw string
|
||||
pre Commands // runs before w.WriteHeader
|
||||
post Commands
|
||||
}
|
||||
)
|
||||
|
||||
func (cmd *Command) IsResponseHandler() bool {
|
||||
return cmd.isResponseHandler
|
||||
}
|
||||
|
||||
const (
|
||||
CommandUpstream = "upstream"
|
||||
CommandUpstreamOld = "bypass"
|
||||
CommandUpstreamOld2 = "pass"
|
||||
|
||||
CommandRequireAuth = "require_auth"
|
||||
CommandRewrite = "rewrite"
|
||||
CommandServe = "serve"
|
||||
@@ -48,8 +47,6 @@ const (
|
||||
CommandRemove = "remove"
|
||||
CommandLog = "log"
|
||||
CommandNotify = "notify"
|
||||
CommandPass = "pass"
|
||||
CommandPassAlt = "bypass"
|
||||
)
|
||||
|
||||
type AuthHandler func(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
@@ -60,36 +57,60 @@ func InitAuthHandler(handler AuthHandler) {
|
||||
authHandler = handler
|
||||
}
|
||||
|
||||
func init() {
|
||||
commands[CommandUpstreamOld] = commands[CommandUpstream]
|
||||
commands[CommandUpstreamOld2] = commands[CommandUpstream]
|
||||
}
|
||||
|
||||
var commands = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
build func(args any) CommandHandler
|
||||
isResponseHandler bool
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
build func(args any) HandlerFunc
|
||||
terminate bool
|
||||
}{
|
||||
CommandUpstream: {
|
||||
help: Help{
|
||||
command: CommandUpstream,
|
||||
description: makeLines("Pass the request to the upstream"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 0 {
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
return phase, nil, nil
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
upstream(w, r)
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRequireAuth: {
|
||||
help: Help{
|
||||
command: CommandRequireAuth,
|
||||
description: makeLines("Require HTTP authentication for incoming requests"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
return phase, nil, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if authHandler == nil {
|
||||
http.Error(w, "Auth handler not initialized", http.StatusInternalServerError)
|
||||
return errTerminated
|
||||
build: func(args any) HandlerFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if authHandler == nil { // no auth handler configured, allow request to proceed
|
||||
return nil
|
||||
}
|
||||
if !authHandler(w, r) {
|
||||
return errTerminated
|
||||
if proceed := authHandler(w, r); !proceed {
|
||||
return errTerminateRule
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
CommandRewrite: {
|
||||
@@ -104,26 +125,27 @@ var commands = map[string]struct {
|
||||
"to": "the path to rewrite to, must start with /",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
path1, err1 := validateURLPath(args[:1])
|
||||
path2, err2 := validateURLPath(args[1:])
|
||||
if err1 != nil {
|
||||
err1 = gperr.PrependSubject(err1, "from")
|
||||
err1 = gperr.Errorf("from: %w", err1)
|
||||
}
|
||||
if err2 != nil {
|
||||
err2 = gperr.PrependSubject(err2, "to")
|
||||
err2 = gperr.Errorf("to: %w", err2)
|
||||
}
|
||||
if err1 != nil || err2 != nil {
|
||||
return nil, gperr.Join(err1, err2)
|
||||
return phase, nil, gperr.Join(err1, err2)
|
||||
}
|
||||
return &StrTuple{path1.(string), path2.(string)}, nil
|
||||
return phase, &StrTuple{path1.(string), path2.(string)}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
orig, repl := args.(*StrTuple).Unpack()
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
path := r.URL.Path
|
||||
if len(path) > 0 && path[0] != '/' {
|
||||
path = "/" + path
|
||||
@@ -133,10 +155,10 @@ var commands = map[string]struct {
|
||||
}
|
||||
path = repl + path[len(orig):]
|
||||
r.URL.Path = path
|
||||
r.URL.RawPath = r.URL.EscapedPath()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.URL.RawPath = ""
|
||||
r.RequestURI = ""
|
||||
return nil
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
CommandServe: {
|
||||
@@ -150,14 +172,19 @@ var commands = map[string]struct {
|
||||
"root": "the file system path to serve, must be an existing directory",
|
||||
},
|
||||
},
|
||||
validate: validateFSPath,
|
||||
build: func(args any) CommandHandler {
|
||||
root := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||
return nil
|
||||
})
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
parsedArgs, err = validateFSPath(args)
|
||||
return
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
root := args.(string)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRedirect: {
|
||||
help: Help{
|
||||
@@ -170,14 +197,19 @@ var commands = map[string]struct {
|
||||
"to": "the url to redirect to, can be relative or absolute URL",
|
||||
},
|
||||
},
|
||||
validate: validateURL,
|
||||
build: func(args any) CommandHandler {
|
||||
target := args.(*nettypes.URL).String()
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||
return nil
|
||||
})
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
parsedArgs, err = validateURL(args)
|
||||
return
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
target := args.(*nettypes.URL).String()
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRoute: {
|
||||
help: Help{
|
||||
@@ -190,15 +222,16 @@ var commands = map[string]struct {
|
||||
"route": "the route to route to",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
return args[0], nil
|
||||
return phase, args[0], nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
route := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, req *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, req *http.Request, upstream http.HandlerFunc) error {
|
||||
ep := entrypoint.FromCtx(req.Context())
|
||||
r, ok := ep.HTTPRoutes().Get(route)
|
||||
if !ok {
|
||||
@@ -212,9 +245,10 @@ var commands = map[string]struct {
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("Route %q not found", route), http.StatusNotFound)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandError: {
|
||||
help: Help{
|
||||
@@ -228,34 +262,40 @@ var commands = map[string]struct {
|
||||
"text": "the error message to return",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
codeStr, text := args[0], args[1]
|
||||
code, err := strconv.Atoi(codeStr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return phase, nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if !httputils.IsStatusCodeValid(code) {
|
||||
return nil, ErrInvalidArguments.Subject(codeStr)
|
||||
return phase, nil, ErrInvalidArguments.Subject(codeStr)
|
||||
}
|
||||
textTmpl, err := validateTemplate(text, true)
|
||||
tmplReq, textTmpl, err := validateTemplate(text, true)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return phase, nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return &Tuple[int, templateString]{code, textTmpl}, nil
|
||||
phase |= tmplReq
|
||||
return phase, &Tuple[int, templateString]{code, textTmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
code, textTmpl := args.(*Tuple[int, templateString]).Unpack()
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
// error command should overwrite the response body
|
||||
httputils.GetInitResponseModifier(w).ResetBody()
|
||||
w.ResetBody()
|
||||
w.WriteHeader(code)
|
||||
err := textTmpl.ExpandVars(w, r, w)
|
||||
return err
|
||||
})
|
||||
_, err := textTmpl.ExpandVars(w, r, w.BodyBuffer())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandRequireBasicAuth: {
|
||||
help: Help{
|
||||
@@ -268,20 +308,22 @@ var commands = map[string]struct {
|
||||
"realm": "the authentication realm",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
if len(args) == 1 {
|
||||
return args[0], nil
|
||||
return phase, args[0], nil
|
||||
}
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
realm := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm=%q`, realm))
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandProxy: {
|
||||
help: Help{
|
||||
@@ -294,14 +336,19 @@ var commands = map[string]struct {
|
||||
"to": "the url to proxy to, must be an absolute URL",
|
||||
},
|
||||
},
|
||||
validate: validateURL,
|
||||
build: func(args any) CommandHandler {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePre
|
||||
parsedArgs, err = validateURL(args)
|
||||
return
|
||||
},
|
||||
build: func(args any) HandlerFunc {
|
||||
target := args.(*nettypes.URL)
|
||||
if target.Scheme == "" {
|
||||
target.Scheme = "http"
|
||||
}
|
||||
if target.Host == "" {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rawPath := target.EscapedPath()
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
url := target.URL
|
||||
url.Host = routes.TryGetUpstreamHostPort(r)
|
||||
if url.Host == "" {
|
||||
@@ -309,18 +356,19 @@ var commands = map[string]struct {
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
|
||||
r.URL.Path = target.Path
|
||||
r.URL.RawPath = r.URL.EscapedPath()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.URL.RawPath = rawPath
|
||||
r.RequestURI = ""
|
||||
rp.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
rp.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
return errTerminateRule
|
||||
}
|
||||
},
|
||||
terminate: true,
|
||||
},
|
||||
CommandSet: {
|
||||
help: Help{
|
||||
@@ -335,11 +383,11 @@ var commands = map[string]struct {
|
||||
"value": "the value to set",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
return validateModField(ModFieldSet, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return args.(CommandHandler)
|
||||
build: func(args any) HandlerFunc {
|
||||
return args.(HandlerFunc)
|
||||
},
|
||||
},
|
||||
CommandAdd: {
|
||||
@@ -355,11 +403,11 @@ var commands = map[string]struct {
|
||||
"value": "the value to add",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
return validateModField(ModFieldAdd, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return args.(CommandHandler)
|
||||
build: func(args any) HandlerFunc {
|
||||
return args.(HandlerFunc)
|
||||
},
|
||||
},
|
||||
CommandRemove: {
|
||||
@@ -374,15 +422,14 @@ var commands = map[string]struct {
|
||||
"field": "the field to remove",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
return validateModField(ModFieldRemove, args)
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return args.(CommandHandler)
|
||||
build: func(args any) HandlerFunc {
|
||||
return args.(HandlerFunc)
|
||||
},
|
||||
},
|
||||
CommandLog: {
|
||||
isResponseHandler: true,
|
||||
help: Help{
|
||||
command: CommandLog,
|
||||
description: makeLines(
|
||||
@@ -399,46 +446,57 @@ var commands = map[string]struct {
|
||||
"template": "the template to log",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrExpectThreeArgs
|
||||
return phase, nil, ErrExpectThreeArgs
|
||||
}
|
||||
tmpl, err := validateTemplate(args[2], true)
|
||||
phase, tmpl, err := validateTemplate(args[2], true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
level, err := validateLevel(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
// NOTE: file will stay opened forever
|
||||
// it leverages accesslog.NewFileIO so
|
||||
// it will be opened only once for the same path
|
||||
f, err := openFile(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
return &onLogArgs{level, f, tmpl}, nil
|
||||
return phase, &onLogArgs{level, f, tmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
level, f, tmpl := args.(*onLogArgs).Unpack()
|
||||
var logger io.Writer
|
||||
if f == stdout || f == stderr {
|
||||
isStdLogger := f == stdout || f == stderr
|
||||
if isStdLogger {
|
||||
logger = logging.NewLoggerWithFixedLevel(level, f)
|
||||
} else {
|
||||
logger = f
|
||||
}
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
err := tmpl.ExpandVars(w, r, logger)
|
||||
if err != nil {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if isStdLogger {
|
||||
bufPool := w.BufPool()
|
||||
buf := bufPool.GetBuffer()
|
||||
defer bufPool.PutBuffer(buf)
|
||||
|
||||
if _, err := tmpl.ExpandVars(w, r, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if buf.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := logger.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
_, err := tmpl.ExpandVars(w, r, logger)
|
||||
return err
|
||||
}
|
||||
},
|
||||
},
|
||||
CommandNotify: {
|
||||
isResponseHandler: true,
|
||||
help: Help{
|
||||
command: CommandNotify,
|
||||
description: makeLines(
|
||||
@@ -456,22 +514,24 @@ var commands = map[string]struct {
|
||||
"body": "the body of the notification",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 4 {
|
||||
return nil, ErrExpectFourArgs
|
||||
return phase, nil, ErrExpectFourArgs
|
||||
}
|
||||
titleTmpl, err := validateTemplate(args[2], false)
|
||||
req1, titleTmpl, err := validateTemplate(args[2], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
bodyTmpl, err := validateTemplate(args[3], false)
|
||||
req2, bodyTmpl, err := validateTemplate(args[3], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
level, err := validateLevel(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return phase, nil, err
|
||||
}
|
||||
|
||||
phase |= req1 | req2
|
||||
// TODO: validate provider
|
||||
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
|
||||
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
|
||||
@@ -480,34 +540,34 @@ var commands = map[string]struct {
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
||||
return phase, &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
build: func(args any) HandlerFunc {
|
||||
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
|
||||
to := []string{provider}
|
||||
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
respBuf := bytes.NewBuffer(make([]byte, 0, titleTmpl.Len()+bodyTmpl.Len()))
|
||||
return func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
var respBuf strings.Builder
|
||||
|
||||
err := titleTmpl.ExpandVars(w, r, respBuf)
|
||||
_, err := titleTmpl.ExpandVars(w, r, &respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
titleLen := respBuf.Len()
|
||||
err = bodyTmpl.ExpandVars(w, r, respBuf)
|
||||
_, err = bodyTmpl.ExpandVars(w, r, &respBuf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b := respBuf.Bytes()
|
||||
s := respBuf.String()
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: level,
|
||||
Title: string(b[:titleLen]),
|
||||
Body: notif.MessageBodyBytes(b[titleLen:]),
|
||||
Title: s[:titleLen],
|
||||
Body: notif.MessageBodyBytes(s[titleLen:]),
|
||||
To: to,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -519,121 +579,29 @@ type (
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (cmd *Command) Parse(v string) error {
|
||||
executors := make([]CommandHandler, 0)
|
||||
isResponseHandler := false
|
||||
for line := range strings.SplitSeq(v, "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
directive, args, err := parse(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if directive == CommandPass || directive == CommandPassAlt {
|
||||
if len(args) != 0 {
|
||||
return ErrExpectNoArg
|
||||
}
|
||||
executors = append(executors, BypassCommand{})
|
||||
continue
|
||||
}
|
||||
|
||||
builder, ok := commands[directive]
|
||||
if !ok {
|
||||
return ErrUnknownDirective.Subject(directive)
|
||||
}
|
||||
validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
|
||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
handler := builder.build(validArgs)
|
||||
executors = append(executors, handler)
|
||||
if builder.isResponseHandler || handler.IsResponseHandler() {
|
||||
isResponseHandler = true
|
||||
}
|
||||
executors, parseErr := parseDoWithBlocks(v)
|
||||
if parseErr != nil {
|
||||
return parseErr
|
||||
}
|
||||
|
||||
if len(executors) == 0 {
|
||||
cmd.raw = v
|
||||
cmd.exec = nil
|
||||
cmd.isResponseHandler = false
|
||||
cmd.pre = nil
|
||||
cmd.post = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
exec, err := buildCmd(executors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.raw = v
|
||||
cmd.exec = exec
|
||||
if exec.IsResponseHandler() {
|
||||
isResponseHandler = true
|
||||
for _, executor := range executors {
|
||||
if executor.Phase().IsPostRule() {
|
||||
cmd.post = append(cmd.post, executor)
|
||||
} else {
|
||||
cmd.pre = append(cmd.pre, executor)
|
||||
}
|
||||
}
|
||||
cmd.isResponseHandler = isResponseHandler
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
|
||||
// Validate the execution order.
|
||||
//
|
||||
// This allows sequences like:
|
||||
// route ws-api
|
||||
// log info /dev/stdout "..."
|
||||
// where the first command is request-phase and the last is response-phase.
|
||||
lastNonResp := -1
|
||||
seenResp := false
|
||||
for i, exec := range executors {
|
||||
if exec.IsResponseHandler() {
|
||||
seenResp = true
|
||||
continue
|
||||
}
|
||||
if seenResp {
|
||||
return nil, ErrInvalidCommandSequence.Withf("response handlers must be the last commands")
|
||||
}
|
||||
lastNonResp = i
|
||||
}
|
||||
|
||||
for i, exec := range executors {
|
||||
if i > lastNonResp {
|
||||
break // response-handler tail
|
||||
}
|
||||
switch exec.(type) {
|
||||
case TerminatingCommand, BypassCommand:
|
||||
if i != lastNonResp {
|
||||
return nil, ErrInvalidCommandSequence.
|
||||
Withf("a response handler or terminating/bypass command must be the last command")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Commands(executors), nil
|
||||
}
|
||||
|
||||
// Command is purely "bypass" or empty.
|
||||
func (cmd *Command) isBypass() bool {
|
||||
if cmd == nil {
|
||||
return true
|
||||
}
|
||||
switch cmd := cmd.exec.(type) {
|
||||
case BypassCommand:
|
||||
return true
|
||||
case Commands:
|
||||
// bypass command is always the last one
|
||||
_, ok := cmd[len(cmd)-1].(BypassCommand)
|
||||
return ok
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
return cmd.exec.Handle(w, r)
|
||||
}
|
||||
|
||||
func (cmd *Command) String() string {
|
||||
return cmd.raw
|
||||
}
|
||||
|
||||
436
internal/route/rules/do_blocks.go
Normal file
436
internal/route/rules/do_blocks.go
Normal file
@@ -0,0 +1,436 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
// IfBlockCommand is an inline conditional block inside a do-body.
|
||||
//
|
||||
// Syntax (within a rule do block):
|
||||
//
|
||||
// <on-expr> { <do...> }
|
||||
//
|
||||
// Semantics:
|
||||
// - Evaluated in the same phase the parent rule runs.
|
||||
// - If <on-expr> matches, run the nested commands in-order.
|
||||
// - Otherwise do nothing.
|
||||
//
|
||||
// NOTE: Per current design decision, we keep this permissive:
|
||||
// nested blocks may use response matchers and response commands; no extra phase validation is performed.
|
||||
type IfBlockCommand struct {
|
||||
On RuleOn
|
||||
Do []CommandHandler
|
||||
}
|
||||
|
||||
func (c IfBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if c.Do == nil {
|
||||
return nil
|
||||
}
|
||||
// If On.checker is nil, treat as unconditional (should not happen if parsed).
|
||||
if c.On.checker == nil {
|
||||
return Commands(c.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
if c.On.checker.Check(w, r) {
|
||||
return Commands(c.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c IfBlockCommand) Phase() PhaseFlag {
|
||||
phase := c.On.phase
|
||||
for _, cmd := range c.Do {
|
||||
phase |= cmd.Phase()
|
||||
}
|
||||
return phase
|
||||
}
|
||||
|
||||
// IfElseBlockCommand is a chained conditional block inside a do-body.
|
||||
//
|
||||
// Syntax (within a rule do block):
|
||||
//
|
||||
// <on-expr> { <do...> } elif <on-expr> { <do...> } ... else { <do...> }
|
||||
//
|
||||
// NOTE: `elif`/`else` must appear on the same line as the preceding closing brace (`}`),
|
||||
// e.g. `} elif ... {` and `} else {`.
|
||||
type IfElseBlockCommand struct {
|
||||
Ifs []IfBlockCommand
|
||||
Else []CommandHandler
|
||||
}
|
||||
|
||||
func (c IfElseBlockCommand) ServeHTTP(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
for _, br := range c.Ifs {
|
||||
// If On.checker is nil, treat as unconditional.
|
||||
if br.On.checker == nil {
|
||||
if br.Do == nil {
|
||||
return nil
|
||||
}
|
||||
return Commands(br.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
if br.On.checker.Check(w, r) {
|
||||
if br.Do == nil {
|
||||
return nil
|
||||
}
|
||||
return Commands(br.Do).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
}
|
||||
if len(c.Else) > 0 {
|
||||
return Commands(c.Else).ServeHTTP(w, r, upstream)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c IfElseBlockCommand) Phase() PhaseFlag {
|
||||
phase := PhaseNone
|
||||
for _, br := range c.Ifs {
|
||||
phase |= br.Phase()
|
||||
}
|
||||
if len(c.Else) > 0 {
|
||||
phase |= Commands(c.Else).Phase()
|
||||
}
|
||||
return phase
|
||||
}
|
||||
|
||||
func skipSameLineSpace(src string, pos int) int {
|
||||
for pos < len(src) {
|
||||
switch src[pos] {
|
||||
case '\n':
|
||||
return pos
|
||||
case '\r':
|
||||
pos++
|
||||
continue
|
||||
case ' ', '\t':
|
||||
pos++
|
||||
continue
|
||||
default:
|
||||
return pos
|
||||
}
|
||||
}
|
||||
return pos
|
||||
}
|
||||
|
||||
func parseAtBlockChain(src string, blockPos int) (CommandHandler, int, error) {
|
||||
length := len(src)
|
||||
headerStart := blockPos
|
||||
|
||||
parseBranch := func(onExpr string, bodyStart int, bodyEnd int) (RuleOn, []CommandHandler, error) {
|
||||
var on RuleOn
|
||||
if err := on.Parse(onExpr); err != nil {
|
||||
return RuleOn{}, nil, err
|
||||
}
|
||||
innerSrc := ""
|
||||
if bodyStart < bodyEnd {
|
||||
innerSrc = src[bodyStart:bodyEnd]
|
||||
}
|
||||
inner, err := parseDoWithBlocks(innerSrc)
|
||||
if err != nil {
|
||||
return RuleOn{}, nil, err
|
||||
}
|
||||
if len(inner) == 0 {
|
||||
return on, nil, nil
|
||||
}
|
||||
return on, inner, nil
|
||||
}
|
||||
|
||||
onExpr, bracePos, herr := parseHeaderToBrace(src, headerStart)
|
||||
if herr != nil {
|
||||
return nil, 0, herr
|
||||
}
|
||||
if onExpr == "" {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr before '{'")
|
||||
}
|
||||
if bracePos >= length || src[bracePos] != '{' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after nested block header")
|
||||
}
|
||||
|
||||
// Parse first <on-expr> { ... }
|
||||
p := bracePos
|
||||
bodyStart := p + 1
|
||||
bodyEnd, ferr := findMatchingBrace(src, &p, bodyStart)
|
||||
if ferr != nil {
|
||||
return nil, 0, ferr
|
||||
}
|
||||
firstOn, firstDo, berr := parseBranch(onExpr, bodyStart, bodyEnd)
|
||||
if berr != nil {
|
||||
return nil, 0, berr
|
||||
}
|
||||
|
||||
ifs := []IfBlockCommand{{On: firstOn, Do: firstDo}}
|
||||
var elseDo []CommandHandler
|
||||
hasChain := false
|
||||
hasElse := false
|
||||
|
||||
for {
|
||||
q := skipSameLineSpace(src, p)
|
||||
if q >= length || src[q] == '\n' {
|
||||
break
|
||||
}
|
||||
|
||||
// elif <on-expr> { ... }
|
||||
if strings.HasPrefix(src[q:], "elif") {
|
||||
next := q + len("elif")
|
||||
if next >= length {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
if src[next] == '\n' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
if !unicode.IsSpace(rune(src[next])) {
|
||||
if src[next] == '{' || src[next] == '}' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected whitespace after 'elif'")
|
||||
}
|
||||
next++
|
||||
for next < length {
|
||||
c := src[next]
|
||||
if c == '\n' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
|
||||
}
|
||||
if c == '\r' {
|
||||
next++
|
||||
continue
|
||||
}
|
||||
if !unicode.IsSpace(rune(c)) {
|
||||
break
|
||||
}
|
||||
next++
|
||||
}
|
||||
|
||||
p2 := next
|
||||
elifOnExpr, bracePos, herr := parseHeaderToBrace(src, p2)
|
||||
if herr != nil {
|
||||
return nil, 0, herr
|
||||
}
|
||||
if elifOnExpr == "" {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected on-expr after 'elif'")
|
||||
}
|
||||
if bracePos >= length || src[bracePos] != '{' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after elif condition")
|
||||
}
|
||||
p2 = bracePos
|
||||
elifBodyStart := p2 + 1
|
||||
elifBodyEnd, ferr := findMatchingBrace(src, &p2, elifBodyStart)
|
||||
if ferr != nil {
|
||||
return nil, 0, ferr
|
||||
}
|
||||
elifOn, elifDo, berr := parseBranch(elifOnExpr, elifBodyStart, elifBodyEnd)
|
||||
if berr != nil {
|
||||
return nil, 0, berr
|
||||
}
|
||||
ifs = append(ifs, IfBlockCommand{On: elifOn, Do: elifDo})
|
||||
hasChain = true
|
||||
p = p2
|
||||
continue
|
||||
}
|
||||
|
||||
// else { ... }
|
||||
if strings.HasPrefix(src[q:], "else") {
|
||||
if hasElse {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("multiple 'else' branches")
|
||||
}
|
||||
next := q + len("else")
|
||||
for next < length {
|
||||
c := src[next]
|
||||
if c == '\n' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
|
||||
}
|
||||
if c == '\r' {
|
||||
next++
|
||||
continue
|
||||
}
|
||||
if !unicode.IsSpace(rune(c)) {
|
||||
break
|
||||
}
|
||||
next++
|
||||
}
|
||||
if next >= length || src[next] != '{' {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("expected '{' after 'else'")
|
||||
}
|
||||
|
||||
elseBodyStart := next + 1
|
||||
p2 := next
|
||||
elseBodyEnd, ferr := findMatchingBrace(src, &p2, elseBodyStart)
|
||||
if ferr != nil {
|
||||
return nil, 0, ferr
|
||||
}
|
||||
innerSrc := ""
|
||||
if elseBodyStart < elseBodyEnd {
|
||||
innerSrc = src[elseBodyStart:elseBodyEnd]
|
||||
}
|
||||
inner, ierr := parseDoWithBlocks(innerSrc)
|
||||
if ierr != nil {
|
||||
return nil, 0, ierr
|
||||
}
|
||||
if len(inner) == 0 {
|
||||
elseDo = nil
|
||||
} else {
|
||||
elseDo = inner
|
||||
}
|
||||
hasChain = true
|
||||
hasElse = true
|
||||
p = p2
|
||||
|
||||
// else must be the last branch on that line.
|
||||
for q2 := skipSameLineSpace(src, p); q2 < length && src[q2] != '\n'; q2 = skipSameLineSpace(src, q2) {
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after else block")
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return nil, 0, ErrInvalidBlockSyntax.Withf("unexpected token after nested block; expected 'elif'/'else' or newline")
|
||||
}
|
||||
|
||||
if hasChain {
|
||||
return IfElseBlockCommand{Ifs: ifs, Else: elseDo}, p, nil
|
||||
}
|
||||
return IfBlockCommand{On: ifs[0].On, Do: ifs[0].Do}, p, nil
|
||||
}
|
||||
|
||||
func lineEndsWithUnquotedOpenBrace(src string, lineStart int, lineEnd int) bool {
|
||||
quote := byte(0)
|
||||
lastSignificant := byte(0)
|
||||
atLineStart := true
|
||||
prevIsSpace := true
|
||||
|
||||
for i := lineStart; i < lineEnd; i++ {
|
||||
c := src[i]
|
||||
if quote != 0 {
|
||||
if c == '\\' && i+1 < lineEnd {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if c == quote {
|
||||
quote = 0
|
||||
}
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
continue
|
||||
}
|
||||
if quoteChars[c] {
|
||||
quote = c
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
continue
|
||||
}
|
||||
if c == '#' && (atLineStart || prevIsSpace) {
|
||||
break
|
||||
}
|
||||
if c == '/' && i+1 < lineEnd {
|
||||
n := rune(src[i+1])
|
||||
if (atLineStart || prevIsSpace) && (n == '/' || n == '*') {
|
||||
break
|
||||
}
|
||||
}
|
||||
if unicode.IsSpace(rune(c)) {
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
lastSignificant = c
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
}
|
||||
return quote == 0 && lastSignificant == '{'
|
||||
}
|
||||
|
||||
// parseDoWithBlocks parses a do-body containing plain command lines and nested blocks.
|
||||
// It returns the outer command handlers and the require phase.
|
||||
//
|
||||
// A nested block is recognized when a line ends with an unquoted '{' (ignoring trailing whitespace).
|
||||
func parseDoWithBlocks(src string) (handlers []CommandHandler, err error) {
|
||||
pos := 0
|
||||
length := len(src)
|
||||
lineStart := true
|
||||
handlers = make([]CommandHandler, 0, strings.Count(src, "\n")+1)
|
||||
|
||||
appendLineCommand := func(line string) error {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
directive, args, err := parse(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
builder, ok := commands[directive]
|
||||
if !ok {
|
||||
return ErrUnknownDirective.Subject(directive)
|
||||
}
|
||||
|
||||
phase, validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
return gperr.PrependSubject(err, directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
h := builder.build(validArgs)
|
||||
handlers = append(handlers, Handler{fn: h, phase: phase, terminate: builder.terminate})
|
||||
return nil
|
||||
}
|
||||
|
||||
for pos < length {
|
||||
// Handle newlines
|
||||
switch src[pos] {
|
||||
case '\n':
|
||||
pos++
|
||||
lineStart = true
|
||||
continue
|
||||
case '\r':
|
||||
// tolerate CRLF
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if lineStart {
|
||||
// Find first non-space on the line.
|
||||
linePos := pos
|
||||
for linePos < length {
|
||||
c := rune(src[linePos])
|
||||
if c == '\n' {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(c) {
|
||||
break
|
||||
}
|
||||
linePos++
|
||||
}
|
||||
|
||||
lineEnd := linePos
|
||||
for lineEnd < length && src[lineEnd] != '\n' {
|
||||
lineEnd++
|
||||
}
|
||||
|
||||
if linePos < length && lineEndsWithUnquotedOpenBrace(src, linePos, lineEnd) {
|
||||
h, next, err := parseAtBlockChain(src, linePos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handlers = append(handlers, h)
|
||||
pos = next
|
||||
lineStart = false
|
||||
continue
|
||||
}
|
||||
|
||||
// Not a nested block; parse the rest of this line as a command.
|
||||
if lerr := appendLineCommand(src[pos:lineEnd]); lerr != nil {
|
||||
return nil, lerr
|
||||
}
|
||||
pos = lineEnd
|
||||
lineStart = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Not at line start; advance to the next line boundary.
|
||||
for pos < length && src[pos] != '\n' {
|
||||
pos++
|
||||
}
|
||||
lineStart = true
|
||||
}
|
||||
|
||||
return handlers, nil
|
||||
}
|
||||
73
internal/route/rules/do_blocks_test.go
Normal file
73
internal/route/rules/do_blocks_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
func TestIfElseBlockCommandServeHTTP_UnconditionalNilDoNotFallsThrough(t *testing.T) {
|
||||
elseCalled := false
|
||||
cmd := IfElseBlockCommand{
|
||||
Ifs: []IfBlockCommand{
|
||||
{
|
||||
On: RuleOn{},
|
||||
Do: nil,
|
||||
},
|
||||
},
|
||||
Else: []CommandHandler{
|
||||
Handler{
|
||||
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
|
||||
elseCalled = true
|
||||
return nil
|
||||
},
|
||||
phase: PhaseNone,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := cmd.ServeHTTP(rm, req, nil)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, elseCalled)
|
||||
}
|
||||
|
||||
func TestIfElseBlockCommandServeHTTP_ConditionalMatchedNilDoNotFallsThrough(t *testing.T) {
|
||||
elseCalled := false
|
||||
cmd := IfElseBlockCommand{
|
||||
Ifs: []IfBlockCommand{
|
||||
{
|
||||
On: RuleOn{
|
||||
checker: CheckFunc(func(_ *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
return true
|
||||
}),
|
||||
},
|
||||
Do: nil,
|
||||
},
|
||||
},
|
||||
Else: []CommandHandler{
|
||||
Handler{
|
||||
fn: func(_ *httputils.ResponseModifier, _ *http.Request, _ http.HandlerFunc) error {
|
||||
elseCalled = true
|
||||
return nil
|
||||
},
|
||||
phase: PhaseNone,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
err := cmd.ServeHTTP(rm, req, nil)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, elseCalled)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -37,7 +39,7 @@ func parseRules(data string, target *Rules) error {
|
||||
}
|
||||
|
||||
func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(200, "success response", http.Header{
|
||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success response", http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
})
|
||||
|
||||
@@ -45,10 +47,9 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-request-response
|
||||
do: |
|
||||
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
||||
`, logFile), &rules)
|
||||
default {
|
||||
log info %q '$req_method $req_url $status_code $resp_header(Content-Type)'
|
||||
}`, logFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -59,7 +60,7 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "success response", w.Body.String())
|
||||
|
||||
// Read and verify log content
|
||||
@@ -70,16 +71,25 @@ func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
originalStdout := stdout
|
||||
originalStderr := stderr
|
||||
var stdoutBuf bytes.Buffer
|
||||
var stderrBuf bytes.Buffer
|
||||
stdout = noopWriteCloser{&stdoutBuf}
|
||||
stderr = noopWriteCloser{&stderrBuf}
|
||||
defer func() {
|
||||
stdout = originalStdout
|
||||
stderr = originalStderr
|
||||
}()
|
||||
|
||||
upstream := mockUpstream(http.StatusOK, "success")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: log-stdout
|
||||
do: |
|
||||
log info /dev/stdout "stdout: $req_method $status_code"
|
||||
- name: log-stderr
|
||||
do: |
|
||||
log error /dev/stderr "stderr: $req_path $status_code"
|
||||
default {
|
||||
log info /dev/stdout "stdout: $req_method $status_code"
|
||||
log error /dev/stderr "stderr: $req_path $status_code"
|
||||
}
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -90,9 +100,13 @@ func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
// Note: We can't easily capture stdout/stderr in unit tests,
|
||||
// but we can verify no errors occurred and the handler completed
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
require.Eventually(t, func() bool {
|
||||
return strings.Contains(stdoutBuf.String(), "stdout: GET 200") &&
|
||||
strings.Contains(stderrBuf.String(), "stderr: /test 200")
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
assert.Equal(t, 1, strings.Count(stdoutBuf.String(), "stdout: GET 200"))
|
||||
assert.Equal(t, 1, strings.Count(stderrBuf.String(), "stderr: /test 200"))
|
||||
}
|
||||
|
||||
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
@@ -104,26 +118,22 @@ func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-info
|
||||
do: |
|
||||
log info %s "INFO: $req_method $status_code"
|
||||
- name: log-warn
|
||||
do: |
|
||||
log warn %s "WARN: $req_path $status_code"
|
||||
- name: log-error
|
||||
do: |
|
||||
log error %s "ERROR: $req_method $req_path $status_code"
|
||||
default {
|
||||
log info %s "INFO: $req_method $status_code"
|
||||
log warn %s "WARN: $req_path $status_code"
|
||||
log error %s "ERROR: $req_method $req_path $status_code"
|
||||
}
|
||||
`, infoFile, warnFile, errorFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/resource/123", nil)
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/resource/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 404, w.Code)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
|
||||
// Verify each log file
|
||||
infoContent := TestFileContent(infoFile)
|
||||
@@ -148,22 +158,22 @@ func TestLogCommand_TemplateVariables(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-with-templates
|
||||
do: |
|
||||
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
||||
default {
|
||||
log info %s 'Request: $req_method $req_url Host: $req_host User-Agent: $header(User-Agent) Response: $status_code Custom-Header: $resp_header(X-Custom-Header) Content-Length: $resp_header(Content-Length)'
|
||||
}
|
||||
`, tempFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/api/resource", nil)
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/resource", nil)
|
||||
req.Header.Set("User-Agent", "test-client/1.0")
|
||||
req.Host = "example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 201, w.Code)
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
// Verify log content
|
||||
content := TestFileContent(tempFile)
|
||||
@@ -192,14 +202,12 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-success
|
||||
on: status 2xx
|
||||
do: |
|
||||
log info %q "SUCCESS: $req_method $req_path $status_code"
|
||||
- name: log-error
|
||||
on: status 4xx | status 5xx
|
||||
do: |
|
||||
log error %q "ERROR: $req_method $req_path $status_code"
|
||||
status 2xx {
|
||||
log info %q "SUCCESS: $req_method $req_path $status_code"
|
||||
}
|
||||
status 4xx | status 5xx {
|
||||
log error %q "ERROR: $req_method $req_path $status_code"
|
||||
}
|
||||
`, successFile, errorFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -244,9 +252,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-multiple
|
||||
do: |
|
||||
log info %q "$req_method $req_path $status_code"`, tempFile), &rules)
|
||||
default {
|
||||
log info %q "$req_method $req_path $status_code"
|
||||
}
|
||||
`, tempFile), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -256,10 +265,10 @@ func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"GET", "/users"},
|
||||
{"POST", "/users"},
|
||||
{"PUT", "/users/1"},
|
||||
{"DELETE", "/users/1"},
|
||||
{http.MethodGet, "/users"},
|
||||
{http.MethodPost, "/users"},
|
||||
{http.MethodPost, "/users/1"},
|
||||
{http.MethodDelete, "/users/1"},
|
||||
}
|
||||
|
||||
for _, reqInfo := range requests {
|
||||
@@ -287,8 +296,9 @@ func TestLogCommand_InvalidTemplate(t *testing.T) {
|
||||
|
||||
// Test with invalid template syntax
|
||||
err := parseRules(`
|
||||
- name: log-invalid
|
||||
do: |
|
||||
log info /dev/stdout "$invalid_var"`, &rules)
|
||||
assert.ErrorIs(t, err, ErrUnexpectedVar)
|
||||
default {
|
||||
log info /dev/stdout "$invalid_var"
|
||||
}
|
||||
`, &rules)
|
||||
require.ErrorIs(t, err, ErrUnexpectedVar)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
type (
|
||||
FieldHandler struct {
|
||||
set, add, remove CommandHandler
|
||||
set, add, remove HandlerFunc
|
||||
}
|
||||
FieldModifier string
|
||||
)
|
||||
@@ -49,30 +49,30 @@ var modFields = map[string]struct {
|
||||
"value": "the header template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePreRequestKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Header[k] = []string{v}
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Header[k] = append(r.Header[k], v)
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
delete(r.Header, k)
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -84,30 +84,30 @@ var modFields = map[string]struct {
|
||||
"value": "the response header template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePostResponseKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header()[k] = []string{v}
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header()[k] = append(w.Header()[k], v)
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
delete(w.Header(), k)
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -119,36 +119,36 @@ var modFields = map[string]struct {
|
||||
"value": "the query template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePreRequestKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Set(k, v)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Add(k, v)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
httputils.GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.SharedData().UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Del(k)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -160,16 +160,16 @@ var modFields = map[string]struct {
|
||||
"value": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
validate: validatePreRequestKVTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
cookies[i].Value = v
|
||||
@@ -179,19 +179,19 @@ var modFields = map[string]struct {
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := tmpl.ExpandVarsToString(w, r)
|
||||
},
|
||||
add: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
v, _, err := tmpl.ExpandVarsToString(w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
httputils.GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
},
|
||||
remove: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.SharedData().UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
index := -1
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
@@ -208,7 +208,7 @@ var modFields = map[string]struct {
|
||||
return cookies
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -227,24 +227,27 @@ var modFields = map[string]struct {
|
||||
"template": "the body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return 0, nil, ErrExpectOneArg
|
||||
}
|
||||
return validateTemplate(args[0], true)
|
||||
phase = PhasePre
|
||||
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
|
||||
phase |= tmplReq
|
||||
return
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateString)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
if r.Body != nil {
|
||||
r.Body.Close()
|
||||
r.Body = nil
|
||||
}
|
||||
|
||||
bufPool := httputils.GetInitResponseModifier(w).BufPool()
|
||||
bufPool := w.BufPool()
|
||||
b := bufPool.GetBuffer()
|
||||
err := tmpl.ExpandVars(w, r, b)
|
||||
_, err := tmpl.ExpandVars(w, r, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -252,7 +255,7 @@ var modFields = map[string]struct {
|
||||
bufPool.PutBuffer(b)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -272,20 +275,26 @@ var modFields = map[string]struct {
|
||||
"template": "the response body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return 0, nil, ErrExpectOneArg
|
||||
}
|
||||
return validateTemplate(args[0], true)
|
||||
phase = PhasePost
|
||||
tmplReq, parsedArgs, err := validateTemplate(args[0], true)
|
||||
phase |= tmplReq
|
||||
return
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateString)
|
||||
return &FieldHandler{
|
||||
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rm := httputils.GetInitResponseModifier(w)
|
||||
rm.ResetBody()
|
||||
return tmpl.ExpandVars(w, r, rm)
|
||||
}),
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.ResetBody()
|
||||
_, err := tmpl.ExpandVars(w, r, w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -300,26 +309,27 @@ var modFields = map[string]struct {
|
||||
"code": "the status code",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
phase = PhasePost
|
||||
status, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
return phase, nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if status < 100 || status > 599 {
|
||||
return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
||||
return phase, nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
||||
}
|
||||
return status, nil
|
||||
return phase, status, nil
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
status := args.(int)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
httputils.GetInitResponseModifier(w).WriteHeader(status)
|
||||
set: func(w *httputils.ResponseModifier, r *http.Request, upstream http.HandlerFunc) error {
|
||||
w.WriteHeader(status)
|
||||
return nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -72,12 +71,12 @@ func TestFieldHandler_Header(t *testing.T) {
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -87,7 +86,7 @@ func TestFieldHandler_Header(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -150,12 +149,12 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
tt.setup(w)
|
||||
}
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -165,7 +164,7 @@ func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -237,12 +236,12 @@ func TestFieldHandler_Query(t *testing.T) {
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -252,7 +251,7 @@ func TestFieldHandler_Query(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -335,12 +334,12 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
var cmd HandlerFunc
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
@@ -350,7 +349,7 @@ func TestFieldHandler_Cookie(t *testing.T) {
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
err := cmd(httputils.NewResponseModifier(w), req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -371,7 +370,7 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
name: "set body with template",
|
||||
template: "Hello $req_method $req_path",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "POST"
|
||||
r.Method = http.MethodPost
|
||||
r.URL.Path = "/test"
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
@@ -399,15 +398,15 @@ func TestFieldHandler_Body(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.template, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.template, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to parse template: %v", tErr)
|
||||
}
|
||||
|
||||
handler := modFields[FieldBody].builder(tmpl)
|
||||
err := handler.set.Handle(w, req)
|
||||
err := handler.set(w, req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
@@ -428,7 +427,7 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
name: "set response body with template",
|
||||
template: "Response: $req_method $req_path",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "GET"
|
||||
r.Method = http.MethodGet
|
||||
r.URL.Path = "/api/test"
|
||||
},
|
||||
verify: func(rm *httputils.ResponseModifier) {
|
||||
@@ -443,23 +442,20 @@ func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
// Create ResponseModifier wrapper
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.template, false)
|
||||
_, tmpl, tErr := validateTemplate(tt.template, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to parse template: %v", tErr)
|
||||
}
|
||||
|
||||
handler := modFields[FieldResponseBody].builder(tmpl)
|
||||
err := handler.set.Handle(rm, req)
|
||||
err := handler.set(w, req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(rm)
|
||||
tt.verify(w)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -472,23 +468,23 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "set status code 200",
|
||||
status: 200,
|
||||
status: http.StatusOK,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 200, w.Code, "Expected status code 200")
|
||||
assert.Equal(t, http.StatusOK, w.Code, "Expected status code 200")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set status code 404",
|
||||
status: 404,
|
||||
status: http.StatusNotFound,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 404, w.Code, "Expected status code 404")
|
||||
assert.Equal(t, http.StatusNotFound, w.Code, "Expected status code 404")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set status code 500",
|
||||
status: 500,
|
||||
status: http.StatusInternalServerError,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 500, w.Code, "Expected status code 500")
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code, "Expected status code 500")
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -503,12 +499,11 @@ func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
err = cmd.ServeHTTP(rm, req)
|
||||
err = cmd.post.ServeHTTP(rm, req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
rm.FlushRelease()
|
||||
|
||||
tt.verify(w)
|
||||
})
|
||||
}
|
||||
@@ -600,7 +595,7 @@ func TestFieldValidation(t *testing.T) {
|
||||
field, exists := modFields[tt.field]
|
||||
assert.True(t, exists, "Field %s does not exist", tt.field)
|
||||
|
||||
_, err := field.validate(tt.args)
|
||||
_, _, err := field.validate(tt.args)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err, "Expected error but got none")
|
||||
} else {
|
||||
@@ -610,25 +605,6 @@ func TestFieldValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllFields(t *testing.T) {
|
||||
expectedFields := []string{
|
||||
FieldHeader,
|
||||
FieldResponseHeader,
|
||||
FieldQuery,
|
||||
FieldCookie,
|
||||
FieldBody,
|
||||
FieldResponseBody,
|
||||
FieldStatusCode,
|
||||
}
|
||||
|
||||
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
|
||||
|
||||
for _, expected := range expectedFields {
|
||||
found := slices.Contains(AllFields, expected)
|
||||
assert.True(t, found, "Expected field %s not found in AllFields", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModFields(t *testing.T) {
|
||||
for fieldName, field := range modFields {
|
||||
// Test that each field has required components
|
||||
|
||||
@@ -14,8 +14,9 @@ var (
|
||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
||||
ErrMultipleDefaultRules = gperr.New("multiple default rules")
|
||||
|
||||
ErrMultipleDefaultRules = gperr.New("multiple default rules")
|
||||
ErrDeadRule = gperr.New("dead rule")
|
||||
|
||||
// vars errors
|
||||
ErrNoArgProvided = gperr.New("no argument provided")
|
||||
@@ -31,5 +32,5 @@ var (
|
||||
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
|
||||
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
|
||||
|
||||
errTerminated = gperr.New("terminated")
|
||||
ErrInvalidBlockSyntax = gperr.New("invalid block syntax") // TODO: struct this error
|
||||
)
|
||||
|
||||
@@ -131,12 +131,13 @@ Error generates help string as error, e.g.
|
||||
from: the path to rewrite, must start with /
|
||||
to: the path to rewrite to, must start with /
|
||||
*/
|
||||
func (h *Help) Error() error {
|
||||
var lines gperr.MultilineError
|
||||
func (h *Help) Error() gperr.Error {
|
||||
help := gperr.New(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
for _, line := range h.description {
|
||||
help = help.Withf("%s", line)
|
||||
}
|
||||
|
||||
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
lines.AddStrings(h.description...)
|
||||
lines.Adds(" args:")
|
||||
args := gperr.New("args")
|
||||
|
||||
argKeys := make([]string, 0, len(h.args))
|
||||
longestArg := 0
|
||||
@@ -151,7 +152,9 @@ func (h *Help) Error() error {
|
||||
slices.Sort(argKeys)
|
||||
for _, arg := range argKeys {
|
||||
desc := h.args[arg]
|
||||
lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc)
|
||||
paddedArg := fmt.Sprintf("%-"+strconv.Itoa(longestArg)+"s", arg)
|
||||
args = args.Withf("%s%s", ansi.WithANSI(paddedArg, ansi.HighlightCyan)+": ", desc)
|
||||
}
|
||||
return &lines
|
||||
|
||||
return help.With(args)
|
||||
}
|
||||
|
||||
1356
internal/route/rules/http_flow_block_test.go
Normal file
1356
internal/route/rules/http_flow_block_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,9 @@ import (
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
func mockUpstream(body string) http.HandlerFunc {
|
||||
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
@@ -47,7 +48,7 @@ func parseRules(data string, target *Rules) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
func TestHTTPFlow_BasicPreRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -74,8 +75,8 @@ func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
upstream := mockUpstream("upstream response")
|
||||
func TestHTTPFlow_BypassRuleYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -99,8 +100,8 @@ func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
upstream := mockUpstream("should not be called")
|
||||
func TestHTTPFlow_TerminatingCommandYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -120,13 +121,13 @@ func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, "Forbidden\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
upstream := mockUpstream("should not be called")
|
||||
func TestHTTPFlow_RedirectFlowYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -143,11 +144,11 @@ func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, 307, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
func TestHTTPFlow_RewriteFlowYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("path: " + r.URL.Path))
|
||||
@@ -172,7 +173,7 @@ func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
assert.Equal(t, "path: /v1/users", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
func TestHTTPFlow_MultiplePreRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
||||
@@ -201,7 +202,7 @@ func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
func TestHTTPFlow_PostResponseRuleYAML(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(http.StatusOK, "success", http.Header{
|
||||
"X-Upstream": []string{"upstream-value"},
|
||||
})
|
||||
@@ -229,11 +230,10 @@ func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
|
||||
// Check log file
|
||||
content := TestFileContent(tempFile)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "GET 200\n", string(content))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseRuleWithStatusConditionYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/success" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -246,14 +246,15 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
|
||||
var rules Rules
|
||||
|
||||
// Create a temporary file for logging
|
||||
tempFile := TestRandomFileName()
|
||||
errorLog := TestRandomFileName()
|
||||
infoLog := TestRandomFileName()
|
||||
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: log-errors
|
||||
on: status 4xx
|
||||
- on: status 4xx
|
||||
do: log error %s "$req_url returned $status_code"
|
||||
`, tempFile), &rules)
|
||||
- on: status 200
|
||||
do: log info %s "$req_url returned $status_code"
|
||||
`, errorLog, infoLog), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
@@ -273,14 +274,18 @@ func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
|
||||
// Check log file
|
||||
content := TestFileContent(tempFile)
|
||||
require.NoError(t, err)
|
||||
content := TestFileContent(errorLog)
|
||||
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
|
||||
require.Len(t, lines, 1, "only 4xx requests should be logged")
|
||||
assert.Equal(t, "/notfound returned 404", lines[0])
|
||||
|
||||
infoContent := TestFileContent(infoLog)
|
||||
lines = strings.Split(strings.TrimSpace(string(infoContent)), "\n")
|
||||
require.Len(t, lines, 1, "only 200 requests should be logged")
|
||||
assert.Equal(t, "/success returned 200", lines[0])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
func TestHTTPFlow_ConditionalRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
||||
@@ -320,22 +325,21 @@ func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
func TestHTTPFlow_ComplexFlowWithPreAndPostRulesYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate different responses based on path
|
||||
if r.URL.Path == "/protected" {
|
||||
if r.Header.Get("X-Auth") != "valid" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte("unauthorized"))
|
||||
fmt.Fprint(w, "unauthorized")
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("X-Response-Time", "100ms")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
fmt.Fprint(w, "success")
|
||||
})
|
||||
|
||||
// Create temporary files for logging
|
||||
logFile := TestRandomFileName()
|
||||
errorLogFile := TestRandomFileName()
|
||||
|
||||
@@ -402,8 +406,8 @@ func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
upstream := mockUpstream("upstream response")
|
||||
func TestHTTPFlow_DefaultRuleYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
@@ -426,21 +430,57 @@ func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule + default rule
|
||||
// Test special rule (default should not run)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
func TestHTTPFlow_DefaultRuleWithOnDefaultYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: default-on-rule
|
||||
on: default
|
||||
do: set resp_header X-Default-Applied true
|
||||
- name: special-rule
|
||||
on: path /special
|
||||
do: set resp_header X-Special-Handled true
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test default rule on regular request
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/regular", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule on matching request (default should not run)
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
assert.Empty(t, w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_HeaderManipulationYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Echo back a header
|
||||
headerValue := r.Header.Get("X-Test-Header")
|
||||
w.Header().Set("X-Echoed-Header", headerValue)
|
||||
w.Header().Set("X-Secret", "sensitive-data")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("header echoed"))
|
||||
})
|
||||
@@ -460,7 +500,6 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Secret", "secret-value")
|
||||
req.Header.Set("X-Test-Header", "original-value")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@@ -469,11 +508,10 @@ func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
||||
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
||||
// Ensure the secret header was removed and not passed to upstream
|
||||
// (we can't directly test this, but the upstream shouldn't see it)
|
||||
assert.Empty(t, w.Header().Get("X-Secret"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
func TestHTTPFlow_QueryParameterHandlingYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -500,13 +538,15 @@ func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
assert.Equal(t, "query: added-value", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
func TestHTTPFlow_ServeCommandYAML(t *testing.T) {
|
||||
// Create a temporary directory with test files
|
||||
tempDir := t.TempDir()
|
||||
tempDir, err := os.MkdirTemp("", "test-serve-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create test files directly in the temp directory
|
||||
testFile := filepath.Join(tempDir, "index.html")
|
||||
err := os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var rules Rules
|
||||
@@ -517,7 +557,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
`, tempDir), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
|
||||
|
||||
// Test serving a file - serve command serves files relative to the root directory
|
||||
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
||||
@@ -546,7 +586,7 @@ func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
assert.Equal(t, http.StatusNotFound, w2.Code)
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
func TestHTTPFlow_ProxyCommandYAML(t *testing.T) {
|
||||
// Create a mock upstream server
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Header", "upstream-value")
|
||||
@@ -563,7 +603,7 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
`, upstreamServer.URL), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream("should not be called"))
|
||||
handler := rules.BuildHandler(mockUpstream(http.StatusOK, "should not be called"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@@ -576,11 +616,28 @@ func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_NotifyCommand(t *testing.T) {
|
||||
// TODO:
|
||||
func TestHTTPFlow_NotifyCommandYAML(t *testing.T) {
|
||||
upstream := mockUpstream(http.StatusOK, "ok")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: notify-rule
|
||||
on: path /notify
|
||||
do: notify info test-provider "title $req_method" "body $req_url $status_code"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/notify", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "ok", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
func TestHTTPFlow_FormConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("form processed"))
|
||||
@@ -620,7 +677,7 @@ func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
func TestHTTPFlow_RemoteConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("remote processed"))
|
||||
@@ -654,11 +711,11 @@ func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w2.Code)
|
||||
assert.Equal(t, 403, w2.Code)
|
||||
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
func TestHTTPFlow_BasicAuthConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("auth processed"))
|
||||
@@ -702,7 +759,7 @@ func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
func TestHTTPFlow_RouteConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("route processed"))
|
||||
@@ -742,10 +799,10 @@ func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseStatusConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
w.Write([]byte("method not allowed"))
|
||||
fmt.Fprint(w, "method not allowed")
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
@@ -767,11 +824,11 @@ func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseHeaderConditionsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Response-Header", "response header")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("processed"))
|
||||
fmt.Fprint(w, "processed")
|
||||
})
|
||||
|
||||
t.Run("any_value", func(t *testing.T) {
|
||||
@@ -831,7 +888,65 @@ func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
func TestHTTPFlow_PreTermination_SkipsLaterPreCommands_ButRunsPostOnlyAndPostMatchersYAML(t *testing.T) {
|
||||
upstreamCalled := false
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "upstream")
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: path /
|
||||
do: error 403 blocked
|
||||
- on: path /
|
||||
do: set resp_header X-Late should-run
|
||||
- on: status 4xx
|
||||
do: set resp_header X-Post true
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.False(t, upstreamCalled)
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, "blocked\n", w.Body.String())
|
||||
assert.Equal(t, "should-run", w.Header().Get("X-Late"))
|
||||
assert.Equal(t, "true", w.Header().Get("X-Post"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostRuleTermination_StopsRemainingCommandsInRuleYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: status 200
|
||||
do: |
|
||||
error 500 failed
|
||||
set resp_header X-After should-not-run
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
assert.Equal(t, "failed\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-After"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinationsYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("complex processed"))
|
||||
@@ -887,12 +1002,12 @@ func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
assert.Equal(t, 200, w3.Code)
|
||||
assert.Equal(t, http.StatusOK, w3.Code)
|
||||
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
|
||||
assert.Empty(t, w3.Header()["X-API-Version"])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
func TestHTTPFlow_ResponseModifierYAML(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("original response"))
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type noopWriteCloser struct {
|
||||
@@ -30,7 +31,7 @@ var (
|
||||
testFilesLock sync.Mutex
|
||||
)
|
||||
|
||||
func openFile(path string) (io.WriteCloser, error) {
|
||||
func openFile(path string) (io.WriteCloser, gperr.Error) {
|
||||
switch path {
|
||||
case "/dev/stdout":
|
||||
return stdout, nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,8 @@ type (
|
||||
MatcherType string
|
||||
)
|
||||
|
||||
var matcherCache = xsync.NewMap[string, Matcher]() // map[string]Matcher
|
||||
|
||||
const (
|
||||
MatcherTypeString MatcherType = "string"
|
||||
MatcherTypeGlob MatcherType = "glob"
|
||||
@@ -59,7 +62,12 @@ func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Erro
|
||||
}
|
||||
|
||||
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
if cached, ok := matcherCache.Load(expr); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
negate := false
|
||||
origExpr := expr
|
||||
if strings.HasPrefix(expr, "!") {
|
||||
negate = true
|
||||
expr = expr[1:]
|
||||
@@ -72,11 +80,23 @@ func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
|
||||
switch t {
|
||||
case MatcherTypeString:
|
||||
return StringMatcher(expr, negate)
|
||||
m, err := StringMatcher(expr, negate)
|
||||
if err == nil {
|
||||
matcherCache.Store(origExpr, m)
|
||||
}
|
||||
return m, err
|
||||
case MatcherTypeGlob:
|
||||
return GlobMatcher(expr, negate)
|
||||
m, err := GlobMatcher(expr, negate)
|
||||
if err == nil {
|
||||
matcherCache.Store(origExpr, m)
|
||||
}
|
||||
return m, err
|
||||
case MatcherTypeRegex:
|
||||
return RegexMatcher(expr, negate)
|
||||
m, err := RegexMatcher(expr, negate)
|
||||
if err == nil {
|
||||
matcherCache.Store(origExpr, m)
|
||||
}
|
||||
return m, err
|
||||
}
|
||||
// won't reach here
|
||||
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
||||
|
||||
@@ -12,19 +12,19 @@ import (
|
||||
)
|
||||
|
||||
type RuleOn struct {
|
||||
raw string
|
||||
checker Checker
|
||||
isResponseChecker bool
|
||||
}
|
||||
|
||||
func (on *RuleOn) IsResponseChecker() bool {
|
||||
return on.isResponseChecker
|
||||
raw string
|
||||
checker Checker
|
||||
phase PhaseFlag
|
||||
}
|
||||
|
||||
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
return on.checker.Check(w, r)
|
||||
if on.checker == nil {
|
||||
return true
|
||||
}
|
||||
return on.checker.Check(httputils.GetInitResponseModifier(w), r)
|
||||
}
|
||||
|
||||
// on request
|
||||
const (
|
||||
OnDefault = "default"
|
||||
OnHeader = "header"
|
||||
@@ -39,35 +39,36 @@ const (
|
||||
OnRemote = "remote"
|
||||
OnBasicAuth = "basic_auth"
|
||||
OnRoute = "route"
|
||||
)
|
||||
|
||||
// on response
|
||||
|
||||
// on response
|
||||
const (
|
||||
OnResponseHeader = "resp_header"
|
||||
OnStatus = "status"
|
||||
)
|
||||
|
||||
var checkers = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
isResponseChecker bool
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
}{
|
||||
OnDefault: {
|
||||
help: Help{
|
||||
command: OnDefault,
|
||||
description: makeLines(
|
||||
"The default rule is matched when no other rules are matched.",
|
||||
"Select the default (fallback) rule.",
|
||||
),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
return phase, nil, ErrExpectNoArg
|
||||
}
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
return phase, nil, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool { return true }
|
||||
},
|
||||
builder: func(args any) CheckFunc { return func(w http.ResponseWriter, r *http.Request) bool { return false } }, // this should never be called
|
||||
},
|
||||
OnHeader: {
|
||||
help: Help{
|
||||
@@ -83,21 +84,23 @@ var checkers = map[string]struct {
|
||||
"[value]": "the header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(r.Header[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(r.Header[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnResponseHeader: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnResponseHeader,
|
||||
description: makeLines(
|
||||
@@ -111,16 +114,20 @@ var checkers = map[string]struct {
|
||||
"[value]": "the response header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePost
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(httputils.GetInitResponseModifier(w).Header()[k]) > 0
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(w.Header()[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(httputils.GetInitResponseModifier(w).Header()[k], matcher)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(w.Header()[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -138,16 +145,19 @@ var checkers = map[string]struct {
|
||||
"[value]": "the query value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(httputils.GetSharedData(w).GetQueries(r)[k]) > 0
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return len(w.SharedData().GetQueries(r)[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(httputils.GetSharedData(w).GetQueries(r)[k], matcher)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return slices.ContainsFunc(w.SharedData().GetQueries(r)[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -165,12 +175,15 @@ var checkers = map[string]struct {
|
||||
"[value]": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
cookies := w.SharedData().GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
return true
|
||||
@@ -179,8 +192,8 @@ var checkers = map[string]struct {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := httputils.GetSharedData(w).GetCookies(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
cookies := w.SharedData().GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
if matcher(cookie.Value) {
|
||||
@@ -192,6 +205,7 @@ var checkers = map[string]struct {
|
||||
}
|
||||
},
|
||||
},
|
||||
//nolint:dupl
|
||||
OnForm: {
|
||||
help: Help{
|
||||
command: OnForm,
|
||||
@@ -206,15 +220,18 @@ var checkers = map[string]struct {
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.FormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.FormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -233,15 +250,18 @@ var checkers = map[string]struct {
|
||||
"[value]": "the form value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = toKVOptionalVMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.PostFormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.PostFormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -250,32 +270,46 @@ var checkers = map[string]struct {
|
||||
help: Help{
|
||||
command: OnProto,
|
||||
args: map[string]string{
|
||||
"proto": "the http protocol (http, https, h3)",
|
||||
"proto": "the http protocol (http, https, h1, h2, h2c, h3)",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, error) {
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
return phase, nil, ErrExpectOneArg
|
||||
}
|
||||
proto := args[0]
|
||||
if proto != "http" && proto != "https" && proto != "h3" {
|
||||
return nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||
switch proto {
|
||||
case "http", "https", "h1", "h2", "h2c", "h3":
|
||||
return phase, proto, nil
|
||||
default:
|
||||
return phase, nil, ErrInvalidArguments.Withf("proto: %q", proto)
|
||||
}
|
||||
return proto, nil
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
proto := args.(string)
|
||||
switch proto {
|
||||
case "http":
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil
|
||||
}
|
||||
case "https":
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil
|
||||
}
|
||||
case "h1":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil && r.ProtoMajor == 1
|
||||
}
|
||||
case "h2":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil && r.ProtoMajor == 2
|
||||
}
|
||||
case "h2c":
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS == nil && r.ProtoMajor == 2
|
||||
}
|
||||
default: // h3
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.TLS != nil && r.ProtoMajor == 3
|
||||
}
|
||||
}
|
||||
@@ -288,10 +322,13 @@ var checkers = map[string]struct {
|
||||
"method": "the http method",
|
||||
},
|
||||
},
|
||||
validate: validateMethod,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateMethod(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
method := args.(string)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return r.Method == method
|
||||
}
|
||||
},
|
||||
@@ -310,10 +347,13 @@ var checkers = map[string]struct {
|
||||
"host": "the host name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateSingleMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(r.Host)
|
||||
}
|
||||
},
|
||||
@@ -332,10 +372,13 @@ var checkers = map[string]struct {
|
||||
"path": "the request path",
|
||||
},
|
||||
},
|
||||
validate: validateURLPathMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateURLPathMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
reqPath := r.URL.Path
|
||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||
reqPath = "/" + reqPath
|
||||
@@ -351,22 +394,25 @@ var checkers = map[string]struct {
|
||||
"ip|cidr": "the remote ip or cidr",
|
||||
},
|
||||
},
|
||||
validate: validateCIDR,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateCIDR(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
ipnet := args.(*net.IPNet)
|
||||
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||
wantIP := ipnet.IP
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
ip := w.SharedData().GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(wantIP)
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := httputils.GetSharedData(w).GetRemoteIP(r)
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
ip := w.SharedData().GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
@@ -382,11 +428,14 @@ var checkers = map[string]struct {
|
||||
"password": "the password encrypted with bcrypt",
|
||||
},
|
||||
},
|
||||
validate: validateUserBCryptPassword,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateUserBCryptPassword(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
cred := args.(*HashedCrendentials)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return cred.Match(httputils.GetSharedData(w).GetBasicAuth(r))
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return cred.Match(w.SharedData().GetBasicAuth(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
@@ -403,16 +452,18 @@ var checkers = map[string]struct {
|
||||
"route": "the route name",
|
||||
},
|
||||
},
|
||||
validate: validateSingleMatcher,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
parsedArgs, err = validateSingleMatcher(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(_ http.ResponseWriter, r *http.Request) bool {
|
||||
return func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return matcher(routes.TryGetUpstreamName(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
OnStatus: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnStatus,
|
||||
description: makeLines(
|
||||
@@ -429,16 +480,20 @@ var checkers = map[string]struct {
|
||||
"status": "the status code range",
|
||||
},
|
||||
},
|
||||
validate: validateStatusRange,
|
||||
validate: func(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
phase = PhasePost
|
||||
parsedArgs, err = validateStatusRange(args)
|
||||
return
|
||||
},
|
||||
builder: func(args any) CheckFunc {
|
||||
beg, end := args.(*IntTuple).Unpack()
|
||||
if beg == end {
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
return httputils.GetInitResponseModifier(w).StatusCode() == beg
|
||||
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
return w.StatusCode() == beg
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
statusCode := httputils.GetInitResponseModifier(w).StatusCode()
|
||||
return func(w *httputils.ResponseModifier, _ *http.Request) bool {
|
||||
statusCode := w.StatusCode()
|
||||
return statusCode >= beg && statusCode <= end
|
||||
}
|
||||
},
|
||||
@@ -515,85 +570,90 @@ func splitPipe(s string) []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
brackets := 0
|
||||
result := []string{}
|
||||
forEachPipePart(s, func(part string) {
|
||||
result = append(result, part)
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
func forEachAndPart(s string, fn func(part string)) {
|
||||
start := 0
|
||||
for i := 0; i <= len(s); i++ {
|
||||
if i < len(s) && andSeps[s[i]] == 0 {
|
||||
continue
|
||||
}
|
||||
part := strings.TrimSpace(s[start:i])
|
||||
if part != "" {
|
||||
fn(part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
switch r {
|
||||
func forEachPipePart(s string, fn func(part string)) {
|
||||
quote := byte(0)
|
||||
brackets := 0
|
||||
start := 0
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '\\':
|
||||
escaped = true
|
||||
current.WriteRune(r)
|
||||
if i+1 < len(s) {
|
||||
i++
|
||||
}
|
||||
case '"', '\'', '`':
|
||||
if quote == 0 && brackets == 0 {
|
||||
quote = r
|
||||
} else if r == quote {
|
||||
quote = s[i]
|
||||
} else if s[i] == quote {
|
||||
quote = 0
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '(':
|
||||
brackets++
|
||||
current.WriteRune(r)
|
||||
case ')':
|
||||
if brackets > 0 {
|
||||
brackets--
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '|':
|
||||
if quote == 0 && brackets == 0 {
|
||||
// Found a pipe outside quotes/brackets, split here
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
if part := strings.TrimSpace(s[start:i]); part != "" {
|
||||
fn(part)
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last part
|
||||
if current.Len() > 0 {
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
if start < len(s) {
|
||||
if part := strings.TrimSpace(s[start:]); part != "" {
|
||||
fn(part)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (on *RuleOn) Parse(v string) error {
|
||||
on.raw = v
|
||||
|
||||
rules := splitAnd(v)
|
||||
checkAnd := make(CheckMatchAll, 0, len(rules))
|
||||
ruleCount := 0
|
||||
forEachAndPart(v, func(_ string) {
|
||||
ruleCount++
|
||||
})
|
||||
checkAnd := make(CheckMatchAll, 0, ruleCount)
|
||||
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
isResponseChecker := false
|
||||
for i, rule := range rules {
|
||||
if rule == "" {
|
||||
continue
|
||||
}
|
||||
parsed, isResp, err := parseOn(rule)
|
||||
i := 0
|
||||
forEachAndPart(v, func(rule string) {
|
||||
i++
|
||||
parsed, phase, err := parseOn(rule)
|
||||
if err != nil {
|
||||
errs.AddSubjectf(err, "line %d", i+1)
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
errs.AddSubjectf(err, "line %d", i)
|
||||
return
|
||||
}
|
||||
on.phase |= phase
|
||||
checkAnd = append(checkAnd, parsed)
|
||||
}
|
||||
})
|
||||
|
||||
on.checker = checkAnd
|
||||
on.isResponseChecker = isResponseChecker
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
@@ -605,33 +665,40 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||
return []byte(on.String()), nil
|
||||
}
|
||||
|
||||
func parseOn(line string) (Checker, bool, error) {
|
||||
ors := splitPipe(line)
|
||||
|
||||
if len(ors) > 1 {
|
||||
func parseOn(line string) (Checker, PhaseFlag, error) {
|
||||
orCount := 0
|
||||
forEachPipePart(line, func(_ string) {
|
||||
orCount++
|
||||
})
|
||||
if orCount > 1 {
|
||||
var phase PhaseFlag
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
checkOr := make(CheckMatchSingle, len(ors))
|
||||
isResponseChecker := false
|
||||
for i, or := range ors {
|
||||
curCheckers, isResp, err := parseOn(or)
|
||||
checkOr := make(CheckMatchSingle, orCount)
|
||||
i := 0
|
||||
forEachPipePart(line, func(or string) {
|
||||
i++
|
||||
checkFunc, req, err := parseOnAtom(or)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
errs.AddSubjectf(err, "or[%d]", i)
|
||||
return
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
}
|
||||
checkOr[i] = curCheckers.(CheckFunc)
|
||||
}
|
||||
checkOr[i-1] = checkFunc
|
||||
phase |= req
|
||||
})
|
||||
if err := errs.Error(); err != nil {
|
||||
return nil, false, err
|
||||
return nil, phase, err
|
||||
}
|
||||
return checkOr, isResponseChecker, nil
|
||||
return checkOr, phase, nil
|
||||
}
|
||||
|
||||
return parseOnAtom(line)
|
||||
}
|
||||
|
||||
func parseOnAtom(line string) (CheckFunc, PhaseFlag, error) {
|
||||
var phase PhaseFlag
|
||||
subject, args, err := parse(line)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, phase, err
|
||||
}
|
||||
|
||||
negate := false
|
||||
@@ -642,20 +709,21 @@ func parseOn(line string) (Checker, bool, error) {
|
||||
|
||||
checker, ok := checkers[subject]
|
||||
if !ok {
|
||||
return nil, false, ErrInvalidOnTarget.Subject(subject)
|
||||
return nil, phase, ErrInvalidOnTarget.Subject(subject)
|
||||
}
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
req, validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, false, gperr.Wrap(err).With(checker.help.Error())
|
||||
return nil, phase, gperr.Wrap(err).With(checker.help.Error())
|
||||
}
|
||||
phase |= req
|
||||
|
||||
checkFunc := checker.builder(validArgs)
|
||||
if negate {
|
||||
origCheckFunc := checkFunc
|
||||
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
|
||||
checkFunc = func(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
return !origCheckFunc(w, r)
|
||||
}
|
||||
}
|
||||
return checkFunc, checker.isResponseChecker, nil
|
||||
return checkFunc, phase, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func TestSplitPipe(t *testing.T) {
|
||||
{
|
||||
name: "empty_segments",
|
||||
input: "rule1 || rule2 | | rule3",
|
||||
want: []string{"rule1", "", "rule2", "", "rule3"},
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
. "github.com/yusing/godoxy/internal/route/rules"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
@@ -386,7 +387,7 @@ func TestOnCorrectness(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
w := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
var on RuleOn
|
||||
err := on.Parse(tt.checker)
|
||||
expect.NoError(t, err)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/yusing/goutils/env"
|
||||
@@ -25,6 +24,76 @@ var quoteChars = [256]bool{
|
||||
'`': true,
|
||||
}
|
||||
|
||||
func parseSimple(v string) (subject string, args []string, err error, ok bool) {
|
||||
brackets := 0
|
||||
for i := range len(v) {
|
||||
switch v[i] {
|
||||
case '\\', '$', '"', '\'', '`', '\t', '\r', '\n':
|
||||
return "", nil, nil, false
|
||||
case '(':
|
||||
brackets++
|
||||
case ')':
|
||||
if brackets == 0 {
|
||||
return "", nil, ErrUnterminatedBrackets, true
|
||||
}
|
||||
brackets--
|
||||
}
|
||||
}
|
||||
if brackets != 0 {
|
||||
return "", nil, ErrUnterminatedBrackets, true
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(v) && v[i] == ' ' {
|
||||
i++
|
||||
}
|
||||
if i >= len(v) {
|
||||
return "", nil, nil, true
|
||||
}
|
||||
|
||||
start := i
|
||||
for i < len(v) && v[i] != ' ' {
|
||||
i++
|
||||
}
|
||||
subject = v[start:i]
|
||||
|
||||
if i >= len(v) {
|
||||
return subject, nil, nil, true
|
||||
}
|
||||
|
||||
argCount := 0
|
||||
for j := i; j < len(v); {
|
||||
for j < len(v) && v[j] == ' ' {
|
||||
j++
|
||||
}
|
||||
if j >= len(v) {
|
||||
break
|
||||
}
|
||||
argCount++
|
||||
for j < len(v) && v[j] != ' ' {
|
||||
j++
|
||||
}
|
||||
}
|
||||
if argCount == 0 {
|
||||
return subject, nil, nil, true
|
||||
}
|
||||
args = make([]string, 0, argCount)
|
||||
for i < len(v) {
|
||||
for i < len(v) && v[i] == ' ' {
|
||||
i++
|
||||
}
|
||||
if i >= len(v) {
|
||||
break
|
||||
}
|
||||
start = i
|
||||
for i < len(v) && v[i] != ' ' {
|
||||
i++
|
||||
}
|
||||
args = append(args, v[start:i])
|
||||
}
|
||||
return subject, args, nil, true
|
||||
}
|
||||
|
||||
// parse expression to subject and args
|
||||
// with support for quotes, escaped chars, and env substitution, e.g.
|
||||
//
|
||||
@@ -32,14 +101,21 @@ var quoteChars = [256]bool{
|
||||
// error 403 Forbidden\ \"foo\"\ \"bar\".
|
||||
// error 403 "Message: ${CLOUDFLARE_API_KEY}"
|
||||
func parse(v string) (subject string, args []string, err error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(v)))
|
||||
if subject, args, err, ok := parseSimple(v); ok {
|
||||
return subject, args, err
|
||||
}
|
||||
|
||||
buf := getStringBuffer(len(v))
|
||||
args = make([]string, 0, 4)
|
||||
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
brackets := 0
|
||||
|
||||
var envVar bytes.Buffer
|
||||
var missingEnvVars []string
|
||||
var (
|
||||
envVar strings.Builder
|
||||
missingEnvVars []string
|
||||
)
|
||||
inEnvVar := false
|
||||
expectingBrace := false
|
||||
|
||||
@@ -71,7 +147,8 @@ func parse(v string) (subject string, args []string, err error) {
|
||||
if ch, ok := escapedChars[r]; ok {
|
||||
buf.WriteRune(ch)
|
||||
} else {
|
||||
fmt.Fprintf(buf, `\%c`, r)
|
||||
buf.WriteRune('\\')
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
escaped = false
|
||||
continue
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,7 @@ func TestParser(t *testing.T) {
|
||||
input string
|
||||
subject string
|
||||
args []string
|
||||
wantErr gperr.Error
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
@@ -90,6 +92,10 @@ func TestParser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
subject, args, err := parse(tt.input)
|
||||
if tt.wantErr != nil {
|
||||
expect.ErrorIs(t, tt.wantErr, err)
|
||||
return
|
||||
}
|
||||
// t.Log(subject, args, err)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, tt.subject)
|
||||
|
||||
29
internal/route/rules/phase.go
Normal file
29
internal/route/rules/phase.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package rules
|
||||
|
||||
import "strings"
|
||||
|
||||
type PhaseFlag uint8
|
||||
|
||||
const (
|
||||
PhaseNone PhaseFlag = 0
|
||||
PhasePre PhaseFlag = 1 << (iota - 1)
|
||||
PhasePost
|
||||
)
|
||||
|
||||
func (phase PhaseFlag) IsPostRule() bool {
|
||||
return phase&PhasePost != 0
|
||||
}
|
||||
|
||||
func (phase PhaseFlag) String() string {
|
||||
if phase == PhaseNone {
|
||||
return "none"
|
||||
}
|
||||
var flags []string
|
||||
if phase&PhasePre != 0 {
|
||||
flags = append(flags, "PhasePre")
|
||||
}
|
||||
if phase&PhasePost != 0 {
|
||||
flags = append(flags, "PhasePost")
|
||||
}
|
||||
return strings.Join(flags, ",")
|
||||
}
|
||||
@@ -4,9 +4,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
@@ -15,37 +22,36 @@ import (
|
||||
|
||||
type (
|
||||
/*
|
||||
Example:
|
||||
Rules is a list of rules.
|
||||
|
||||
proxy.app1.rules: |
|
||||
- name: default
|
||||
do: |
|
||||
rewrite / /index.html
|
||||
serve /var/www/goaccess
|
||||
- name: ws
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
Example:
|
||||
|
||||
proxy.app2.rules: |
|
||||
- name: default
|
||||
do: bypass
|
||||
- name: block POST and PUT
|
||||
on: method POST | method PUT
|
||||
do: error 403 Forbidden
|
||||
proxy.app1.rules: |
|
||||
- name: default
|
||||
do: |
|
||||
rewrite / /index.html
|
||||
serve /var/www/goaccess
|
||||
- name: ws
|
||||
on: |
|
||||
header Connection Upgrade
|
||||
header Upgrade websocket
|
||||
do: bypass
|
||||
|
||||
proxy.app2.rules: |
|
||||
- name: default
|
||||
do: bypass
|
||||
- name: block POST and PUT
|
||||
on: method POST | method PUT
|
||||
do: error 403 Forbidden
|
||||
*/
|
||||
//nolint:recvcheck
|
||||
Rules []Rule
|
||||
/*
|
||||
Rule is a rule for a reverse proxy.
|
||||
It do `Do` when `On` matches.
|
||||
|
||||
A rule can have multiple lines of on.
|
||||
|
||||
All lines of on must match,
|
||||
but each line can have multiple checks that
|
||||
one match means this line is matched.
|
||||
*/
|
||||
// Rule represents a reverse proxy rule.
|
||||
// The `Do` field is executed when `On` matches.
|
||||
//
|
||||
// - A rule may have multiple lines in the `On` section.
|
||||
// - All `On` lines must match for the rule to trigger.
|
||||
// - Each line can have several checks—one match per line is enough for that line.
|
||||
Rule struct {
|
||||
Name string `json:"name"`
|
||||
On RuleOn `json:"on" swaggertype:"string"`
|
||||
@@ -53,210 +59,395 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (rule *Rule) IsResponseRule() bool {
|
||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
||||
func isDefaultRule(rule Rule) bool {
|
||||
return rule.Name == "default" || rule.On.raw == OnDefault
|
||||
}
|
||||
|
||||
func (rules Rules) Validate() error {
|
||||
func (rules Rules) Validate() gperr.Error {
|
||||
var defaultRulesFound []int
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
if isDefaultRule(rule) {
|
||||
defaultRulesFound = append(defaultRulesFound, i)
|
||||
}
|
||||
if rules[i].Name == "" {
|
||||
// set name to index if name is empty
|
||||
rules[i].Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
}
|
||||
if len(defaultRulesFound) > 1 {
|
||||
return ErrMultipleDefaultRules.Withf("found %d", len(defaultRulesFound))
|
||||
}
|
||||
for i := range rules {
|
||||
r1 := rules[i]
|
||||
if isDefaultRule(r1) || r1.On.phase.IsPostRule() || !r1.doesTerminateInPre() {
|
||||
continue
|
||||
}
|
||||
sig1, ok := matcherSignature(r1.On.raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(rules); j++ {
|
||||
r2 := rules[j]
|
||||
if isDefaultRule(r2) || r2.On.phase.IsPostRule() {
|
||||
continue
|
||||
}
|
||||
sig2, ok := matcherSignature(r2.On.raw)
|
||||
if !ok || sig1 != sig2 {
|
||||
continue
|
||||
}
|
||||
return ErrDeadRule.Withf("rule[%d] shadows rule[%d] with same matcher", i, j)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rule Rule) doesTerminateInPre() bool {
|
||||
return commandsTerminateInPre(rule.Do.pre)
|
||||
}
|
||||
|
||||
func commandsTerminateInPre(cmds []CommandHandler) bool {
|
||||
return slices.ContainsFunc(cmds, commandTerminatesInPre)
|
||||
}
|
||||
|
||||
func commandTerminatesInPre(cmd CommandHandler) bool {
|
||||
switch c := cmd.(type) {
|
||||
case Handler:
|
||||
return c.Terminates()
|
||||
case *Handler:
|
||||
return c.Terminates()
|
||||
case IfBlockCommand:
|
||||
return ruleOnAlwaysTrue(c.On) && commandsTerminateInPre(c.Do)
|
||||
case *IfBlockCommand:
|
||||
return c != nil && ruleOnAlwaysTrue(c.On) && commandsTerminateInPre(c.Do)
|
||||
case IfElseBlockCommand:
|
||||
return ifElseBlockTerminatesInPre(c)
|
||||
case *IfElseBlockCommand:
|
||||
return c != nil && ifElseBlockTerminatesInPre(*c)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ifElseBlockTerminatesInPre(cmd IfElseBlockCommand) bool {
|
||||
hasFallback := len(cmd.Else) > 0
|
||||
for _, br := range cmd.Ifs {
|
||||
if !commandsTerminateInPre(br.Do) {
|
||||
return false
|
||||
}
|
||||
if ruleOnAlwaysTrue(br.On) {
|
||||
hasFallback = true
|
||||
}
|
||||
}
|
||||
if !hasFallback {
|
||||
return false
|
||||
}
|
||||
if len(cmd.Else) > 0 && !commandsTerminateInPre(cmd.Else) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ruleOnAlwaysTrue(on RuleOn) bool {
|
||||
return strings.TrimSpace(on.raw) == OnDefault || on.checker == nil
|
||||
}
|
||||
|
||||
func matcherSignature(raw string) (string, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "(any)", true // unconditional rule
|
||||
}
|
||||
|
||||
andParts := splitAnd(raw)
|
||||
if len(andParts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
canonAnd := make([]string, 0, len(andParts))
|
||||
for _, andPart := range andParts {
|
||||
orParts := splitPipe(andPart)
|
||||
if len(orParts) == 0 {
|
||||
continue
|
||||
}
|
||||
canonOr := make([]string, 0, len(orParts))
|
||||
for _, atom := range orParts {
|
||||
subject, args, err := parse(strings.TrimSpace(atom))
|
||||
if err != nil || subject == "" {
|
||||
return "", false
|
||||
}
|
||||
canonOr = append(canonOr, subject+" "+strings.Join(args, "\x00"))
|
||||
}
|
||||
slices.Sort(canonOr)
|
||||
canonOr = slices.Compact(canonOr)
|
||||
canonAnd = append(canonAnd, "("+strings.Join(canonOr, "|")+")")
|
||||
}
|
||||
|
||||
slices.Sort(canonAnd)
|
||||
canonAnd = slices.Compact(canonAnd)
|
||||
if len(canonAnd) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return strings.Join(canonAnd, "&"), true
|
||||
}
|
||||
|
||||
// Parse parses a rule configuration string.
|
||||
// It first tries the block syntax (if the string contains a top-level '{'),
|
||||
// then falls back to YAML syntax.
|
||||
func (rules *Rules) Parse(config string) error {
|
||||
config = strings.TrimSpace(config)
|
||||
if config == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
blockTried := false
|
||||
var blockErr gperr.Error
|
||||
|
||||
// Prefer block syntax if it looks like block syntax.
|
||||
if hasTopLevelLBrace(config) {
|
||||
blockTried = true
|
||||
blockRules, err := parseBlockRules(config)
|
||||
if err == nil {
|
||||
*rules = blockRules
|
||||
return nil
|
||||
}
|
||||
blockErr = err
|
||||
}
|
||||
|
||||
// YAML fallback
|
||||
var anySlice []any
|
||||
yamlErr := yaml.Unmarshal([]byte(config), &anySlice)
|
||||
if yamlErr == nil {
|
||||
return serialization.ConvertSlice(reflect.ValueOf(anySlice), reflect.ValueOf(rules), false)
|
||||
}
|
||||
|
||||
// If YAML fails and we haven't tried block syntax yet, try it now.
|
||||
if !blockTried {
|
||||
blockRules, err := parseBlockRules(config)
|
||||
if err == nil {
|
||||
*rules = blockRules
|
||||
return nil
|
||||
}
|
||||
blockErr = err
|
||||
}
|
||||
return blockErr
|
||||
}
|
||||
|
||||
// hasTopLevelLBrace reports whether s contains a '{' outside quotes/backticks and comments.
|
||||
// Used to decide whether to prioritize the block syntax.
|
||||
func hasTopLevelLBrace(s string) bool {
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '*' && i+1 < len(s) && s[i+1] == '/' {
|
||||
inBlock = false
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
if quote != '`' && c == '\\' && i+1 < len(s) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if rune(c) == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch c {
|
||||
case '\'', '"', '`':
|
||||
quote = rune(c)
|
||||
continue
|
||||
case '{':
|
||||
return true
|
||||
case '#':
|
||||
inLine = true
|
||||
continue
|
||||
case '/':
|
||||
if i+1 < len(s) && s[i+1] == '/' {
|
||||
inLine = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if i+1 < len(s) && s[i+1] == '*' {
|
||||
inBlock = true
|
||||
i++
|
||||
continue
|
||||
}
|
||||
default:
|
||||
if unicode.IsSpace(rune(c)) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
||||
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
||||
if len(rules) == 0 {
|
||||
return up
|
||||
}
|
||||
|
||||
defaultRule := Rule{
|
||||
Name: "default",
|
||||
Do: Command{
|
||||
raw: "pass",
|
||||
exec: BypassCommand{},
|
||||
},
|
||||
}
|
||||
var defaultRule *Rule
|
||||
|
||||
var nonDefaultRules Rules
|
||||
hasDefaultRule := false
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" || rule.On.raw == OnDefault {
|
||||
defaultRule = rule
|
||||
hasDefaultRule = true
|
||||
for _, rule := range rules {
|
||||
if isDefaultRule(rule) {
|
||||
r := rule
|
||||
defaultRule = &r
|
||||
} else {
|
||||
// set name to index if name is empty
|
||||
if rule.Name == "" {
|
||||
rule.Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
nonDefaultRules = append(nonDefaultRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonDefaultRules) == 0 {
|
||||
if defaultRule.Do.isBypass() {
|
||||
if defaultRule == nil || defaultRule.Do.raw == CommandUpstream {
|
||||
return up
|
||||
}
|
||||
if defaultRule.IsResponseRule() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
}
|
||||
}()
|
||||
w = rm
|
||||
up(w, r)
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
}
|
||||
}()
|
||||
w = rm
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err == nil {
|
||||
up(w, r)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preRules := make(Rules, 0, len(nonDefaultRules))
|
||||
postRules := make(Rules, 0, len(nonDefaultRules))
|
||||
for _, rule := range nonDefaultRules {
|
||||
if rule.IsResponseRule() {
|
||||
postRules = append(postRules, rule)
|
||||
} else {
|
||||
preRules = append(preRules, rule)
|
||||
}
|
||||
execPreCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||
return cmd.pre.ServeHTTP(w, r, up)
|
||||
}
|
||||
|
||||
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
|
||||
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
||||
execPostCommand := func(cmd Command, w *httputils.ResponseModifier, r *http.Request) error {
|
||||
return cmd.post.ServeHTTP(w, r, up)
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
logError(err, r)
|
||||
logFlushError(err, r)
|
||||
}
|
||||
}()
|
||||
|
||||
w = rm
|
||||
var hasError bool
|
||||
|
||||
shouldCallUpstream := true
|
||||
preMatched := false
|
||||
executedPre := make([]bool, len(nonDefaultRules))
|
||||
terminatedInPre := make([]bool, len(nonDefaultRules))
|
||||
matchedNonDefaultPre := false
|
||||
preTerminated := false
|
||||
for i, rule := range nonDefaultRules {
|
||||
if rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||
continue
|
||||
}
|
||||
matchedNonDefaultPre = true
|
||||
if preTerminated {
|
||||
// Preserve post-only commands (e.g. logging) even after
|
||||
// pre-phase termination.
|
||||
if len(rule.Do.pre) == 0 {
|
||||
executedPre[i] = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
executedPre[i] = true
|
||||
if err := execPreCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
terminatedInPre[i] = true
|
||||
preTerminated = true
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
hasError = true
|
||||
}
|
||||
}
|
||||
|
||||
// Default rule is a fallback: run only when no non-default pre rule matched.
|
||||
defaultExecutedPre := false
|
||||
defaultTerminatedInPre := false
|
||||
if defaultRule != nil && !matchedNonDefaultPre && !defaultRule.On.phase.IsPostRule() && defaultRule.On.Check(rm, r) {
|
||||
defaultExecutedPre = true
|
||||
if err := execPreCommand(defaultRule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
defaultTerminatedInPre = true
|
||||
} else {
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing pre rule (%s): %w", defaultRule.Do.raw, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
hasError = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
for _, rule := range preRules {
|
||||
if rule.Check(w, r) {
|
||||
preMatched = true
|
||||
if rule.Do.isBypass() {
|
||||
break // post rules should still execute
|
||||
}
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &rule, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
break
|
||||
}
|
||||
if !rm.HasStatus() {
|
||||
if hasError {
|
||||
http.Error(rm, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
} else { // call upstream if no WriteHeader or Write was called and no error occurred
|
||||
up(rm, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Run post commands for rules that actually executed in pre phase,
|
||||
// unless that same rule terminated in pre phase.
|
||||
for i, rule := range nonDefaultRules {
|
||||
if !executedPre[i] || terminatedInPre[i] {
|
||||
continue
|
||||
}
|
||||
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if defaultExecutedPre && !defaultTerminatedInPre {
|
||||
if err := execPostCommand(defaultRule.Do, rm, r); err != nil {
|
||||
if !errors.Is(err, errTerminateRule) && isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing post rule (%s): %w", defaultRule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
return
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
// Run true post-matcher rules after response is available.
|
||||
for _, rule := range nonDefaultRules {
|
||||
if !rule.On.phase.IsPostRule() || !rule.On.Check(rm, r) {
|
||||
continue
|
||||
}
|
||||
// Post-rule matchers are only evaluated after upstream, so commands parsed
|
||||
// as "pre" for requirement purposes still need to run in this phase.
|
||||
if err := rule.Do.pre.ServeHTTP(rm, r, up); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing pre rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
if err := execPostCommand(rule.Do, rm, r); err != nil {
|
||||
if errors.Is(err, errTerminateRule) {
|
||||
continue
|
||||
}
|
||||
if isUnexpectedError(err) {
|
||||
// will logged by logFlushError after FlushRelease
|
||||
rm.AppendError("executing post rule (%s): %w", rule.Do.raw, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
up(w, r)
|
||||
}
|
||||
|
||||
// if no post rules, we are done here
|
||||
if len(postRules) == 0 && !isDefaultRulePost {
|
||||
return
|
||||
}
|
||||
|
||||
for _, rule := range postRules {
|
||||
if rule.Check(w, r) {
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &rule, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isDefaultRulePost {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
appendRuleError(rm, &defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func appendRuleError(rm *httputils.ResponseModifier, rule *Rule, err error) {
|
||||
// rm.AppendError("rule: %s, error: %w", rule.Name, err)
|
||||
}
|
||||
|
||||
func isTerminatingHandler(handler CommandHandler) bool {
|
||||
switch h := handler.(type) {
|
||||
case TerminatingCommand:
|
||||
return true
|
||||
case Commands:
|
||||
if len(h) == 0 {
|
||||
return false
|
||||
}
|
||||
return isTerminatingHandler(h[len(h)-1])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,41 +455,41 @@ func (rule *Rule) String() string {
|
||||
return rule.Name
|
||||
}
|
||||
|
||||
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
func (rule *Rule) Check(w *httputils.ResponseModifier, r *http.Request) bool {
|
||||
if rule.On.checker == nil {
|
||||
return true
|
||||
}
|
||||
v := rule.On.checker.Check(w, r)
|
||||
return v
|
||||
}
|
||||
|
||||
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return rule.Do.exec.Handle(w, r)
|
||||
return rule.On.Check(w, r)
|
||||
}
|
||||
|
||||
//go:linkname errStreamClosed golang.org/x/net/http2.errStreamClosed
|
||||
var errStreamClosed error
|
||||
|
||||
func logError(err error, r *http.Request) {
|
||||
if errors.Is(err, errStreamClosed) {
|
||||
return
|
||||
//go:linkname errClientDisconnected golang.org/x/net/http2.errClientDisconnected
|
||||
var errClientDisconnected error
|
||||
|
||||
func isUnexpectedError(err error) bool {
|
||||
if errors.Is(err, errStreamClosed) || errors.Is(err, errClientDisconnected) {
|
||||
return false
|
||||
}
|
||||
var h2Err http2.StreamError
|
||||
if errors.As(err, &h2Err) {
|
||||
if h2Err, ok := errors.AsType[http2.StreamError](err); ok {
|
||||
// ignore these errors
|
||||
if h2Err.Code == http2.ErrCodeStreamClosed {
|
||||
return
|
||||
return false
|
||||
}
|
||||
}
|
||||
var h3Err *http3.Error
|
||||
if errors.As(err, &h3Err) {
|
||||
if h3Err, ok := errors.AsType[*http3.Error](err); ok {
|
||||
// ignore these errors
|
||||
switch h3Err.ErrorCode {
|
||||
case
|
||||
http3.ErrCodeNoError,
|
||||
http3.ErrCodeRequestCanceled:
|
||||
return
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func logFlushError(err error, r *http.Request) {
|
||||
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
|
||||
}
|
||||
|
||||
@@ -19,28 +19,133 @@ func TestRulesValidate(t *testing.T) {
|
||||
{
|
||||
name: "no default rule",
|
||||
rules: `
|
||||
- name: rule1
|
||||
on: header Host example.com
|
||||
do: pass
|
||||
`,
|
||||
header Host example.com {
|
||||
pass
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "multiple default rules",
|
||||
rules: `
|
||||
- name: default
|
||||
do: pass
|
||||
- name: rule1
|
||||
on: default
|
||||
do: pass
|
||||
`,
|
||||
default {
|
||||
pass
|
||||
}
|
||||
|
||||
default {
|
||||
pass
|
||||
}`,
|
||||
want: ErrMultipleDefaultRules,
|
||||
},
|
||||
{
|
||||
name: "multiple responses on same condition",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
error 404 "not found"
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
{
|
||||
name: "same condition different formatting error then proxy",
|
||||
rules: `
|
||||
header Host example.com & method GET {
|
||||
error 404 "not found"
|
||||
}
|
||||
|
||||
method GET
|
||||
header Host example.com {
|
||||
proxy http://127.0.0.1:8080
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
{
|
||||
name: "same condition with non terminating first rule",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
set resp_header X-Test first
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "same condition with terminating handler inside if block",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
default {
|
||||
error 404 "not found"
|
||||
}
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
{
|
||||
name: "same condition with terminating handler across if else block",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
method GET {
|
||||
error 404 "not found"
|
||||
} else {
|
||||
redirect https://example.com
|
||||
}
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
{
|
||||
name: "same condition with non terminating if branch in if else block",
|
||||
rules: `
|
||||
header Host example.com {
|
||||
method GET {
|
||||
set resp_header X-Test first
|
||||
} else {
|
||||
error 404 "not found"
|
||||
}
|
||||
}
|
||||
|
||||
header Host example.com {
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "unconditional terminating rule shadows later unconditional rule",
|
||||
rules: `
|
||||
{
|
||||
error 404 "not found"
|
||||
}
|
||||
|
||||
{
|
||||
error 403 "forbidden"
|
||||
}
|
||||
`,
|
||||
want: ErrDeadRule,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var rules Rules
|
||||
convertible, err := serialization.ConvertString(strings.TrimSpace(tt.rules), reflect.ValueOf(&rules))
|
||||
require.True(t, convertible)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = rules.Validate()
|
||||
|
||||
if tt.want == nil {
|
||||
assert.NoError(t, err)
|
||||
@@ -50,3 +155,50 @@ func TestRulesValidate(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasTopLevelLBrace(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "escaped quote inside double quoted string",
|
||||
in: `"test\"more{"`,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "escaped quote inside single quoted string",
|
||||
in: "'test\\'more{'",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "top-level brace outside quoted string",
|
||||
in: `"test\"more" {`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "backtick keeps existing behavior",
|
||||
in: "`test\\`more{`",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, hasTopLevelLBrace(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRulesParse_BlockTriedThenYAMLFails_ReturnsBlockError(t *testing.T) {
|
||||
input := `default {`
|
||||
|
||||
_, blockErr := parseBlockRules(input)
|
||||
require.Error(t, blockErr)
|
||||
|
||||
var rules Rules
|
||||
err := rules.Parse(input)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, blockErr.Error(), err.Error())
|
||||
}
|
||||
|
||||
273
internal/route/rules/scanner.go
Normal file
273
internal/route/rules/scanner.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
// Tokenizer provides utilities for parsing rule syntax with proper handling
|
||||
// of quotes, comments, and env vars.
|
||||
//
|
||||
// This is intentionally reusable by both the top-level rule block parser and
|
||||
// the nested do-block parser.
|
||||
type Tokenizer struct {
|
||||
src string
|
||||
length int
|
||||
}
|
||||
|
||||
// newTokenizer creates a tokenizer for the given source.
|
||||
func newTokenizer(src string) Tokenizer {
|
||||
return Tokenizer{src: src, length: len(src)}
|
||||
}
|
||||
|
||||
// skipComments skips whitespace, line comments, and block comments.
|
||||
// It returns the new position and an error if a block comment is unterminated.
|
||||
func (t *Tokenizer) skipComments(pos int, atLineStart bool, prevIsSpace bool) (int, gperr.Error) {
|
||||
for pos < t.length {
|
||||
c := t.src[pos]
|
||||
|
||||
// Skip whitespace
|
||||
if unicode.IsSpace(rune(c)) {
|
||||
pos++
|
||||
atLineStart = false
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for line comment: // or #
|
||||
if c == '/' {
|
||||
if pos+1 < t.length && t.src[pos+1] == '/' {
|
||||
// Skip to end of line
|
||||
for pos < t.length && t.src[pos] != '\n' {
|
||||
pos++
|
||||
}
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c == '#' && (atLineStart || prevIsSpace) {
|
||||
// Skip to end of line
|
||||
for pos < t.length && t.src[pos] != '\n' {
|
||||
pos++
|
||||
}
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for block comment: /*
|
||||
if c == '/' && pos+1 < t.length && t.src[pos+1] == '*' {
|
||||
pos += 2
|
||||
closed := false
|
||||
for pos+1 < t.length {
|
||||
if t.src[pos] == '*' && t.src[pos+1] == '/' {
|
||||
pos += 2
|
||||
closed = true
|
||||
break
|
||||
}
|
||||
pos++
|
||||
}
|
||||
if !closed {
|
||||
return 0, ErrInvalidBlockSyntax.Withf("unterminated block comment")
|
||||
}
|
||||
atLineStart = false
|
||||
prevIsSpace = true
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return pos, nil
|
||||
}
|
||||
|
||||
// scanToBrace scans from pos until it finds '{' outside quotes, or returns an error.
|
||||
func (t *Tokenizer) scanToBrace(pos int) (int, gperr.Error) {
|
||||
quote := rune(0)
|
||||
for pos < t.length {
|
||||
c := rune(t.src[pos])
|
||||
if quote != 0 {
|
||||
if c == quote {
|
||||
quote = 0
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if c == '"' || c == '\'' || c == '`' {
|
||||
quote = c
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if c == '{' {
|
||||
return pos, nil
|
||||
}
|
||||
if c == '}' {
|
||||
return 0, ErrInvalidBlockSyntax.Withf("unmatched '}' in block header")
|
||||
}
|
||||
pos++
|
||||
}
|
||||
return 0, ErrInvalidBlockSyntax.Withf("expected '{' after block header")
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the matching '}' for a '{' starting at startPos.
|
||||
// It respects quotes/backticks and ${...} env vars.
|
||||
func (t *Tokenizer) findMatchingBrace(startPos int) (int, gperr.Error) {
|
||||
pos := startPos
|
||||
braceDepth := 1
|
||||
quote := rune(0)
|
||||
inLine := false
|
||||
inBlock := false
|
||||
atLineStart := true
|
||||
prevIsSpace := true
|
||||
|
||||
for pos < t.length {
|
||||
c := rune(t.src[pos])
|
||||
|
||||
if inLine {
|
||||
if c == '\n' {
|
||||
inLine = false
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if inBlock {
|
||||
if c == '*' && pos+1 < t.length && t.src[pos+1] == '/' {
|
||||
pos += 2
|
||||
inBlock = false
|
||||
continue
|
||||
}
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
if c == quote {
|
||||
quote = 0
|
||||
}
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(c)
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' || c == '\'' || c == '`' {
|
||||
quote = c
|
||||
atLineStart = false
|
||||
prevIsSpace = false
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
|
||||
// Comments (only outside quotes) at token boundary
|
||||
if c == '#' && (atLineStart || prevIsSpace) {
|
||||
inLine = true
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if c == '/' && pos+1 < t.length {
|
||||
n := rune(t.src[pos+1])
|
||||
if (atLineStart || prevIsSpace) && n == '/' {
|
||||
inLine = true
|
||||
pos += 2
|
||||
continue
|
||||
}
|
||||
if (atLineStart || prevIsSpace) && n == '*' {
|
||||
inBlock = true
|
||||
pos += 2
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if c == '$' && pos+1 < t.length && t.src[pos+1] == '{' {
|
||||
// Skip env var ${...}
|
||||
pos += 2
|
||||
envBraceDepth := 1
|
||||
envQuote := rune(0)
|
||||
for pos < t.length {
|
||||
ec := rune(t.src[pos])
|
||||
if envQuote != 0 {
|
||||
if ec == envQuote {
|
||||
envQuote = 0
|
||||
}
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if ec == '"' || ec == '\'' || ec == '`' {
|
||||
envQuote = ec
|
||||
pos++
|
||||
continue
|
||||
}
|
||||
if ec == '{' {
|
||||
envBraceDepth++
|
||||
} else if ec == '}' {
|
||||
envBraceDepth--
|
||||
if envBraceDepth == 0 {
|
||||
pos++ // Move past the closing '}'
|
||||
break
|
||||
}
|
||||
}
|
||||
pos++
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch c {
|
||||
case '{':
|
||||
braceDepth++
|
||||
case '}':
|
||||
braceDepth--
|
||||
if braceDepth == 0 {
|
||||
return pos, nil
|
||||
}
|
||||
}
|
||||
|
||||
if c == '\n' {
|
||||
atLineStart = true
|
||||
prevIsSpace = true
|
||||
} else {
|
||||
atLineStart = false
|
||||
prevIsSpace = unicode.IsSpace(c)
|
||||
}
|
||||
pos++
|
||||
}
|
||||
|
||||
return 0, ErrInvalidBlockSyntax.Withf("unmatched '{' at position %d", startPos)
|
||||
}
|
||||
|
||||
// parseHeaderToBrace parses an expression/header starting at start and returns:
|
||||
// - header: trimmed src[start:bracePos]
|
||||
// - bracePos: position of '{' (outside quotes/backticks)
|
||||
func parseHeaderToBrace(src string, start int) (header string, bracePos int, err gperr.Error) {
|
||||
t := newTokenizer(src)
|
||||
bracePos, err = t.scanToBrace(start)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return strings.TrimSpace(src[start:bracePos]), bracePos, nil
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the matching '}' for a '{' at position startPos.
|
||||
// It respects quotes/backticks and ${...} env vars in do_body.
|
||||
func findMatchingBrace(src string, pos *int, startPos int) (int, gperr.Error) {
|
||||
t := newTokenizer(src)
|
||||
endPos, err := t.findMatchingBrace(startPos)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
*pos = endPos + 1
|
||||
return endPos, nil
|
||||
}
|
||||
39
internal/route/rules/scanner_test.go
Normal file
39
internal/route/rules/scanner_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTokenizer_skipComments_UnterminatedBlockComment(t *testing.T) {
|
||||
tok := newTokenizer("/* unterminated")
|
||||
pos, err := tok.skipComments(0, true, true)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 0, pos)
|
||||
}
|
||||
|
||||
func TestTokenizer_skipComments_SkipsLineAndBlockComments(t *testing.T) {
|
||||
src := " // line\n /* block */\n # hash\n default"
|
||||
tok := newTokenizer(src)
|
||||
pos, err := tok.skipComments(0, true, true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.Index(src, "default"), pos)
|
||||
}
|
||||
|
||||
func TestTokenizer_scanToBrace_IgnoresQuotedBraces(t *testing.T) {
|
||||
src := "cond \"{\" {" // the brace inside quotes must be ignored
|
||||
tok := newTokenizer(src)
|
||||
bracePos, err := tok.scanToBrace(0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.LastIndex(src, "{"), bracePos)
|
||||
}
|
||||
|
||||
func TestTokenizer_findMatchingBrace_IgnoresQuotedClosingBrace(t *testing.T) {
|
||||
src := `{ "}" }`
|
||||
tok := newTokenizer(src)
|
||||
endPos, err := tok.findMatchingBrace(1) // body starts after the first '{'
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, strings.LastIndex(src, "}"), endPos)
|
||||
}
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
)
|
||||
|
||||
type templateString struct {
|
||||
string
|
||||
|
||||
isTemplate bool
|
||||
}
|
||||
|
||||
@@ -23,32 +23,28 @@ func (tmpl *keyValueTemplate) Unpack() (string, templateString) {
|
||||
return tmpl.key, tmpl.tmpl
|
||||
}
|
||||
|
||||
func (tmpl *templateString) ExpandVars(w http.ResponseWriter, req *http.Request, dstW io.Writer) error {
|
||||
func (tmpl *templateString) ExpandVars(w *httputils.ResponseModifier, req *http.Request, dst io.Writer) (phase PhaseFlag, err error) {
|
||||
if !tmpl.isTemplate {
|
||||
_, err := dstW.Write(strtobNoCopy(tmpl.string))
|
||||
return err
|
||||
_, err := asBytesBufferLike(dst).WriteString(tmpl.string)
|
||||
return PhaseNone, err
|
||||
}
|
||||
|
||||
return ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, dstW)
|
||||
return ExpandVars(w, req, tmpl.string, dst)
|
||||
}
|
||||
|
||||
func (tmpl *templateString) ExpandVarsToString(w http.ResponseWriter, req *http.Request) (string, error) {
|
||||
func (tmpl *templateString) ExpandVarsToString(w *httputils.ResponseModifier, r *http.Request) (string, PhaseFlag, error) {
|
||||
if !tmpl.isTemplate {
|
||||
return tmpl.string, nil
|
||||
return tmpl.string, PhaseNone, nil
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
err := ExpandVars(httputils.GetInitResponseModifier(w), req, tmpl.string, &buf)
|
||||
phase, err := tmpl.ExpandVars(w, r, &buf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", PhaseNone, err
|
||||
}
|
||||
return buf.String(), nil
|
||||
return buf.String(), phase, nil
|
||||
}
|
||||
|
||||
func (tmpl *templateString) Len() int {
|
||||
return len(tmpl.string)
|
||||
}
|
||||
|
||||
func strtobNoCopy(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
ValidateFunc func(args []string) (any, error)
|
||||
ValidateFunc func(args []string) (phase PhaseFlag, parsedArgs any, err error)
|
||||
Tuple[T1, T2 any] struct {
|
||||
First T1
|
||||
Second T2
|
||||
@@ -37,6 +38,8 @@ type (
|
||||
MapValueMatcher = Tuple[string, Matcher]
|
||||
)
|
||||
|
||||
var cidrCache = xsync.NewMap[string, *net.IPNet]()
|
||||
|
||||
func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
||||
return t.First, t.Second
|
||||
}
|
||||
@@ -62,7 +65,7 @@ func (t *Tuple4[T1, T2, T3, T4]) String() string {
|
||||
}
|
||||
|
||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||
func validateSingleMatcher(args []string) (any, error) {
|
||||
func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -70,7 +73,7 @@ func validateSingleMatcher(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||
func toKVOptionalVMatcher(args []string) (any, error) {
|
||||
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
switch len(args) {
|
||||
case 1:
|
||||
return &MapValueMatcher{args[0], nil}, nil
|
||||
@@ -85,20 +88,8 @@ func toKVOptionalVMatcher(args []string) (any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func toKeyValueTemplate(args []string) (any, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
isTemplate, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &keyValueTemplate{args[0], isTemplate}, nil
|
||||
}
|
||||
|
||||
// validateURL returns types.URL with the URL validated.
|
||||
func validateURL(args []string) (any, error) {
|
||||
func validateURL(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -116,22 +107,27 @@ func validateURL(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateCIDR returns types.CIDR with the CIDR validated.
|
||||
func validateCIDR(args []string) (any, error) {
|
||||
func validateCIDR(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
if !strings.Contains(args[0], "/") {
|
||||
args[0] += "/32"
|
||||
cidr := args[0]
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/32"
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(args[0])
|
||||
if cached, ok := cidrCache.Load(cidr); ok {
|
||||
return cached, nil
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
cidrCache.Store(cidr, ipnet)
|
||||
return ipnet, nil
|
||||
}
|
||||
|
||||
// validateURLPath returns string with the path validated.
|
||||
func validateURLPath(args []string) (any, error) {
|
||||
func validateURLPath(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -148,7 +144,7 @@ func validateURLPath(args []string) (any, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func validateURLPathMatcher(args []string) (any, error) {
|
||||
func validateURLPathMatcher(args []string) (any, gperr.Error) {
|
||||
path, err := validateURLPath(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -157,7 +153,7 @@ func validateURLPathMatcher(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateFSPath returns string with the path validated.
|
||||
func validateFSPath(args []string) (any, error) {
|
||||
func validateFSPath(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -169,7 +165,7 @@ func validateFSPath(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateMethod returns string with the method validated.
|
||||
func validateMethod(args []string) (any, error) {
|
||||
func validateMethod(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -200,7 +196,7 @@ func validateStatusCode(status string) (int, error) {
|
||||
// - 3xx
|
||||
// - 4xx
|
||||
// - 5xx
|
||||
func validateStatusRange(args []string) (any, error) {
|
||||
func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
@@ -232,7 +228,7 @@ func validateStatusRange(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||
func validateUserBCryptPassword(args []string) (any, error) {
|
||||
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
@@ -240,64 +236,93 @@ func validateUserBCryptPassword(args []string) (any, error) {
|
||||
}
|
||||
|
||||
// validateModField returns CommandHandler with the field validated.
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, error) {
|
||||
func validateModField(mod FieldModifier, args []string) (phase PhaseFlag, handler HandlerFunc, err error) {
|
||||
if len(args) == 0 {
|
||||
return nil, ErrExpectTwoOrThreeArgs
|
||||
return phase, nil, ErrExpectTwoOrThreeArgs
|
||||
}
|
||||
setField, ok := modFields[args[0]]
|
||||
if !ok {
|
||||
return nil, ErrUnknownModField.Subject(args[0])
|
||||
return phase, nil, ErrUnknownModField.Subject(args[0])
|
||||
}
|
||||
if mod == ModFieldRemove {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
// setField expect validateStrTuple
|
||||
args = append(args, "")
|
||||
}
|
||||
validArgs, err := setField.validate(args[1:])
|
||||
phase, validArgs, err := setField.validate(args[1:])
|
||||
if err != nil {
|
||||
return nil, gperr.Wrap(err).With(setField.help.Error())
|
||||
return phase, nil, gperr.Wrap(err).With(setField.help.Error())
|
||||
}
|
||||
|
||||
modder := setField.builder(validArgs)
|
||||
switch mod {
|
||||
case ModFieldAdd:
|
||||
add := modder.add
|
||||
if add == nil {
|
||||
return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod)
|
||||
return phase, nil, ErrInvalidArguments.Withf("add is not supported for field %s", args[0])
|
||||
}
|
||||
return add, nil
|
||||
return phase, add, nil
|
||||
case ModFieldRemove:
|
||||
remove := modder.remove
|
||||
if remove == nil {
|
||||
return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod)
|
||||
return phase, nil, ErrInvalidArguments.Withf("remove is not supported for field %s", args[0])
|
||||
}
|
||||
return remove, nil
|
||||
return phase, remove, nil
|
||||
}
|
||||
set := modder.set
|
||||
if set == nil {
|
||||
return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod)
|
||||
return phase, nil, ErrInvalidArguments.Withf("set is not supported for field %s", args[0])
|
||||
}
|
||||
return set, nil
|
||||
return phase, set, nil
|
||||
}
|
||||
|
||||
func validateTemplate(tmplStr string, newline bool) (templateString, error) {
|
||||
func validateTemplate(tmplStr string, newline bool) (phase PhaseFlag, tmpl templateString, err error) {
|
||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||
tmplStr += "\n"
|
||||
}
|
||||
|
||||
if !NeedExpandVars(tmplStr) {
|
||||
return templateString{tmplStr, false}, nil
|
||||
return phase, templateString{tmplStr, false}, nil
|
||||
}
|
||||
|
||||
err := ValidateVars(tmplStr)
|
||||
phase, err = ValidateVars(tmplStr)
|
||||
if err != nil {
|
||||
return templateString{}, err
|
||||
return phase, templateString{}, gperr.Wrap(err)
|
||||
}
|
||||
return templateString{tmplStr, true}, nil
|
||||
return phase, templateString{tmplStr, true}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, error) {
|
||||
func validatePreRequestKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 2 {
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
phase = PhasePre
|
||||
tmplReq, tmpl, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return phase, nil, err
|
||||
}
|
||||
phase |= tmplReq
|
||||
return phase, &keyValueTemplate{args[0], tmpl}, nil
|
||||
}
|
||||
|
||||
func validatePostResponseKVTemplate(args []string) (phase PhaseFlag, parsedArgs any, err error) {
|
||||
if len(args) != 2 {
|
||||
return phase, nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
phase = PhasePost
|
||||
tmplReq, tmpl, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return phase, nil, err
|
||||
}
|
||||
phase |= tmplReq
|
||||
return phase, &keyValueTemplate{args[0], tmpl}, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||
l, err := zerolog.ParseLevel(level)
|
||||
if err != nil {
|
||||
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
||||
|
||||
@@ -23,7 +23,7 @@ func BenchmarkExpandVars(b *testing.B) {
|
||||
testRequest.PostForm = url.Values{"param3": {"value3"}, "param4": {"value4"}}
|
||||
|
||||
for b.Loop() {
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path $req_query $req_url $req_uri $req_host $req_port $req_addr $req_content_type $req_content_length $remote_host $remote_port $remote_addr $status_code $resp_content_type $resp_content_length $header(User-Agent) $header(X-Custom, 0) $header(X-Custom, 1) $arg(param1) $arg(param2) $arg(param3) $arg(param4) $form(param1) $form(param2) $form(param3) $form(param4) $postform(param1) $postform(param2) $postform(param3) $postform(param4)", io.Discard)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
|
||||
// TODO: remove middleware/vars.go and use this instead
|
||||
@@ -45,41 +46,84 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
type bytesBufferLike interface {
|
||||
io.Writer
|
||||
WriteByte(c byte) error
|
||||
WriteString(s string) (int, error)
|
||||
}
|
||||
|
||||
type bytesBufferAdapter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (b bytesBufferAdapter) WriteByte(c byte) error {
|
||||
buf := [1]byte{c}
|
||||
_, err := b.Write(buf[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func (b bytesBufferAdapter) WriteString(s string) (int, error) {
|
||||
return b.Write(unsafe.Slice(unsafe.StringData(s), len(s))) // avoid copy
|
||||
}
|
||||
|
||||
func asBytesBufferLike(w io.Writer) bytesBufferLike {
|
||||
switch w := w.(type) {
|
||||
case *bytes.Buffer:
|
||||
return w
|
||||
case bytesBufferLike:
|
||||
return w
|
||||
default:
|
||||
return bytesBufferAdapter{w}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateVars validates the variables in the given string.
|
||||
// It returns ErrUnexpectedVar if any invalid variable is found.
|
||||
func ValidateVars(s string) error {
|
||||
// It returns the phase that the variables require and an error if any error occurs.
|
||||
//
|
||||
// Possible errors:
|
||||
// - ErrUnexpectedVar: if any invalid variable is found
|
||||
// - ErrUnterminatedEnvVar: missing closing }
|
||||
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
||||
// - ErrUnterminatedParenthesis: missing closing )
|
||||
func ValidateVars(s string) (phase PhaseFlag, err error) {
|
||||
return ExpandVars(voidResponseModifier, &dummyRequest, s, io.Discard)
|
||||
}
|
||||
|
||||
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) error {
|
||||
dst := ioutils.NewBufferedWriter(dstW, 1024)
|
||||
defer dst.Close()
|
||||
|
||||
// ExpandVars expands the variables in the given string and writes the result to the given writer.
|
||||
// It returns the phase that the variables require and an error if any error occurs.
|
||||
//
|
||||
// Possible errors:
|
||||
// - ErrUnexpectedVar: if any invalid variable is found
|
||||
// - ErrUnterminatedEnvVar: missing closing }
|
||||
// - ErrUnterminatedQuotes: missing closing " or ' or `
|
||||
// - ErrUnterminatedParenthesis: missing closing )
|
||||
func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, dstW io.Writer) (phase PhaseFlag, err error) {
|
||||
dst := asBytesBufferLike(dstW)
|
||||
for i := 0; i < len(src); i++ {
|
||||
ch := src[i]
|
||||
if ch != '$' {
|
||||
if err := dst.WriteByte(ch); err != nil {
|
||||
return err
|
||||
if err = dst.WriteByte(ch); err != nil {
|
||||
return phase, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Look ahead
|
||||
if i+1 >= len(src) {
|
||||
return ErrUnterminatedEnvVar
|
||||
return phase, ErrUnterminatedEnvVar
|
||||
}
|
||||
j := i + 1
|
||||
|
||||
switch src[j] {
|
||||
case '$': // $$ -> literal '$'
|
||||
if err := dst.WriteByte('$'); err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
i = j
|
||||
continue
|
||||
case '{': // ${...} pass through as-is
|
||||
if _, err := dst.WriteString("${"); err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
i = j // we've consumed the '{' too
|
||||
continue
|
||||
@@ -102,24 +146,32 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
|
||||
if getter, ok := dynamicVarSubsMap[name]; ok {
|
||||
// Function-like variables
|
||||
isStatic = false
|
||||
phase |= getter.phase
|
||||
args, nextIdx, err := extractArgs(src, j, name)
|
||||
if err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
i = nextIdx
|
||||
actual, err = getter(args, w, req)
|
||||
// Expand any nested $func(...) expressions in args
|
||||
args, argPhase, err := expandArgs(args, w, req)
|
||||
if err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
} else if getter, ok := staticReqVarSubsMap[name]; ok {
|
||||
phase |= argPhase
|
||||
actual, err = getter.get(args, w, req)
|
||||
if err != nil {
|
||||
return phase, err
|
||||
}
|
||||
} else if getter, ok := staticReqVarSubsMap[name]; ok { // always available
|
||||
actual = getter(req)
|
||||
} else if getter, ok := staticRespVarSubsMap[name]; ok {
|
||||
} else if getter, ok := staticRespVarSubsMap[name]; ok { // post response
|
||||
actual = getter(w)
|
||||
phase |= PhasePost
|
||||
} else {
|
||||
return ErrUnexpectedVar.Subject(name)
|
||||
return phase, ErrUnexpectedVar.Subject(name)
|
||||
}
|
||||
if _, err := dst.WriteString(actual); err != nil {
|
||||
return err
|
||||
return phase, err
|
||||
}
|
||||
if isStatic {
|
||||
i = k - 1
|
||||
@@ -128,10 +180,10 @@ func ExpandVars(w *httputils.ResponseModifier, req *http.Request, src string, ds
|
||||
}
|
||||
|
||||
// No valid construct after '$'
|
||||
return ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
||||
return phase, ErrUnterminatedEnvVar.Withf("around $ at position %d", j)
|
||||
}
|
||||
|
||||
return nil
|
||||
return phase, nil
|
||||
}
|
||||
|
||||
func extractArgs(src string, i int, funcName string) (args []string, nextIdx int, err error) {
|
||||
@@ -175,6 +227,18 @@ func extractArgs(src string, i int, funcName string) (args []string, nextIdx int
|
||||
continue
|
||||
}
|
||||
|
||||
// Nested function call: $func(...) as an argument
|
||||
if ch == '$' && arg.Len() == 0 {
|
||||
// Capture the entire $func(...) expression as a raw argument token
|
||||
nestedEnd, nestedErr := extractNestedFuncExpr(src, nextIdx)
|
||||
if nestedErr != nil {
|
||||
return nil, 0, nestedErr
|
||||
}
|
||||
args = append(args, src[nextIdx:nestedEnd+1])
|
||||
nextIdx = nestedEnd + 1
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == ')' {
|
||||
// End of arguments
|
||||
if arg.Len() > 0 {
|
||||
@@ -210,3 +274,70 @@ func extractArgs(src string, i int, funcName string) (args []string, nextIdx int
|
||||
}
|
||||
return nil, 0, ErrUnterminatedParenthesis.Withf("func %q", funcName)
|
||||
}
|
||||
|
||||
// extractNestedFuncExpr finds the end index (inclusive) of a $func(...) expression
|
||||
// starting at position start in src. It handles nested parentheses.
|
||||
func extractNestedFuncExpr(src string, start int) (endIdx int, err error) {
|
||||
// src[start] must be '$'
|
||||
i := start + 1
|
||||
// skip the function name (valid var name chars)
|
||||
for i < len(src) && validVarNameCharset[src[i]] {
|
||||
i++
|
||||
}
|
||||
if i >= len(src) || src[i] != '(' {
|
||||
return 0, ErrUnterminatedParenthesis.Withf("nested func at position %d", start)
|
||||
}
|
||||
// Now find the matching closing parenthesis, respecting quotes and nesting
|
||||
depth := 0
|
||||
var quote byte
|
||||
for i < len(src) {
|
||||
ch := src[i]
|
||||
if quote != 0 {
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if quoteChars[ch] {
|
||||
quote = ch
|
||||
i++
|
||||
continue
|
||||
}
|
||||
switch ch {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
if quote != 0 {
|
||||
return 0, ErrUnterminatedQuotes.Withf("nested func at position %d", start)
|
||||
}
|
||||
return 0, ErrUnterminatedParenthesis.Withf("nested func at position %d", start)
|
||||
}
|
||||
|
||||
// expandArgs expands any args that are nested dynamic var expressions (starting with '$').
|
||||
// It returns the expanded args and the combined phase flags.
|
||||
func expandArgs(args []string, w *httputils.ResponseModifier, req *http.Request) (expanded []string, phase PhaseFlag, err error) {
|
||||
expanded = make([]string, len(args))
|
||||
for i, arg := range args {
|
||||
if len(arg) > 0 && arg[0] == '$' {
|
||||
var buf strings.Builder
|
||||
var argPhase PhaseFlag
|
||||
argPhase, err = ExpandVars(w, req, arg, &buf)
|
||||
if err != nil {
|
||||
return nil, phase, err
|
||||
}
|
||||
phase |= argPhase
|
||||
expanded[i] = buf.String()
|
||||
} else {
|
||||
expanded[i] = arg
|
||||
}
|
||||
}
|
||||
return expanded, phase, nil
|
||||
}
|
||||
|
||||
@@ -6,63 +6,106 @@ import (
|
||||
"strconv"
|
||||
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
)
|
||||
|
||||
var (
|
||||
VarHeader = "header"
|
||||
VarResponseHeader = "resp_header"
|
||||
VarCookie = "cookie"
|
||||
VarQuery = "arg"
|
||||
VarForm = "form"
|
||||
VarPostForm = "postform"
|
||||
VarRedacted = "redacted"
|
||||
)
|
||||
|
||||
type dynamicVarGetter func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
||||
type dynamicVarGetter struct {
|
||||
phase PhaseFlag
|
||||
get func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error)
|
||||
}
|
||||
|
||||
var dynamicVarSubsMap = map[string]dynamicVarGetter{
|
||||
VarHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Header, key, index)
|
||||
},
|
||||
VarResponseHeader: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||
},
|
||||
VarQuery: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
||||
},
|
||||
VarForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
VarHeader: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Form, key, index)
|
||||
return getValueByKeyAtIndex(req.Header, key, index)
|
||||
},
|
||||
},
|
||||
VarPostForm: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
VarResponseHeader: {
|
||||
phase: PhasePost,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.PostForm, key, index)
|
||||
return getValueByKeyAtIndex(w.Header(), key, index)
|
||||
},
|
||||
},
|
||||
VarCookie: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sharedData := httputils.GetSharedData(w)
|
||||
return getValueByKeyAtIndex(sharedData.GetCookiesMap(req), key, index)
|
||||
},
|
||||
},
|
||||
VarQuery: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return getValueByKeyAtIndex(httputils.GetSharedData(w).GetQueries(req), key, index)
|
||||
},
|
||||
},
|
||||
VarForm: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.Form, key, index)
|
||||
},
|
||||
},
|
||||
VarPostForm: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
key, index, err := getKeyAndIndex(args)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if req.Form == nil {
|
||||
if err := req.ParseForm(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return getValueByKeyAtIndex(req.PostForm, key, index)
|
||||
},
|
||||
},
|
||||
// VarRedacted wraps the result of its single argument (which may be another dynamic var
|
||||
// expression, already expanded by expandArgs) with strutils.Redact.
|
||||
VarRedacted: {
|
||||
phase: PhaseNone,
|
||||
get: func(args []string, w *httputils.ResponseModifier, req *http.Request) (string, error) {
|
||||
if len(args) != 1 {
|
||||
return "", ErrExpectOneArg
|
||||
}
|
||||
return strutils.Redact(args[0]), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -189,6 +189,64 @@ func TestExtractArgs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractArgs_NestedFunc(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src string
|
||||
startPos int
|
||||
funcName string
|
||||
wantArgs []string
|
||||
wantNextIdx int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nested func as single arg",
|
||||
src: "redacted($header(Authorization))",
|
||||
startPos: 0,
|
||||
funcName: "redacted",
|
||||
wantArgs: []string{"$header(Authorization)"},
|
||||
wantNextIdx: 31,
|
||||
},
|
||||
{
|
||||
name: "nested func with quoted arg inside",
|
||||
src: `redacted($header("X-Secret"))`,
|
||||
startPos: 0,
|
||||
funcName: "redacted",
|
||||
wantArgs: []string{`$header("X-Secret")`},
|
||||
wantNextIdx: 28,
|
||||
},
|
||||
{
|
||||
name: "nested func with two args inside",
|
||||
src: "redacted($header(X-Multi, 1))",
|
||||
startPos: 0,
|
||||
funcName: "redacted",
|
||||
wantArgs: []string{"$header(X-Multi, 1)"},
|
||||
wantNextIdx: 28,
|
||||
},
|
||||
{
|
||||
name: "nested func missing closing paren",
|
||||
src: "redacted($header(Authorization)",
|
||||
startPos: 0,
|
||||
funcName: "redacted",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args, nextIdx, err := extractArgs(tt.src, tt.startPos, tt.funcName)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantArgs, args)
|
||||
require.Equal(t, tt.wantNextIdx, nextIdx)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandVars(t *testing.T) {
|
||||
// Create a comprehensive test request with form data
|
||||
formData := url.Values{}
|
||||
@@ -232,7 +290,7 @@ func TestExpandVars(t *testing.T) {
|
||||
{
|
||||
name: "req_method",
|
||||
input: "$req_method",
|
||||
want: "POST",
|
||||
want: http.MethodPost,
|
||||
},
|
||||
{
|
||||
name: "req_path",
|
||||
@@ -446,6 +504,27 @@ func TestExpandVars(t *testing.T) {
|
||||
input: "Header: $header(User-Agent), Status: $status_code",
|
||||
want: "Header: test-agent/1.0, Status: 200",
|
||||
},
|
||||
// $redacted function
|
||||
{
|
||||
name: "redacted with plain string arg",
|
||||
input: "$redacted(secret)",
|
||||
want: "se**et",
|
||||
},
|
||||
{
|
||||
name: "redacted wrapping header",
|
||||
input: "$redacted($header(User-Agent))",
|
||||
want: "te**********.0",
|
||||
},
|
||||
{
|
||||
name: "redacted wrapping arg",
|
||||
input: "$redacted($arg(param1))",
|
||||
want: "va**e1",
|
||||
},
|
||||
{
|
||||
name: "redacted with no args",
|
||||
input: "$redacted()",
|
||||
wantErr: true,
|
||||
},
|
||||
// Escaped dollar signs
|
||||
{
|
||||
name: "escaped dollar",
|
||||
@@ -484,7 +563,7 @@ func TestExpandVars(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, tt.input, &out)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
@@ -506,7 +585,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||
"$req_method $req_url $status_code User-Agent=$header(User-Agent)",
|
||||
&out)
|
||||
|
||||
@@ -520,7 +599,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||
"Query: $arg(q), Page: $arg(page)",
|
||||
&out)
|
||||
|
||||
@@ -537,7 +616,7 @@ func TestExpandVars_Integration(t *testing.T) {
|
||||
testResponseModifier.WriteHeader(http.StatusOK)
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest,
|
||||
_, err := ExpandVars(testResponseModifier, testRequest,
|
||||
"Status: $status_code, Cache: $resp_header(Cache-Control), Limit: $resp_header(X-Rate-Limit)",
|
||||
&out)
|
||||
|
||||
@@ -560,7 +639,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
{
|
||||
name: "https scheme",
|
||||
request: &http.Request{
|
||||
Method: "GET",
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Scheme: "https", Host: "example.com", Path: "/"},
|
||||
TLS: &tls.ConnectionState{}, // Simulate TLS connection
|
||||
},
|
||||
@@ -572,7 +651,7 @@ func TestExpandVars_RequestSchemes(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||
_, err := ExpandVars(testResponseModifier, tt.request, "$req_scheme", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, out.String())
|
||||
})
|
||||
@@ -598,7 +677,7 @@ func TestExpandVars_UpstreamVariables(t *testing.T) {
|
||||
for _, varExpr := range upstreamVars {
|
||||
t.Run(varExpr, func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, varExpr, &out)
|
||||
// Should not error, may return empty string
|
||||
require.NoError(t, err)
|
||||
})
|
||||
@@ -614,16 +693,16 @@ func TestExpandVars_NoHostPort(t *testing.T) {
|
||||
|
||||
t.Run("req_host without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "example.com", out.String())
|
||||
})
|
||||
|
||||
t.Run("req_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, out.String())
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -636,16 +715,16 @@ func TestExpandVars_NoRemotePort(t *testing.T) {
|
||||
|
||||
t.Run("remote_host without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_host", &out)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, out.String())
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
|
||||
t.Run("remote_port without port", func(t *testing.T) {
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$remote_port", &out)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, out.String())
|
||||
require.Equal(t, "", out.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -654,7 +733,7 @@ func TestExpandVars_WhitespaceHandling(t *testing.T) {
|
||||
testResponseModifier := httputils.NewResponseModifier(httptest.NewRecorder())
|
||||
|
||||
var out strings.Builder
|
||||
err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||
_, err := ExpandVars(testResponseModifier, testRequest, "$req_method $req_path", &out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "GET /test", out.String())
|
||||
}
|
||||
@@ -699,7 +778,7 @@ func TestValidateVars(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateVars(tt.input)
|
||||
_, err := ValidateVars(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
||||
@@ -4,335 +4,336 @@ import { Glob } from "bun";
|
||||
import { md2mdx } from "./api-md2mdx";
|
||||
|
||||
type ImplDoc = {
|
||||
/** Directory path relative to this repo, e.g. "internal/health/check" */
|
||||
pkgPath: string;
|
||||
/** File name in wiki `src/impl/`, e.g. "internal-health-check.md" */
|
||||
docFileName: string;
|
||||
/** VitePress route path (extensionless), e.g. "/impl/internal-health-check" */
|
||||
docRoute: string;
|
||||
/** Absolute source README path */
|
||||
srcPathAbs: string;
|
||||
/** Absolute destination doc path */
|
||||
dstPathAbs: string;
|
||||
/** Directory path relative to this repo, e.g. "internal/health/check" */
|
||||
pkgPath: string;
|
||||
/** File name in wiki `src/impl/`, e.g. "internal-health-check.md" */
|
||||
docFileName: string;
|
||||
/** VitePress route path (extensionless), e.g. "/impl/internal-health-check" */
|
||||
docRoute: string;
|
||||
/** Absolute source README path */
|
||||
srcPathAbs: string;
|
||||
/** Absolute destination doc path */
|
||||
dstPathAbs: string;
|
||||
};
|
||||
|
||||
const skipSubmodules = [
|
||||
"internal/go-oidc/",
|
||||
"internal/gopsutil/",
|
||||
"internal/go-proxmox/",
|
||||
"internal/go-oidc/",
|
||||
"internal/gopsutil/",
|
||||
"internal/go-proxmox/",
|
||||
];
|
||||
|
||||
function normalizeRepoUrl(raw: string) {
|
||||
let url = (raw ?? "").trim();
|
||||
if (!url) return "";
|
||||
// Common typo: "https://https://github.com/..."
|
||||
url = url.replace(/^https?:\/\/https?:\/\//i, "https://");
|
||||
if (!/^https?:\/\//i.test(url)) url = `https://${url}`;
|
||||
url = url.replace(/\/+$/, "");
|
||||
return url;
|
||||
let url = (raw ?? "").trim();
|
||||
if (!url) return "";
|
||||
// Common typo: "https://https://github.com/..."
|
||||
url = url.replace(/^https?:\/\/https?:\/\//i, "https://");
|
||||
if (!/^https?:\/\//i.test(url)) url = `https://${url}`;
|
||||
url = url.replace(/\/+$/, "");
|
||||
return url;
|
||||
}
|
||||
|
||||
function sanitizeFileStemFromPkgPath(pkgPath: string) {
|
||||
// Convert a package path into a stable filename.
|
||||
// Example: "internal/go-oidc/example" -> "internal-go-oidc-example"
|
||||
// Keep it readable and unique (uses full path).
|
||||
const parts = pkgPath
|
||||
.split("/")
|
||||
.filter(Boolean)
|
||||
.map((p) => p.replace(/[^A-Za-z0-9._-]+/g, "-"));
|
||||
const joined = parts.join("-");
|
||||
return joined.replace(/-+/g, "-").replace(/^-|-$/g, "");
|
||||
// Convert a package path into a stable filename.
|
||||
// Example: "internal/go-oidc/example" -> "internal-go-oidc-example"
|
||||
// Keep it readable and unique (uses full path).
|
||||
const parts = pkgPath
|
||||
.split("/")
|
||||
.filter(Boolean)
|
||||
.map((p) => p.replace(/[^A-Za-z0-9._-]+/g, "-"));
|
||||
const joined = parts.join("-");
|
||||
return joined.replace(/-+/g, "-").replace(/^-|-$/g, "");
|
||||
}
|
||||
|
||||
function splitUrlAndFragment(url: string): {
|
||||
urlNoFragment: string;
|
||||
fragment: string;
|
||||
urlNoFragment: string;
|
||||
fragment: string;
|
||||
} {
|
||||
const i = url.indexOf("#");
|
||||
if (i === -1) return { urlNoFragment: url, fragment: "" };
|
||||
return { urlNoFragment: url.slice(0, i), fragment: url.slice(i) };
|
||||
const i = url.indexOf("#");
|
||||
if (i === -1) return { urlNoFragment: url, fragment: "" };
|
||||
return { urlNoFragment: url.slice(0, i), fragment: url.slice(i) };
|
||||
}
|
||||
|
||||
function isExternalOrAbsoluteUrl(url: string) {
|
||||
// - absolute site links: "/foo"
|
||||
// - pure fragments: "#bar"
|
||||
// - external schemes: "https:", "mailto:", "vscode:", etc.
|
||||
// IMPORTANT: don't treat "config.go:29" as a scheme.
|
||||
if (url.startsWith("/") || url.startsWith("#")) return true;
|
||||
if (url.includes("://")) return true;
|
||||
return /^(https?|mailto|tel|vscode|file|data|ssh|git):/i.test(url);
|
||||
// - absolute site links: "/foo"
|
||||
// - pure fragments: "#bar"
|
||||
// - external schemes: "https:", "mailto:", "vscode:", etc.
|
||||
// IMPORTANT: don't treat "config.go:29" as a scheme.
|
||||
if (url.startsWith("/") || url.startsWith("#")) return true;
|
||||
if (url.includes("://")) return true;
|
||||
return /^(https?|mailto|tel|vscode|file|data|ssh|git):/i.test(url);
|
||||
}
|
||||
|
||||
function isRepoSourceFilePath(filePath: string) {
|
||||
// Conservative allow-list: avoid rewriting .md (non-README) which may be VitePress docs.
|
||||
return /\.(go|ts|tsx|js|jsx|py|sh|yml|yaml|json|toml|env|css|html|txt)$/i.test(
|
||||
filePath,
|
||||
);
|
||||
// Conservative allow-list: avoid rewriting .md (non-README) which may be VitePress docs.
|
||||
return /\.(go|ts|tsx|js|jsx|py|sh|yml|yaml|json|toml|env|css|html|txt)$/i.test(
|
||||
filePath,
|
||||
);
|
||||
}
|
||||
|
||||
function parseFileLineSuffix(urlNoFragment: string): {
|
||||
filePath: string;
|
||||
line?: string;
|
||||
filePath: string;
|
||||
line?: string;
|
||||
} {
|
||||
// Match "file.ext:123" (line suffix), while leaving "file.ext" untouched.
|
||||
const m = urlNoFragment.match(/^(.*?):(\d+)$/);
|
||||
if (!m) return { filePath: urlNoFragment };
|
||||
return { filePath: m[1] ?? urlNoFragment, line: m[2] };
|
||||
// Match "file.ext:123" (line suffix), while leaving "file.ext" untouched.
|
||||
const m = urlNoFragment.match(/^(.*?):(\d+)$/);
|
||||
if (!m) return { filePath: urlNoFragment };
|
||||
return { filePath: m[1] ?? urlNoFragment, line: m[2] };
|
||||
}
|
||||
|
||||
function rewriteMarkdownLinksOutsideFences(
|
||||
md: string,
|
||||
rewriteInline: (url: string) => string,
|
||||
md: string,
|
||||
rewriteInline: (url: string) => string,
|
||||
) {
|
||||
const lines = md.split("\n");
|
||||
let inFence = false;
|
||||
const lines = md.split("\n");
|
||||
let inFence = false;
|
||||
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i] ?? "";
|
||||
const trimmed = line.trimStart();
|
||||
if (trimmed.startsWith("```")) {
|
||||
inFence = !inFence;
|
||||
continue;
|
||||
}
|
||||
if (inFence) continue;
|
||||
for (let i = 0; i < lines.length; i++) {
|
||||
const line = lines[i] ?? "";
|
||||
const trimmed = line.trimStart();
|
||||
if (trimmed.startsWith("```")) {
|
||||
inFence = !inFence;
|
||||
continue;
|
||||
}
|
||||
if (inFence) continue;
|
||||
|
||||
// Inline markdown links/images: [text](url "title") / 
|
||||
lines[i] = line.replace(
|
||||
/\]\(([^)\s]+)(\s+"[^"]*")?\)/g,
|
||||
(_full, urlRaw: string, maybeTitle: string | undefined) => {
|
||||
const rewritten = rewriteInline(urlRaw);
|
||||
return `](${rewritten}${maybeTitle ?? ""})`;
|
||||
},
|
||||
);
|
||||
}
|
||||
// Inline markdown links/images: [text](url "title") / 
|
||||
lines[i] = line.replace(
|
||||
/\]\(([^)\s]+)(\s+"[^"]*")?\)/g,
|
||||
(_full, urlRaw: string, maybeTitle: string | undefined) => {
|
||||
const rewritten = rewriteInline(urlRaw);
|
||||
return `](${rewritten}${maybeTitle ?? ""})`;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
return lines.join("\n");
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
function rewriteImplMarkdown(params: {
|
||||
md: string;
|
||||
pkgPath: string;
|
||||
readmeRelToDocRoute: Map<string, string>;
|
||||
dirPathToDocRoute: Map<string, string>;
|
||||
repoUrl: string;
|
||||
md: string;
|
||||
pkgPath: string;
|
||||
readmeRelToDocRoute: Map<string, string>;
|
||||
dirPathToDocRoute: Map<string, string>;
|
||||
repoUrl: string;
|
||||
}) {
|
||||
const { md, pkgPath, readmeRelToDocRoute, dirPathToDocRoute, repoUrl } =
|
||||
params;
|
||||
const { md, pkgPath, readmeRelToDocRoute, dirPathToDocRoute, repoUrl } =
|
||||
params;
|
||||
|
||||
return rewriteMarkdownLinksOutsideFences(md, (urlRaw) => {
|
||||
// Handle angle-bracketed destinations: (<./foo/README.md>)
|
||||
const angleWrapped =
|
||||
urlRaw.startsWith("<") && urlRaw.endsWith(">")
|
||||
? urlRaw.slice(1, -1)
|
||||
: urlRaw;
|
||||
return rewriteMarkdownLinksOutsideFences(md, (urlRaw) => {
|
||||
// Handle angle-bracketed destinations: (<./foo/README.md>)
|
||||
const angleWrapped =
|
||||
urlRaw.startsWith("<") && urlRaw.endsWith(">")
|
||||
? urlRaw.slice(1, -1)
|
||||
: urlRaw;
|
||||
|
||||
const { urlNoFragment, fragment } = splitUrlAndFragment(angleWrapped);
|
||||
if (!urlNoFragment) return urlRaw;
|
||||
if (isExternalOrAbsoluteUrl(urlNoFragment)) return urlRaw;
|
||||
const { urlNoFragment, fragment } = splitUrlAndFragment(angleWrapped);
|
||||
if (!urlNoFragment) return urlRaw;
|
||||
if (isExternalOrAbsoluteUrl(urlNoFragment)) return urlRaw;
|
||||
|
||||
// 1) Directory links like "common" or "common/" that have a README
|
||||
const dirPathNormalized = urlNoFragment.replace(/\/+$/, "");
|
||||
let rewritten: string | undefined;
|
||||
// First try exact match
|
||||
if (dirPathToDocRoute.has(dirPathNormalized)) {
|
||||
rewritten = `${dirPathToDocRoute.get(dirPathNormalized)}${fragment}`;
|
||||
} else {
|
||||
// Fallback: check parent directories for a README
|
||||
// This handles paths like "internal/watcher/events" where only the parent has a README
|
||||
let parentPath = dirPathNormalized;
|
||||
while (parentPath.includes("/")) {
|
||||
parentPath = parentPath.slice(0, parentPath.lastIndexOf("/"));
|
||||
if (dirPathToDocRoute.has(parentPath)) {
|
||||
rewritten = `${dirPathToDocRoute.get(parentPath)}${fragment}`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (rewritten) {
|
||||
return angleWrapped === urlRaw ? rewritten : `<${rewritten}>`;
|
||||
}
|
||||
// 1) Directory links like "common" or "common/" that have a README
|
||||
const dirPathNormalized = urlNoFragment.replace(/\/+$/, "");
|
||||
let rewritten: string | undefined;
|
||||
// First try exact match
|
||||
if (dirPathToDocRoute.has(dirPathNormalized)) {
|
||||
rewritten = `${dirPathToDocRoute.get(dirPathNormalized)}${fragment}`;
|
||||
} else {
|
||||
// Fallback: check parent directories for a README
|
||||
// This handles paths like "internal/watcher/events" where only the parent has a README
|
||||
let parentPath = dirPathNormalized;
|
||||
while (parentPath.includes("/")) {
|
||||
parentPath = parentPath.slice(0, parentPath.lastIndexOf("/"));
|
||||
if (dirPathToDocRoute.has(parentPath)) {
|
||||
rewritten = `${dirPathToDocRoute.get(parentPath)}${fragment}`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (rewritten) {
|
||||
return angleWrapped === urlRaw ? rewritten : `<${rewritten}>`;
|
||||
}
|
||||
|
||||
// 2) Intra-repo README links -> VitePress impl routes
|
||||
if (/(^|\/)README\.md$/.test(urlNoFragment)) {
|
||||
const targetReadmeRel = path.posix.normalize(
|
||||
path.posix.join(pkgPath, urlNoFragment),
|
||||
);
|
||||
const route = readmeRelToDocRoute.get(targetReadmeRel);
|
||||
if (route) {
|
||||
const rewritten = `${route}${fragment}`;
|
||||
return angleWrapped === urlRaw ? rewritten : `<${rewritten}>`;
|
||||
}
|
||||
return urlRaw;
|
||||
}
|
||||
// 2) Intra-repo README links -> VitePress impl routes
|
||||
if (/(^|\/)README\.md$/.test(urlNoFragment)) {
|
||||
const targetReadmeRel = path.posix.normalize(
|
||||
path.posix.join(pkgPath, urlNoFragment),
|
||||
);
|
||||
const route = readmeRelToDocRoute.get(targetReadmeRel);
|
||||
if (route) {
|
||||
const rewritten = `${route}${fragment}`;
|
||||
return angleWrapped === urlRaw ? rewritten : `<${rewritten}>`;
|
||||
}
|
||||
return urlRaw;
|
||||
}
|
||||
|
||||
// 3) Local source-file references like "config.go:29" -> GitHub blob link
|
||||
if (repoUrl) {
|
||||
const { filePath, line } = parseFileLineSuffix(urlNoFragment);
|
||||
if (isRepoSourceFilePath(filePath)) {
|
||||
const repoRel = path.posix.normalize(
|
||||
path.posix.join(pkgPath, filePath),
|
||||
);
|
||||
const githubUrl = `${repoUrl}/blob/main/${repoRel}${
|
||||
line ? `#L${line}` : ""
|
||||
}`;
|
||||
const rewritten = `${githubUrl}${fragment}`;
|
||||
return angleWrapped === urlRaw ? rewritten : `<${rewritten}>`;
|
||||
}
|
||||
}
|
||||
// 3) Local source-file references like "config.go:29" -> GitHub blob link
|
||||
if (repoUrl) {
|
||||
const { filePath, line } = parseFileLineSuffix(urlNoFragment);
|
||||
if (isRepoSourceFilePath(filePath)) {
|
||||
const repoRel = path.posix.normalize(
|
||||
path.posix.join(pkgPath, filePath),
|
||||
);
|
||||
const githubUrl = `${repoUrl}/blob/main/${repoRel}${
|
||||
line ? `#L${line}` : ""
|
||||
}`;
|
||||
const rewritten = `${githubUrl}${fragment}`;
|
||||
return angleWrapped === urlRaw ? rewritten : `<${rewritten}>`;
|
||||
}
|
||||
}
|
||||
|
||||
return urlRaw;
|
||||
});
|
||||
return urlRaw;
|
||||
});
|
||||
}
|
||||
|
||||
async function listRepoReadmes(repoRootAbs: string): Promise<string[]> {
|
||||
const glob = new Glob("**/README.md");
|
||||
const readmes: string[] = [];
|
||||
const glob = new Glob("**/README.md");
|
||||
const readmes: string[] = [];
|
||||
|
||||
for await (const rel of glob.scan({
|
||||
cwd: repoRootAbs,
|
||||
onlyFiles: true,
|
||||
dot: false,
|
||||
})) {
|
||||
// Bun returns POSIX-style rel paths.
|
||||
if (rel === "README.md") continue; // exclude root README
|
||||
if (rel.startsWith(".git/") || rel.includes("/.git/")) continue;
|
||||
if (rel.startsWith("node_modules/") || rel.includes("/node_modules/"))
|
||||
continue;
|
||||
let skip = false;
|
||||
for (const submodule of skipSubmodules) {
|
||||
if (rel.startsWith(submodule)) {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (skip) continue;
|
||||
readmes.push(rel);
|
||||
}
|
||||
for await (const rel of glob.scan({
|
||||
cwd: repoRootAbs,
|
||||
onlyFiles: true,
|
||||
dot: false,
|
||||
})) {
|
||||
// Bun returns POSIX-style rel paths.
|
||||
if (rel === "README.md") continue; // exclude root README
|
||||
if (rel.startsWith(".git/") || rel.includes("/.git/")) continue;
|
||||
if (rel.startsWith("node_modules/") || rel.includes("/node_modules/"))
|
||||
continue;
|
||||
let skip = false;
|
||||
for (const submodule of skipSubmodules) {
|
||||
if (rel.startsWith(submodule)) {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (skip) continue;
|
||||
readmes.push(rel);
|
||||
}
|
||||
|
||||
// Deterministic order.
|
||||
readmes.sort((a, b) => a.localeCompare(b));
|
||||
return readmes;
|
||||
// Deterministic order.
|
||||
readmes.sort((a, b) => a.localeCompare(b));
|
||||
return readmes;
|
||||
}
|
||||
|
||||
async function writeImplDocCopy(params: {
|
||||
srcAbs: string;
|
||||
dstAbs: string;
|
||||
pkgPath: string;
|
||||
readmeRelToDocRoute: Map<string, string>;
|
||||
dirPathToDocRoute: Map<string, string>;
|
||||
repoUrl: string;
|
||||
async function writeImplDocToMdx(params: {
|
||||
srcAbs: string;
|
||||
dstAbs: string;
|
||||
pkgPath: string;
|
||||
readmeRelToDocRoute: Map<string, string>;
|
||||
dirPathToDocRoute: Map<string, string>;
|
||||
repoUrl: string;
|
||||
}) {
|
||||
const {
|
||||
srcAbs,
|
||||
dstAbs,
|
||||
pkgPath,
|
||||
readmeRelToDocRoute,
|
||||
dirPathToDocRoute,
|
||||
repoUrl,
|
||||
} = params;
|
||||
await mkdir(path.dirname(dstAbs), { recursive: true });
|
||||
await rm(dstAbs, { force: true });
|
||||
const {
|
||||
srcAbs,
|
||||
dstAbs,
|
||||
pkgPath,
|
||||
readmeRelToDocRoute,
|
||||
dirPathToDocRoute,
|
||||
repoUrl,
|
||||
} = params;
|
||||
await mkdir(path.dirname(dstAbs), { recursive: true });
|
||||
|
||||
const original = await readFile(srcAbs, "utf8");
|
||||
const rewritten = rewriteImplMarkdown({
|
||||
md: original,
|
||||
pkgPath,
|
||||
readmeRelToDocRoute,
|
||||
dirPathToDocRoute,
|
||||
repoUrl,
|
||||
});
|
||||
await writeFile(dstAbs, md2mdx(rewritten));
|
||||
const original = await readFile(srcAbs, "utf8");
|
||||
const current = await readFile(dstAbs, "utf-8");
|
||||
const rewritten = md2mdx(
|
||||
rewriteImplMarkdown({
|
||||
md: original,
|
||||
pkgPath,
|
||||
readmeRelToDocRoute,
|
||||
dirPathToDocRoute,
|
||||
repoUrl,
|
||||
}),
|
||||
);
|
||||
|
||||
if (current === rewritten) {
|
||||
return;
|
||||
}
|
||||
|
||||
await writeFile(dstAbs, rewritten, "utf-8");
|
||||
console.log(`[W] ${srcAbs} -> ${dstAbs}`);
|
||||
}
|
||||
|
||||
async function syncImplDocs(
|
||||
repoRootAbs: string,
|
||||
wikiRootAbs: string,
|
||||
): Promise<ImplDoc[]> {
|
||||
const implDirAbs = path.join(wikiRootAbs, "content", "docs", "impl");
|
||||
await mkdir(implDirAbs, { recursive: true });
|
||||
repoRootAbs: string,
|
||||
wikiRootAbs: string,
|
||||
): Promise<void> {
|
||||
const implDirAbs = path.join(wikiRootAbs, "content", "docs", "impl");
|
||||
await mkdir(implDirAbs, { recursive: true });
|
||||
|
||||
const readmes = await listRepoReadmes(repoRootAbs);
|
||||
const docs: ImplDoc[] = [];
|
||||
const expectedFileNames = new Set<string>();
|
||||
expectedFileNames.add("index.mdx");
|
||||
expectedFileNames.add("meta.json");
|
||||
const readmes = await listRepoReadmes(repoRootAbs);
|
||||
const expectedFileNames = new Set<string>();
|
||||
expectedFileNames.add("index.mdx");
|
||||
expectedFileNames.add("meta.json");
|
||||
|
||||
const repoUrl = normalizeRepoUrl(
|
||||
Bun.env.REPO_URL ?? "https://github.com/yusing/godoxy",
|
||||
);
|
||||
const repoUrl = normalizeRepoUrl(
|
||||
Bun.env.REPO_URL ?? "https://github.com/yusing/godoxy",
|
||||
);
|
||||
|
||||
// Precompute mapping from repo-relative README path -> VitePress route.
|
||||
// This lets us rewrite intra-repo README links when copying content.
|
||||
const readmeRelToDocRoute = new Map<string, string>();
|
||||
// Precompute mapping from repo-relative README path -> VitePress route.
|
||||
// This lets us rewrite intra-repo README links when copying content.
|
||||
const readmeRelToDocRoute = new Map<string, string>();
|
||||
|
||||
// Also precompute mapping from directory path -> VitePress route.
|
||||
// This handles links like "[`common/`](common)" that point to directories with READMEs.
|
||||
const dirPathToDocRoute = new Map<string, string>();
|
||||
// Also precompute mapping from directory path -> VitePress route.
|
||||
// This handles links like "[`common/`](common)" that point to directories with READMEs.
|
||||
const dirPathToDocRoute = new Map<string, string>();
|
||||
|
||||
for (const readmeRel of readmes) {
|
||||
const pkgPath = path.posix.dirname(readmeRel);
|
||||
if (!pkgPath || pkgPath === ".") continue;
|
||||
for (const readmeRel of readmes) {
|
||||
const pkgPath = path.posix.dirname(readmeRel);
|
||||
if (!pkgPath || pkgPath === ".") continue;
|
||||
|
||||
const docStem = sanitizeFileStemFromPkgPath(pkgPath);
|
||||
if (!docStem) continue;
|
||||
const route = `/impl/${docStem}`;
|
||||
readmeRelToDocRoute.set(readmeRel, route);
|
||||
dirPathToDocRoute.set(pkgPath, route);
|
||||
}
|
||||
const docStem = sanitizeFileStemFromPkgPath(pkgPath);
|
||||
if (!docStem) continue;
|
||||
const route = `/impl/${docStem}`;
|
||||
readmeRelToDocRoute.set(readmeRel, route);
|
||||
dirPathToDocRoute.set(pkgPath, route);
|
||||
}
|
||||
|
||||
for (const readmeRel of readmes) {
|
||||
const pkgPath = path.posix.dirname(readmeRel);
|
||||
if (!pkgPath || pkgPath === ".") continue;
|
||||
for (const readmeRel of readmes) {
|
||||
const pkgPath = path.posix.dirname(readmeRel);
|
||||
if (!pkgPath || pkgPath === ".") continue;
|
||||
|
||||
const docStem = sanitizeFileStemFromPkgPath(pkgPath);
|
||||
if (!docStem) continue;
|
||||
const docFileName = `${docStem}.mdx`;
|
||||
const docRoute = `/impl/${docStem}`;
|
||||
const docStem = sanitizeFileStemFromPkgPath(pkgPath);
|
||||
if (!docStem) continue;
|
||||
const docFileName = `${docStem}.mdx`;
|
||||
|
||||
const srcPathAbs = path.join(repoRootAbs, readmeRel);
|
||||
const dstPathAbs = path.join(implDirAbs, docFileName);
|
||||
const srcPathAbs = path.join(repoRootAbs, readmeRel);
|
||||
const dstPathAbs = path.join(implDirAbs, docFileName);
|
||||
|
||||
await writeImplDocCopy({
|
||||
srcAbs: srcPathAbs,
|
||||
dstAbs: dstPathAbs,
|
||||
pkgPath,
|
||||
readmeRelToDocRoute,
|
||||
dirPathToDocRoute,
|
||||
repoUrl,
|
||||
});
|
||||
await writeImplDocToMdx({
|
||||
srcAbs: srcPathAbs,
|
||||
dstAbs: dstPathAbs,
|
||||
pkgPath,
|
||||
readmeRelToDocRoute,
|
||||
dirPathToDocRoute,
|
||||
repoUrl,
|
||||
});
|
||||
|
||||
docs.push({ pkgPath, docFileName, docRoute, srcPathAbs, dstPathAbs });
|
||||
expectedFileNames.add(docFileName);
|
||||
}
|
||||
expectedFileNames.add(docFileName);
|
||||
}
|
||||
|
||||
// Clean orphaned impl docs.
|
||||
const existing = await readdir(implDirAbs, { withFileTypes: true });
|
||||
for (const ent of existing) {
|
||||
if (!ent.isFile()) continue;
|
||||
if (!ent.name.endsWith(".md")) continue;
|
||||
if (expectedFileNames.has(ent.name)) continue;
|
||||
await rm(path.join(implDirAbs, ent.name), { force: true });
|
||||
}
|
||||
|
||||
// Deterministic for sidebar.
|
||||
docs.sort((a, b) => a.pkgPath.localeCompare(b.pkgPath));
|
||||
return docs;
|
||||
// Clean orphaned impl docs.
|
||||
const existing = await readdir(implDirAbs, { withFileTypes: true });
|
||||
for (const ent of existing) {
|
||||
if (!ent.isFile()) continue;
|
||||
if (!ent.name.endsWith(".md")) continue;
|
||||
if (expectedFileNames.has(ent.name)) continue;
|
||||
await rm(path.join(implDirAbs, ent.name), { force: true });
|
||||
}
|
||||
}
|
||||
|
||||
async function main() {
|
||||
// This script lives in `scripts/update-wiki/`, so repo root is two levels up.
|
||||
const repoRootAbs = path.resolve(import.meta.dir);
|
||||
// This script lives in `scripts/update-wiki/`, so repo root is two levels up.
|
||||
const repoRootAbs = path.resolve(import.meta.dir, "../..");
|
||||
|
||||
// Required by task, but allow overriding via env for convenience.
|
||||
const wikiRootAbs = Bun.env.DOCS_DIR
|
||||
? path.resolve(repoRootAbs, Bun.env.DOCS_DIR)
|
||||
: undefined;
|
||||
// Required by task, but allow overriding via env for convenience.
|
||||
const wikiRootAbs = Bun.env.DOCS_DIR
|
||||
? path.resolve(repoRootAbs, Bun.env.DOCS_DIR)
|
||||
: undefined;
|
||||
|
||||
if (!wikiRootAbs) {
|
||||
throw new Error("DOCS_DIR is not set");
|
||||
}
|
||||
if (!wikiRootAbs) {
|
||||
throw new Error("DOCS_DIR is not set");
|
||||
}
|
||||
|
||||
await syncImplDocs(repoRootAbs, wikiRootAbs);
|
||||
await syncImplDocs(repoRootAbs, wikiRootAbs);
|
||||
}
|
||||
|
||||
await main();
|
||||
|
||||
Reference in New Issue
Block a user