feat(rules): add post-request rules system with response manipulation (#160)

* Add comprehensive post-request rules support for response phase
* Enable response body, status, and header manipulation via set commands
* Refactor command handlers to support both request and response phases
* Implement response modifier system for post-request template execution
* Support response-based rule matching with status and header checks
* Add comprehensive benchmarks for matcher performance
* Refactor authentication and proxying commands for unified error handling
* Support negated conditions with !
* Enhance error handling, error formatting and validation
* Routes: add `rule_file` field with rule preset support
* Environment variable substitution: now supports variables without `GODOXY_` prefix

* new conditions:
  * `on resp_header <key> [<value>]`
  * `on status <status>`
* new commands:
  * `require_auth`
  * `set resp_header <key> <template>`
  * `set resp_body <template>`
  * `set status <code>`
  * `log <level> <path> <template>`
  * `notify <level> <provider> <title_template> <body_template>`
This commit is contained in:
Yuzerion
2025-10-14 23:53:06 +08:00
committed by GitHub
parent 19968834d2
commit 53f3397b7a
41 changed files with 4425 additions and 528 deletions

Submodule goutils updated: 26146bd560...e78e3c2d35

View File

@@ -58,3 +58,13 @@ func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
}
func AuthOrProceed(w http.ResponseWriter, r *http.Request) (proceed bool) {
err := defaultAuth.CheckToken(r)
if err != nil {
defaultAuth.LoginHandler(w, r)
return false
} else {
return true
}
}

View File

@@ -69,7 +69,7 @@ func (cfg *ConfigBase) Validate() gperr.Error {
// If only stdout is enabled, it returns nil, nil.
func (cfg *ConfigBase) IO() (WriterWithName, error) {
if cfg.Path != "" {
io, err := newFileIO(cfg.Path)
io, err := NewFileIO(cfg.Path)
if err != nil {
return nil, err
}

View File

@@ -26,7 +26,10 @@ var (
openedFilesMu sync.Mutex
)
func newFileIO(path string) (WriterWithName, error) {
// NewFileIO creates a new file writer with cleaned path.
//
// If the file is already opened, it will be returned.
func NewFileIO(path string) (WriterWithName, error) {
openedFilesMu.Lock()
defer openedFilesMu.Unlock()

View File

@@ -31,7 +31,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
wg.Add(1)
go func(index int) {
defer wg.Done()
file, err := newFileIO(cfg.Path)
file, err := NewFileIO(cfg.Path)
expect.NoError(t, err)
accessLogIOs[index] = file
}(i)

View File

@@ -61,6 +61,26 @@ func NewLogger(out ...io.Writer) zerolog.Logger {
).Level(level).With().Timestamp().Logger()
}
func NewLoggerWithFixedLevel(level zerolog.Level, out ...io.Writer) zerolog.Logger {
levelStr := level.String()
writer := zerolog.ConsoleWriter{
Out: zerolog.MultiLevelWriter(out...),
TimeFormat: timeFmt,
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
if msgI == nil {
return ""
}
return fmtMessage(msgI.(string))
},
FormatLevel: func(_ any) string {
return levelStr
},
}
return zerolog.New(
writer,
).Level(level).With().Timestamp().Logger()
}
func InitLogger(out ...io.Writer) {
logger = NewLogger(out...)
log.SetOutput(logger)

View File

@@ -8,11 +8,9 @@ import (
type Bypass []rules.RuleOn
func (b Bypass) ShouldBypass(r *http.Request) bool {
cached := rules.NewCache()
defer cached.Release()
func (b Bypass) ShouldBypass(w http.ResponseWriter, r *http.Request) bool {
for _, rule := range b {
if rule.Check(cached, r) {
if rule.Check(w, r) {
return true
}
}
@@ -26,14 +24,14 @@ type checkBypass struct {
}
func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
if c.modReq == nil || c.bypass.ShouldBypass(r) {
if c.modReq == nil || c.bypass.ShouldBypass(w, r) {
return true
}
return c.modReq.before(w, r)
}
func (c *checkBypass) modifyResponse(resp *http.Response) error {
if c.modRes == nil || c.bypass.ShouldBypass(resp.Request) {
func (c *checkBypass) modifyResponse(w http.ResponseWriter, resp *http.Response) error {
if c.modRes == nil || c.bypass.ShouldBypass(w, resp.Request) {
return nil
}
return c.modRes.modifyResponse(resp)

View File

@@ -20,10 +20,11 @@ type (
)
type (
FieldsBody []LogField
ListBody []string
MessageBody string
errorBody struct {
FieldsBody []LogField
ListBody []string
MessageBody string
MessageBodyBytes []byte
errorBody struct {
Error error
}
)
@@ -98,7 +99,15 @@ func (m MessageBody) Format(format LogFormat) ([]byte, error) {
case LogFormatRawJSON:
return sonic.Marshal(m)
}
return m.Format(LogFormatMarkdown)
return []byte(m), nil
}
func (m MessageBodyBytes) Format(format LogFormat) ([]byte, error) {
switch format {
case LogFormatRawJSON:
return sonic.Marshal(string(m))
}
return m, nil
}
func (e errorBody) Format(format LogFormat) ([]byte, error) {

View File

@@ -128,7 +128,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
}
if len(r.Rules) > 0 {
r.handler = r.Rules.BuildHandler(r.handler)
r.handler = r.Rules.BuildHandler(r.handler.ServeHTTP)
}
if r.HealthMon != nil {

View File

@@ -2,7 +2,11 @@ package route
import (
"context"
"errors"
"fmt"
"net/url"
"os"
"reflect"
"runtime"
"strings"
"sync"
@@ -17,6 +21,7 @@ import (
netutils "github.com/yusing/godoxy/internal/net"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/godoxy/internal/proxmox"
"github.com/yusing/godoxy/internal/serialization"
"github.com/yusing/godoxy/internal/types"
gperr "github.com/yusing/goutils/errs"
strutils "github.com/yusing/goutils/strings"
@@ -25,6 +30,7 @@ import (
"github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/route/rules"
rulepresets "github.com/yusing/godoxy/internal/route/rules/presets"
route "github.com/yusing/godoxy/internal/route/types"
"github.com/yusing/godoxy/internal/utils"
)
@@ -41,7 +47,8 @@ type (
route.HTTPConfig
PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"`
Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name" extension:"x-nullable"`
Rules rules.Rules `json:"rules,omitempty" extension:"x-nullable"`
RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"`
HealthCheck *types.HealthCheckConfig `json:"healthcheck"`
LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"`
Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"`
@@ -212,7 +219,10 @@ func (r *Route) Validate() gperr.Error {
}
}
errs := gperr.NewBuilder("entry validation failed")
var errs gperr.Builder
if err := r.validateRules(); err != nil {
errs.Add(err)
}
var impl types.Route
var err gperr.Error
@@ -267,6 +277,39 @@ func (r *Route) Validate() gperr.Error {
return nil
}
func (r *Route) validateRules() error {
if r.RuleFile != "" && len(r.Rules) > 0 {
return errors.New("`rule_file` and `rules` cannot be used together")
} else if r.RuleFile != "" {
src, err := url.Parse(r.RuleFile)
if err != nil {
return fmt.Errorf("failed to parse rule file url %q: %w", r.RuleFile, err)
}
switch src.Scheme {
case "embed": // embed://<preset_file_name>
rules, ok := rulepresets.GetRulePreset(src.Host)
if !ok {
return fmt.Errorf("rule preset %q not found", src.Host)
} else {
r.Rules = rules
}
case "file", "":
content, err := os.ReadFile(src.Path)
if err != nil {
return fmt.Errorf("failed to read rule file %q: %w", src.Path, err)
} else {
_, err = serialization.ConvertString(string(content), reflect.ValueOf(&r.Rules))
if err != nil {
return fmt.Errorf("failed to unmarshal rule file %q: %w", src.Path, err)
}
}
default:
return fmt.Errorf("unsupported rule file scheme %q", src.Scheme)
}
}
return nil
}
func (r *Route) Impl() types.Route {
return r.impl
}

View File

@@ -86,6 +86,13 @@ func TryGetUpstreamPort(r *http.Request) string {
return ""
}
func TryGetUpstreamHostPort(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.Host
}
return ""
}
func TryGetUpstreamAddr(r *http.Request) string {
if u := tryGetURL(r); u != nil {
return u.Host

View File

@@ -15,13 +15,13 @@ type (
)
const (
CacheKeyQueries = "queries"
CacheKeyCookies = "cookies"
CacheKeyRemoteIP = "remote_ip"
CacheKeyBasicAuth = "basic_auth"
cacheKeyQueries = "queries"
cacheKeyCookies = "cookies"
cacheKeyRemoteIP = "remote_ip"
cacheKeyBasicAuth = "basic_auth"
)
var cachePool = &sync.Pool{
var cachePool = sync.Pool{
New: func() any {
return make(Cache)
},
@@ -41,10 +41,10 @@ func (c Cache) Release() {
// GetQueries returns the queries.
// If r does not have queries, an empty map is returned.
func (c Cache) GetQueries(r *http.Request) url.Values {
v, ok := c[CacheKeyQueries]
v, ok := c[cacheKeyQueries]
if !ok {
v = r.URL.Query()
c[CacheKeyQueries] = v
c[cacheKeyQueries] = v
}
return v.(url.Values)
}
@@ -58,17 +58,17 @@ func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) {
// GetCookies returns the cookies.
// If r does not have cookies, an empty slice is returned.
func (c Cache) GetCookies(r *http.Request) []*http.Cookie {
v, ok := c[CacheKeyCookies]
v, ok := c[cacheKeyCookies]
if !ok {
v = r.Cookies()
c[CacheKeyCookies] = v
c[cacheKeyCookies] = v
}
return v.([]*http.Cookie)
}
func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) {
cookies := update(c.GetCookies(r))
c[CacheKeyCookies] = cookies
c[cacheKeyCookies] = cookies
r.Header.Del("Cookie")
for _, cookie := range cookies {
r.AddCookie(cookie)
@@ -78,14 +78,14 @@ func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie])
// GetRemoteIP returns the remote ip address.
// If r.RemoteAddr is not a valid ip address, nil is returned.
func (c Cache) GetRemoteIP(r *http.Request) net.IP {
v, ok := c[CacheKeyRemoteIP]
v, ok := c[cacheKeyRemoteIP]
if !ok {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
host = r.RemoteAddr
}
v = net.ParseIP(host)
c[CacheKeyRemoteIP] = v
c[cacheKeyRemoteIP] = v
}
return v.(net.IP)
}
@@ -93,14 +93,14 @@ func (c Cache) GetRemoteIP(r *http.Request) net.IP {
// GetBasicAuth returns *Credentials the basic auth username and password.
// If r does not have basic auth, nil is returned.
func (c Cache) GetBasicAuth(r *http.Request) *Credentials {
v, ok := c[CacheKeyBasicAuth]
v, ok := c[cacheKeyBasicAuth]
if !ok {
u, p, ok := r.BasicAuth()
if ok {
v = &Credentials{u, []byte(p)}
c[CacheKeyBasicAuth] = v
c[cacheKeyBasicAuth] = v
} else {
c[CacheKeyBasicAuth] = nil
c[cacheKeyBasicAuth] = nil
return nil
}
}

View File

@@ -3,30 +3,30 @@ package rules
import "net/http"
type (
CheckFunc func(cached Cache, r *http.Request) bool
CheckFunc func(w http.ResponseWriter, r *http.Request) bool
Checker interface {
Check(cached Cache, r *http.Request) bool
Check(w http.ResponseWriter, r *http.Request) bool
}
CheckMatchSingle []Checker
CheckMatchAll []Checker
)
func (checker CheckFunc) Check(cached Cache, r *http.Request) bool {
return checker(cached, r)
func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool {
return checker(w, r)
}
func (checkers CheckMatchSingle) Check(cached Cache, r *http.Request) bool {
func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool {
for _, check := range checkers {
if check.Check(cached, r) {
if check.Check(w, r) {
return true
}
}
return false
}
func (checkers CheckMatchAll) Check(cached Cache, r *http.Request) bool {
func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool {
for _, check := range checkers {
if !check.Check(cached, r) {
if !check.Check(w, r) {
return false
}
}

View File

@@ -3,19 +3,21 @@ package rules
import "net/http"
type (
handlerFunc func(w http.ResponseWriter, r *http.Request) error
CommandHandler interface {
// CommandHandler can read and modify the values
// then handle the request
// finally proceed to next command (or return) base on situation
Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
Handle(w http.ResponseWriter, r *http.Request) error
IsResponseHandler() bool
}
// NonTerminatingCommand will run then proceed to next command or reverse proxy.
NonTerminatingCommand http.HandlerFunc
NonTerminatingCommand handlerFunc
// TerminatingCommand will run then return immediately.
TerminatingCommand http.HandlerFunc
// DynamicCommand will return base on the request
// and can read or modify the values.
DynamicCommand func(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
TerminatingCommand handlerFunc
// OnResponseCommand will run then return based on the response.
OnResponseCommand handlerFunc
// BypassCommand will skip all the following commands
// and directly return to reverse proxy.
BypassCommand struct{}
@@ -23,29 +25,55 @@ type (
Commands []CommandHandler
)
func (c NonTerminatingCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
c(w, r)
return true
func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
return c(w, r)
}
func (c TerminatingCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
c(w, r)
func (c NonTerminatingCommand) IsResponseHandler() bool {
return false
}
func (c DynamicCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
return c(cached, w, r)
func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
if err := c(w, r); err != nil {
return err
}
return errTerminated
}
func (c BypassCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
func (c TerminatingCommand) IsResponseHandler() bool {
return false
}
func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error {
return c(w, r)
}
func (c OnResponseCommand) IsResponseHandler() bool {
return true
}
func (c Commands) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error {
return errTerminated
}
func (c BypassCommand) IsResponseHandler() bool {
return false
}
func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error {
for _, cmd := range c {
if !cmd.Handle(cached, w, r) {
return false
if err := cmd.Handle(w, r); err != nil {
return err
}
}
return true
return nil
}
func (c Commands) IsResponseHandler() bool {
for _, cmd := range c {
if cmd.IsResponseHandler() {
return true
}
}
return false
}

View File

@@ -1,27 +1,41 @@
package rules
import (
"bytes"
"fmt"
"io"
"net/http"
"path"
"strconv"
"strings"
"github.com/rs/zerolog"
"github.com/yusing/godoxy/internal/auth"
"github.com/yusing/godoxy/internal/logging"
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/godoxy/internal/notif"
"github.com/yusing/godoxy/internal/route/routes"
gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http"
"github.com/yusing/goutils/http/reverseproxy"
strutils "github.com/yusing/goutils/strings"
"github.com/yusing/goutils/synk"
)
type (
Command struct {
raw string
exec CommandHandler
raw string
exec CommandHandler
isResponseHandler bool
}
)
func (cmd *Command) IsResponseHandler() bool {
return cmd.isResponseHandler
}
const (
CommandRequireAuth = "require_auth"
CommandRewrite = "rewrite"
CommandServe = "serve"
CommandProxy = "proxy"
@@ -31,18 +45,46 @@ const (
CommandSet = "set"
CommandAdd = "add"
CommandRemove = "remove"
CommandLog = "log"
CommandNotify = "notify"
CommandPass = "pass"
CommandPassAlt = "bypass"
)
var commands = map[string]struct {
help Help
validate ValidateFunc
build func(args any) CommandHandler
help Help
validate ValidateFunc
build func(args any) CommandHandler
isResponseHandler bool
}{
CommandRequireAuth: {
help: Help{
command: CommandRequireAuth,
description: makeLines("Require HTTP authentication for incoming requests"),
args: map[string]string{},
},
validate: func(args []string) (any, gperr.Error) {
if len(args) != 0 {
return nil, ErrExpectNoArg
}
return nil, nil
},
build: func(args any) CommandHandler {
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
if !auth.AuthOrProceed(w, r) {
return errTerminated
}
return nil
})
},
},
CommandRewrite: {
help: Help{
command: CommandRewrite,
description: makeLines(
"Rewrite a request path from one prefix to another, e.g.:",
helpExample(CommandRewrite, "/foo", "/bar"),
),
args: map[string]string{
"from": "the path to rewrite, must start with /",
"to": "the path to rewrite to, must start with /",
@@ -67,24 +109,29 @@ var commands = map[string]struct {
},
build: func(args any) CommandHandler {
orig, repl := args.(*StrTuple).Unpack()
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
path := r.URL.Path
if len(path) > 0 && path[0] != '/' {
path = "/" + path
}
if !strings.HasPrefix(path, orig) {
return
return nil
}
path = repl + path[len(orig):]
r.URL.Path = path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
return nil
})
},
},
CommandServe: {
help: Help{
command: CommandServe,
description: makeLines(
"Serve static files from a local file system path, e.g.:",
helpExample(CommandServe, "/var/www"),
),
args: map[string]string{
"root": "the file system path to serve, must be an existing directory",
},
@@ -92,14 +139,19 @@ var commands = map[string]struct {
validate: validateFSPath,
build: func(args any) CommandHandler {
root := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
return nil
})
},
},
CommandRedirect: {
help: Help{
command: CommandRedirect,
description: makeLines(
"Redirect request to another URL, e.g.:",
helpExample(CommandRedirect, "https://example.com"),
),
args: map[string]string{
"to": "the url to redirect to, can be relative or absolute URL",
},
@@ -107,14 +159,19 @@ var commands = map[string]struct {
validate: validateURL,
build: func(args any) CommandHandler {
target := args.(*nettypes.URL).String()
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
return nil
})
},
},
CommandError: {
help: Help{
command: CommandError,
description: makeLines(
"Send an HTTP error response and terminate processing, e.g.:",
helpExample(CommandError, "400", "bad request"),
),
args: map[string]string{
"code": "the http status code to return",
"text": "the error message to return",
@@ -136,14 +193,21 @@ var commands = map[string]struct {
},
build: func(args any) CommandHandler {
code, text := args.(*Tuple[int, string]).Unpack()
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
// error command should overwrite the response body
GetInitResponseModifier(w).ResetBody()
http.Error(w, text, code)
return nil
})
},
},
CommandRequireBasicAuth: {
help: Help{
command: CommandRequireBasicAuth,
description: makeLines(
"Require HTTP basic authentication for incoming requests, e.g.:",
helpExample(CommandRequireBasicAuth, "Restricted Area"),
),
args: map[string]string{
"realm": "the authentication realm",
},
@@ -156,35 +220,63 @@ var commands = map[string]struct {
},
build: func(args any) CommandHandler {
realm := args.(string)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return nil
})
},
},
CommandProxy: {
help: Help{
command: CommandProxy,
description: makeLines(
"Proxy the request to the specified absolute URL, e.g.:",
helpExample(CommandProxy, "http://upstream:8080"),
),
args: map[string]string{
"to": "the url to proxy to, must be an absolute URL",
},
},
validate: validateAbsoluteURL,
validate: validateURL,
build: func(args any) CommandHandler {
target := args.(*nettypes.URL)
if target.Scheme == "" {
target.Scheme = "http"
}
if target.Host == "" {
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
url := target.URL
url.Host = routes.TryGetUpstreamHostPort(r)
if url.Host == "" {
return fmt.Errorf("no upstream host: %s", r.URL.String())
}
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
r.URL.Path = target.Path
r.URL.RawPath = r.URL.EscapedPath()
r.RequestURI = r.URL.RequestURI()
rp.ServeHTTP(w, r)
return nil
})
}
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
return TerminatingCommand(rp.ServeHTTP)
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
rp.ServeHTTP(w, r)
return nil
})
},
},
CommandSet: {
help: Help{
command: CommandSet,
description: makeLines(
"Set a field in the request or response, e.g.:",
helpExample(CommandSet, "header", "User-Agent", "godoxy"),
),
args: map[string]string{
"field": "the field to set",
"value": "the value to set",
"target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")),
"field": "the field to set",
"value": "the value to set",
},
},
validate: func(args []string) (any, gperr.Error) {
@@ -197,9 +289,14 @@ var commands = map[string]struct {
CommandAdd: {
help: Help{
command: CommandAdd,
description: makeLines(
"Add a value to a field in the request or response, e.g.:",
helpExample(CommandAdd, "header", "X-Foo", "bar"),
),
args: map[string]string{
"field": "the field to add",
"value": "the value to add",
"target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")),
"field": "the field to add",
"value": "the value to add",
},
},
validate: func(args []string) (any, gperr.Error) {
@@ -212,8 +309,13 @@ var commands = map[string]struct {
CommandRemove: {
help: Help{
command: CommandRemove,
description: makeLines(
"Remove a field from the request or response, e.g.:",
helpExample(CommandRemove, "header", "User-Agent"),
),
args: map[string]string{
"field": "the field to remove",
"target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")),
"field": "the field to remove",
},
},
validate: func(args []string) (any, gperr.Error) {
@@ -223,17 +325,157 @@ var commands = map[string]struct {
return args.(CommandHandler)
},
},
CommandLog: {
isResponseHandler: true,
help: Help{
command: CommandLog,
description: makeLines(
"The template supports the following variables:",
helpListItem("Request", "the request object"),
helpListItem("Response", "the response object"),
"",
"Example:",
helpExample(CommandLog, "info", "/dev/stdout", "{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"),
),
args: map[string]string{
"level": "the log level",
"path": "the log path (/dev/stdout for stdout, /dev/stderr for stderr)",
"template": "the template to log",
},
},
validate: func(args []string) (any, gperr.Error) {
if len(args) != 3 {
return nil, ErrExpectThreeArgs
}
tmpl, err := validateTemplate(args[2], true)
if err != nil {
return nil, err
}
level, err := validateLevel(args[0])
if err != nil {
return nil, err
}
// NOTE: file will stay opened forever
// it leverages accesslog.NewFileIO so
// it will be opened only once for the same path
f, err := openFile(args[1])
if err != nil {
return nil, err
}
return &onLogArgs{level, f, tmpl}, nil
},
build: func(args any) CommandHandler {
level, f, tmpl := args.(*onLogArgs).Unpack()
var logger io.Writer
if f == stdout || f == stderr {
logger = logging.NewLoggerWithFixedLevel(level, f)
} else {
logger = f
}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
err := executeReqRespTemplateTo(tmpl, logger, w, r)
if err != nil {
return err
}
return nil
})
},
},
CommandNotify: {
isResponseHandler: true,
help: Help{
command: CommandNotify,
description: makeLines(
"The template supports the following variables:",
helpListItem("Request", "the request object"),
helpListItem("Response", "the response object"),
"",
"Example:",
helpExample(CommandNotify, "info", "ntfy", "Received request to {{ .Request.URL }}", "{{ .Request.Method }} {{ .Response.StatusCode }}"),
),
args: map[string]string{
"level": "the log level",
"provider": "the notification provider (must be defined in config `providers.notification`)",
"title": "the title of the notification",
"body": "the body of the notification",
},
},
validate: func(args []string) (any, gperr.Error) {
if len(args) != 4 {
return nil, ErrExpectFourArgs
}
titleTmpl, err := validateTemplate(args[2], false)
if err != nil {
return nil, err
}
bodyTmpl, err := validateTemplate(args[3], false)
if err != nil {
return nil, err
}
level, err := validateLevel(args[0])
if err != nil {
return nil, err
}
// TODO: validate provider
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
// err = validateNotifProvider(args[1])
// if err != nil {
// return nil, err
// }
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
},
build: func(args any) CommandHandler {
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
to := []string{provider}
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
buf := bufPool.Get()
defer bufPool.Put(buf)
respBuf := bytes.NewBuffer(buf)
err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r)
if err != nil {
return err
}
titleLen := respBuf.Len()
err = executeReqRespTemplateTo(bodyTmpl, respBuf, w, r)
if err != nil {
return err
}
notif.Notify(&notif.LogMessage{
Level: level,
Title: string(buf[:titleLen]),
Body: notif.MessageBodyBytes(buf[titleLen:]),
To: to,
})
return nil
})
},
},
}
type reqResponseTemplateData struct {
Request *http.Request
Response struct {
StatusCode int
Header http.Header
}
}
var bufPool = synk.GetBytesPoolWithUniqueMemory()
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr]
type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr]
// Parse implements strutils.Parser.
func (cmd *Command) Parse(v string) error {
lines := strutils.SplitLine(v)
if len(lines) == 0 {
return nil
}
executors := make([]CommandHandler, 0, len(lines))
for _, line := range lines {
executors := make([]CommandHandler, 0)
isResponseHandler := false
for line := range strings.SplitSeq(v, "\n") {
if line == "" {
continue
}
@@ -257,13 +499,21 @@ func (cmd *Command) Parse(v string) error {
}
validArgs, err := builder.validate(args)
if err != nil {
return err.Subject(directive).Withf("%s", builder.help.String())
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
return err.Subject(directive).With(builder.help.Error())
}
executors = append(executors, builder.build(validArgs))
handler := builder.build(validArgs)
executors = append(executors, handler)
if builder.isResponseHandler || handler.IsResponseHandler() {
isResponseHandler = true
}
}
if len(executors) == 0 {
cmd.raw = v
cmd.exec = nil
cmd.isResponseHandler = false
return nil
}
@@ -274,10 +524,14 @@ func (cmd *Command) Parse(v string) error {
cmd.raw = v
cmd.exec = exec
if exec.IsResponseHandler() {
isResponseHandler = true
}
cmd.isResponseHandler = isResponseHandler
return nil
}
func buildCmd(executors []CommandHandler) (CommandHandler, error) {
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
for i, exec := range executors {
switch exec.(type) {
case TerminatingCommand, BypassCommand:
@@ -308,6 +562,10 @@ func (cmd *Command) isBypass() bool {
}
}
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
return cmd.exec.Handle(w, r)
}
func (cmd *Command) String() string {
return cmd.raw
}

View File

@@ -0,0 +1,400 @@
package rules
import (
"fmt"
"maps"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/serialization"
gperr "github.com/yusing/goutils/errs"
)
// mockUpstream creates a simple upstream handler for testing
func mockUpstream(status int, body string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status)
w.Write([]byte(body))
}
}
// mockUpstreamWithHeaders creates an upstream that returns specific headers
func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
maps.Copy(w.Header(), headers)
w.WriteHeader(status)
w.Write([]byte(body))
}
}
func parseRules(data string, target *Rules) gperr.Error {
_, err := serialization.ConvertString(data, reflect.ValueOf(target))
return err
}
func TestLogCommand_TemporaryFile(t *testing.T) {
upstream := mockUpstreamWithHeaders(200, "success response", http.Header{
"Content-Type": []string{"application/json"},
})
// Create a temporary file for logging
tempFile, err := os.CreateTemp("", "test-log-*.log")
require.NoError(t, err)
tempFile.Close()
defer os.Remove(tempFile.Name())
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: log-request-response
do: |
log info %q '{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }} {{ index (index .Response.Header "Content-Type") 0 }}'
`, tempFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("POST", "/api/users", nil)
req.Header.Set("User-Agent", "test-agent")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "success response", w.Body.String())
// Read and verify log content
content, err := os.ReadFile(tempFile.Name())
require.NoError(t, err)
logContent := string(content)
assert.Equal(t, "POST /api/users 200 application/json\n", logContent)
}
func TestLogCommand_StdoutAndStderr(t *testing.T) {
upstream := mockUpstream(200, "success")
var rules Rules
err := parseRules(`
- name: log-stdout
do: |
log info /dev/stdout "stdout: {{ .Request.Method }} {{ .Response.StatusCode }}"
- name: log-stderr
do: |
log error /dev/stderr "stderr: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// Note: We can't easily capture stdout/stderr in unit tests,
// but we can verify no errors occurred and the handler completed
}
func TestLogCommand_DifferentLogLevels(t *testing.T) {
upstream := mockUpstream(404, "not found")
// Create temporary files for different log levels
infoFile, err := os.CreateTemp("", "test-info-*.log")
require.NoError(t, err)
infoFile.Close()
defer os.Remove(infoFile.Name())
warnFile, err := os.CreateTemp("", "test-warn-*.log")
require.NoError(t, err)
warnFile.Close()
defer os.Remove(warnFile.Name())
errorFile, err := os.CreateTemp("", "test-error-*.log")
require.NoError(t, err)
errorFile.Close()
defer os.Remove(errorFile.Name())
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: log-info
do: |
log info %s "INFO: {{ .Request.Method }} {{ .Response.StatusCode }}"
- name: log-warn
do: |
log warn %s "WARN: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
- name: log-error
do: |
log error %s "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("DELETE", "/api/resource/123", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
// Verify each log file
infoContent, err := os.ReadFile(infoFile.Name())
require.NoError(t, err)
assert.Equal(t, "INFO: DELETE 404", strings.TrimSpace(string(infoContent)))
warnContent, err := os.ReadFile(warnFile.Name())
require.NoError(t, err)
assert.Equal(t, "WARN: /api/resource/123 404", strings.TrimSpace(string(warnContent)))
errorContent, err := os.ReadFile(errorFile.Name())
require.NoError(t, err)
assert.Equal(t, "ERROR: DELETE /api/resource/123 404", strings.TrimSpace(string(errorContent)))
}
func TestLogCommand_TemplateVariables(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Header", "custom-value")
w.Header().Set("Content-Length", "42")
w.WriteHeader(201)
w.Write([]byte("created"))
})
// Create temporary file
tempFile, err := os.CreateTemp("", "test-template-*.log")
require.NoError(t, err)
tempFile.Close()
defer os.Remove(tempFile.Name())
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: log-with-templates
do: |
log info %s 'Request: {{ .Request.Method }} {{ .Request.URL }} Host: {{ .Request.Host }} User-Agent: {{ index .Request.Header "User-Agent" 0 }} Response: {{ .Response.StatusCode }} Custom-Header: {{ index .Response.Header "X-Custom-Header" 0 }} Content-Length: {{ index .Response.Header "Content-Length" 0 }}'
`, tempFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("PUT", "/api/resource", nil)
req.Header.Set("User-Agent", "test-client/1.0")
req.Host = "example.com"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 201, w.Code)
// Verify log content
content, err := os.ReadFile(tempFile.Name())
require.NoError(t, err)
logContent := strings.TrimSpace(string(content))
assert.Equal(t, "Request: PUT /api/resource Host: example.com User-Agent: test-client/1.0 Response: 201 Custom-Header: custom-value Content-Length: 42", logContent)
}
func TestLogCommand_ConditionalLogging(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/error":
w.WriteHeader(500)
w.Write([]byte("internal server error"))
case "/notfound":
w.WriteHeader(404)
w.Write([]byte("not found"))
default:
w.WriteHeader(200)
w.Write([]byte("success"))
}
})
// Create temporary files
successFile, err := os.CreateTemp("", "test-success-*.log")
require.NoError(t, err)
successFile.Close()
defer os.Remove(successFile.Name())
errorFile, err := os.CreateTemp("", "test-error-*.log")
require.NoError(t, err)
errorFile.Close()
defer os.Remove(errorFile.Name())
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: log-success
on: status 2xx
do: |
log info %q "SUCCESS: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
- name: log-error
on: status 4xx | status 5xx
do: |
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
`, successFile.Name(), errorFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test success request
req1 := httptest.NewRequest("GET", "/success", nil)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
// Test not found request
req2 := httptest.NewRequest("GET", "/notfound", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 404, w2.Code)
// Test server error request
req3 := httptest.NewRequest("POST", "/error", nil)
w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3)
assert.Equal(t, 500, w3.Code)
// Verify success log
successContent, err := os.ReadFile(successFile.Name())
require.NoError(t, err)
successLines := strings.Split(strings.TrimSpace(string(successContent)), "\n")
assert.Len(t, successLines, 1)
assert.Equal(t, "SUCCESS: GET /success 200", successLines[0])
// Verify error log
errorContent, err := os.ReadFile(errorFile.Name())
require.NoError(t, err)
errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n")
assert.Len(t, errorLines, 2)
assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0])
assert.Equal(t, "ERROR: POST /error 500", errorLines[1])
}
func TestLogCommand_MultipleLogEntries(t *testing.T) {
upstream := mockUpstream(200, "response")
// Create temporary file
tempFile, err := os.CreateTemp("", "test-multiple-*.log")
require.NoError(t, err)
tempFile.Close()
defer os.Remove(tempFile.Name())
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: log-multiple
do: |
log info %q "{{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"`, tempFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Make multiple requests
requests := []struct {
method string
path string
}{
{"GET", "/users"},
{"POST", "/users"},
{"PUT", "/users/1"},
{"DELETE", "/users/1"},
}
for _, reqInfo := range requests {
req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
}
// Verify all requests were logged
content, err := os.ReadFile(tempFile.Name())
require.NoError(t, err)
logContent := strings.TrimSpace(string(content))
lines := strings.Split(logContent, "\n")
assert.Len(t, lines, len(requests))
for i, reqInfo := range requests {
expectedLog := reqInfo.method + " " + reqInfo.path + " 200"
assert.Equal(t, expectedLog, lines[i])
}
}
func TestLogCommand_FilePermissions(t *testing.T) {
upstream := mockUpstream(200, "success")
// Create a temporary directory
tempDir, err := os.MkdirTemp("", "test-log-dir")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create a log file path within the temp directory
logFilePath := filepath.Join(tempDir, "test.log")
var rules Rules
err = parseRules(fmt.Sprintf(`
- on: status 2xx
do: log info %q "{{ .Request.Method }} {{ .Response.StatusCode }}"`, logFilePath), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// Verify file was created and is writable
_, err = os.Stat(logFilePath)
require.NoError(t, err)
// Test writing to the file again to ensure it's not closed
req2 := httptest.NewRequest("POST", "/test2", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
// Verify both entries are in the file
content, err := os.ReadFile(logFilePath)
require.NoError(t, err)
logContent := strings.TrimSpace(string(content))
lines := strings.Split(logContent, "\n")
assert.Len(t, lines, 2)
assert.Equal(t, "GET 200", lines[0])
assert.Equal(t, "POST 200", lines[1])
}
func TestLogCommand_InvalidTemplate(t *testing.T) {
upstream := mockUpstream(200, "success")
var rules Rules
// Test with invalid template syntax
err := parseRules(`
- name: log-invalid
do: |
log info /dev/stdout "{{ .Invalid.Field }}"`, &rules)
// Should not error during parsing, but template execution will fail gracefully
assert.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Should not panic
assert.NotPanics(t, func() {
handler.ServeHTTP(w, req)
})
assert.Equal(t, 200, w.Code)
}

View File

@@ -0,0 +1,328 @@
package rules
import (
"bytes"
"io"
"net/http"
"net/url"
"strconv"
gperr "github.com/yusing/goutils/errs"
ioutils "github.com/yusing/goutils/io"
)
type (
FieldHandler struct {
set, add, remove CommandHandler
}
FieldModifier string
)
const (
ModFieldSet FieldModifier = "set"
ModFieldAdd FieldModifier = "add"
ModFieldRemove FieldModifier = "remove"
)
const (
FieldHeader = "header"
FieldResponseHeader = "resp_header"
FieldQuery = "query"
FieldCookie = "cookie"
FieldBody = "body"
FieldResponseBody = "resp_body"
FieldStatusCode = "status"
)
var AllFields = []string{FieldHeader, FieldResponseHeader, FieldQuery, FieldCookie, FieldBody, FieldResponseBody, FieldStatusCode}
// NOTE: should not use canonicalized header keys, respect to user's input
var modFields = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) *FieldHandler
}{
FieldHeader: {
help: Help{
command: FieldHeader,
args: map[string]string{
"key": "the header key",
"value": "the header template",
},
},
validate: toKeyValueTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
r.Header[k] = []string{v}
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
r.Header[k] = append(r.Header[k], v)
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
delete(r.Header, k)
return nil
}),
}
},
},
FieldResponseHeader: {
help: Help{
command: FieldResponseHeader,
args: map[string]string{
"key": "the response header key",
"value": "the response header template",
},
},
validate: toKeyValueTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
w.Header()[k] = []string{v}
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
w.Header()[k] = append(w.Header()[k], v)
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
delete(w.Header(), k)
return nil
}),
}
},
},
FieldQuery: {
help: Help{
command: FieldQuery,
args: map[string]string{
"key": "the query key",
"value": "the query template",
},
},
validate: toKeyValueTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
queries.Set(k, v)
})
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
queries.Add(k, v)
})
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
queries.Del(k)
})
return nil
}),
}
},
},
FieldCookie: {
help: Help{
command: FieldCookie,
args: map[string]string{
"key": "the cookie key",
"value": "the cookie value",
},
},
validate: toKeyValueTemplate,
builder: func(args any) *FieldHandler {
k, tmpl := args.(*keyValueTemplate).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
for i, c := range cookies {
if c.Name == k {
cookies[i].Value = v
return cookies
}
}
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return nil
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
v, err := executeRequestTemplateString(tmpl, r)
if err != nil {
return err
}
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return nil
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
index := -1
for i, c := range cookies {
if c.Name == k {
index = i
break
}
}
if index != -1 {
if len(cookies) == 1 {
return []*http.Cookie{}
}
return append(cookies[:index], cookies[index+1:]...)
}
return cookies
})
return nil
}),
}
},
},
FieldBody: {
help: Help{
command: FieldBody,
description: makeLines(
"Override the request body that will be sent to the upstream",
"The template supports the following variables:",
helpListItem("Request", "the request object"),
"",
"Example:",
helpExample(FieldBody, "HTTP STATUS: {{ .Request.Method }} {{ .Request.URL.Path }}"),
),
args: map[string]string{
"template": "the body template",
},
},
validate: func(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
return validateTemplate(args[0], true)
},
builder: func(args any) *FieldHandler {
tmpl := args.(templateOrStr)
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
if r.Body != nil {
r.Body.Close()
r.Body = nil
}
buf := pool.Get()
b := bytes.NewBuffer(buf)
err := executeRequestTemplateTo(tmpl, b, r)
if err != nil {
return err
}
r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() {
pool.Put(buf)
})
return nil
}),
}
},
},
FieldResponseBody: {
help: Help{
command: FieldResponseBody,
description: makeLines(
"Override the response body that will be sent to the client",
"The template supports the following variables:",
helpListItem("Request", "the request object"),
helpListItem("Response", "the response object"),
"",
"Example:",
helpExample(FieldResponseBody, "HTTP STATUS: {{ .Request.Method }} {{ .Response.StatusCode }}"),
),
args: map[string]string{
"template": "the response body template",
},
},
validate: func(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
return validateTemplate(args[0], true)
},
builder: func(args any) *FieldHandler {
tmpl := args.(templateOrStr)
return &FieldHandler{
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
rm := GetInitResponseModifier(w)
rm.ResetBody()
return executeReqRespTemplateTo(tmpl, rm, rm, r)
}),
}
},
},
FieldStatusCode: {
help: Help{
command: FieldStatusCode,
description: makeLines(
"Override the status code that will be sent to the client, e.g.:",
helpExample(FieldStatusCode, "200"),
),
args: map[string]string{
"code": "the status code",
},
},
validate: func(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
status, err := strconv.Atoi(args[0])
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if status < 100 || status > 599 {
return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
}
return status, nil
},
builder: func(args any) *FieldHandler {
status := args.(int)
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
GetInitResponseModifier(w).WriteHeader(status)
return nil
}),
}
},
},
}

View File

@@ -0,0 +1,643 @@
package rules
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFieldHandler_Header(t *testing.T) {
tests := []struct {
name string
key string
value string
modifier FieldModifier
setup func(*http.Request)
verify func(*http.Request, *httptest.ResponseRecorder)
}{
{
name: "set header",
key: "X-Test",
value: "test-value",
modifier: ModFieldSet,
setup: func(r *http.Request) {
r.Header.Set("X-Test", "old-value")
},
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
got := r.Header.Get("X-Test")
assert.Equal(t, "test-value", got, "Expected header X-Test to be 'test-value'")
},
},
{
name: "add header",
key: "X-Test",
value: "new-value",
modifier: ModFieldAdd,
setup: func(r *http.Request) {
r.Header.Set("X-Test", "existing-value")
},
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
values := r.Header["X-Test"]
require.Len(t, values, 2, "Expected 2 header values")
assert.Equal(t, "existing-value", values[0], "Expected first value of X-Test header to be 'existing-value'")
assert.Equal(t, "new-value", values[1], "Expected second value of X-Test header to be 'new-value'")
},
},
{
name: "remove header",
key: "X-Test",
value: "",
modifier: ModFieldRemove,
setup: func(r *http.Request) {
r.Header.Set("X-Test", "to-be-removed")
},
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
got := r.Header.Get("X-Test")
assert.Empty(t, got, "Expected header X-Test to be removed")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
case ModFieldAdd:
cmd = handler.add
case ModFieldRemove:
cmd = handler.remove
}
err := cmd.Handle(w, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(req, w)
})
}
}
func TestFieldHandler_ResponseHeader(t *testing.T) {
tests := []struct {
name string
key string
value string
modifier FieldModifier
setup func(*httptest.ResponseRecorder)
verify func(*httptest.ResponseRecorder)
}{
{
name: "set response header",
key: "X-Response-Test",
value: "response-value",
modifier: ModFieldSet,
verify: func(w *httptest.ResponseRecorder) {
got := w.Header().Get("X-Response-Test")
assert.Equal(t, "response-value", got, "Expected response header X-Response-Test to be 'response-value'")
},
},
{
name: "add response header",
key: "X-Response-Test",
value: "additional-value",
modifier: ModFieldAdd,
setup: func(w *httptest.ResponseRecorder) {
w.Header().Set("X-Response-Test", "existing-value")
},
verify: func(w *httptest.ResponseRecorder) {
values := w.Header()["X-Response-Test"]
require.Len(t, values, 2)
assert.Equal(t, values[0], "existing-value")
assert.Equal(t, values[1], "additional-value")
},
},
{
name: "remove response header",
key: "X-Response-Test",
value: "",
modifier: ModFieldRemove,
verify: func(w *httptest.ResponseRecorder) {
assert.Empty(t, w.Header().Get("X-Response-Test"))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if tt.setup != nil {
tt.setup(w)
}
tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
case ModFieldAdd:
cmd = handler.add
case ModFieldRemove:
cmd = handler.remove
}
err := cmd.Handle(w, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(w)
})
}
}
func TestFieldHandler_Query(t *testing.T) {
tests := []struct {
name string
key string
value string
modifier FieldModifier
setup func(*http.Request)
verify func(*http.Request)
}{
{
name: "set query",
key: "test",
value: "new-value",
modifier: ModFieldSet,
setup: func(r *http.Request) {
r.URL.RawQuery = "test=old-value&other=keep"
},
verify: func(r *http.Request) {
got := r.URL.Query().Get("test")
assert.Equal(t, "new-value", got, "Expected query 'test' to be 'new-value'")
gotOther := r.URL.Query().Get("other")
assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'")
},
},
{
name: "add query",
key: "test",
value: "additional-value",
modifier: ModFieldAdd,
setup: func(r *http.Request) {
r.URL.RawQuery = "test=existing-value"
},
verify: func(r *http.Request) {
values := r.URL.Query()["test"]
require.Len(t, values, 2, "Expected 2 query values")
assert.Equal(t, "existing-value", values[0], "Expected first value of test query param to be 'existing-value'")
assert.Equal(t, "additional-value", values[1], "Expected second value of test query param to be 'additional-value'")
},
},
{
name: "remove query",
key: "test",
value: "",
modifier: ModFieldRemove,
setup: func(r *http.Request) {
r.URL.RawQuery = "test=to-be-removed&other=keep"
},
verify: func(r *http.Request) {
got := r.URL.Query().Get("test")
assert.Empty(t, got, "Expected query 'test' to be removed")
gotOther := r.URL.Query().Get("other")
assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
case ModFieldAdd:
cmd = handler.add
case ModFieldRemove:
cmd = handler.remove
}
err := cmd.Handle(w, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(req)
})
}
}
func TestFieldHandler_Cookie(t *testing.T) {
tests := []struct {
name string
key string
value string
modifier FieldModifier
setup func(*http.Request)
verify func(*http.Request)
}{
{
name: "set cookie",
key: "test",
value: "new-value",
modifier: ModFieldSet,
setup: func(r *http.Request) {
r.AddCookie(&http.Cookie{Name: "test", Value: "old-value"})
},
verify: func(r *http.Request) {
cookie, err := r.Cookie("test")
assert.NoError(t, err, "Expected cookie 'test' to exist")
if err == nil {
assert.Equal(t, "new-value", cookie.Value, "Expected cookie 'test' to be 'new-value'")
}
},
},
{
name: "add cookie",
key: "test",
value: "additional-value",
modifier: ModFieldAdd,
setup: func(r *http.Request) {
r.AddCookie(&http.Cookie{Name: "test", Value: "existing-value"})
},
verify: func(r *http.Request) {
cookies := r.Cookies()
testCookies := make([]string, 0)
for _, c := range cookies {
if c.Name == "test" {
testCookies = append(testCookies, c.Value)
}
}
require.Len(t, testCookies, 2, "Expected 2 cookies with name 'test'")
assert.Equal(t, "existing-value", testCookies[0], "Expected first value of 'test' cookie to be 'existing-value'")
assert.Equal(t, "additional-value", testCookies[1], "Expected second value of 'test' cookie to be 'additional-value'")
},
},
{
name: "remove cookie",
key: "test",
value: "",
modifier: ModFieldRemove,
setup: func(r *http.Request) {
r.AddCookie(&http.Cookie{Name: "test", Value: "to-be-removed"})
r.AddCookie(&http.Cookie{Name: "other", Value: "keep"})
},
verify: func(r *http.Request) {
_, err := r.Cookie("test")
assert.Error(t, err, "Expected cookie 'test' to be removed")
cookie, err := r.Cookie("other")
assert.NoError(t, err, "Expected cookie 'other' to exist")
if err == nil {
assert.Equal(t, "keep", cookie.Value, "Expected cookie 'other' to be 'keep'")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.value, false)
if tErr != nil {
t.Fatalf("Failed to validate template: %v", tErr)
}
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
var cmd CommandHandler
switch tt.modifier {
case ModFieldSet:
cmd = handler.set
case ModFieldAdd:
cmd = handler.add
case ModFieldRemove:
cmd = handler.remove
}
err := cmd.Handle(w, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(req)
})
}
}
func TestFieldHandler_Body(t *testing.T) {
tests := []struct {
name string
template string
setup func(*http.Request)
verify func(*http.Request)
}{
{
name: "set body with template",
template: "Hello {{ .Request.Method }} {{ .Request.URL.Path }}",
setup: func(r *http.Request) {
r.Method = "POST"
r.URL.Path = "/test"
},
verify: func(r *http.Request) {
body, err := io.ReadAll(r.Body)
assert.NoError(t, err, "Failed to read body")
expected := "Hello POST /test"
assert.Equal(t, expected, string(body), "Expected body content")
},
},
{
name: "set body with existing body",
template: "Overridden",
setup: func(r *http.Request) {
r.Body = io.NopCloser(strings.NewReader("original body"))
},
verify: func(r *http.Request) {
body, err := io.ReadAll(r.Body)
assert.NoError(t, err, "Failed to read body")
assert.Equal(t, "Overridden", string(body), "Expected body to be 'Overridden'")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
tmpl, tErr := validateTemplate(tt.template, false)
if tErr != nil {
t.Fatalf("Failed to parse template: %v", tErr)
}
handler := modFields[FieldBody].builder(tmpl)
err := handler.set.Handle(w, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(req)
})
}
}
func TestFieldHandler_ResponseBody(t *testing.T) {
tests := []struct {
name string
template string
setup func(*http.Request)
verify func(*ResponseModifier)
}{
{
name: "set response body with template",
template: "Response: {{ .Request.Method }} {{ .Request.URL.Path }}",
setup: func(r *http.Request) {
r.Method = "GET"
r.URL.Path = "/api/test"
},
verify: func(rm *ResponseModifier) {
content := rm.buf.String()
expected := "Response: GET /api/test"
assert.Equal(t, expected, content, "Expected response body")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
tt.setup(req)
w := httptest.NewRecorder()
// Create ResponseModifier wrapper
rm := NewResponseModifier(w)
tmpl, tErr := validateTemplate(tt.template, false)
if tErr != nil {
t.Fatalf("Failed to parse template: %v", tErr)
}
handler := modFields[FieldResponseBody].builder(tmpl)
err := handler.set.Handle(rm, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
tt.verify(rm)
})
}
}
func TestFieldHandler_StatusCode(t *testing.T) {
tests := []struct {
name string
status int
verify func(*httptest.ResponseRecorder)
}{
{
name: "set status code 200",
status: 200,
verify: func(w *httptest.ResponseRecorder) {
assert.Equal(t, 200, w.Code, "Expected status code 200")
},
},
{
name: "set status code 404",
status: 404,
verify: func(w *httptest.ResponseRecorder) {
assert.Equal(t, 404, w.Code, "Expected status code 404")
},
},
{
name: "set status code 500",
status: 500,
verify: func(w *httptest.ResponseRecorder) {
assert.Equal(t, 500, w.Code, "Expected status code 500")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
rm := NewResponseModifier(w)
var cmd Command
err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status))
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
err = cmd.ServeHTTP(rm, req)
if err != nil {
t.Fatalf("Handler returned error: %v", err)
}
rm.FlushRelease()
tt.verify(w)
})
}
}
func TestFieldValidation(t *testing.T) {
tests := []struct {
name string
field string
args []string
wantError bool
}{
{
name: "header valid",
field: FieldHeader,
args: []string{"key", "value"},
wantError: false,
},
{
name: "header invalid - missing value",
field: FieldHeader,
args: []string{"key"},
wantError: true,
},
{
name: "response header valid",
field: FieldResponseHeader,
args: []string{"key", "value"},
wantError: false,
},
{
name: "query valid",
field: FieldQuery,
args: []string{"key", "value"},
wantError: false,
},
{
name: "cookie valid",
field: FieldCookie,
args: []string{"key", "value"},
wantError: false,
},
{
name: "body valid template",
field: FieldBody,
args: []string{"Hello {{ .Request.Method }}"},
wantError: false,
},
{
name: "body invalid template syntax",
field: FieldBody,
args: []string{"Hello {{ .InvalidField "},
wantError: true,
},
{
name: "response body valid template",
field: FieldResponseBody,
args: []string{"Response: {{ .Request.Method }}"},
wantError: false,
},
{
name: "status code valid",
field: FieldStatusCode,
args: []string{"200"},
wantError: false,
},
{
name: "status code invalid - too low",
field: FieldStatusCode,
args: []string{"99"},
wantError: true,
},
{
name: "status code invalid - too high",
field: FieldStatusCode,
args: []string{"600"},
wantError: true,
},
{
name: "status code invalid - not a number",
field: FieldStatusCode,
args: []string{"not-a-number"},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
field, exists := modFields[tt.field]
assert.True(t, exists, "Field %s does not exist", tt.field)
_, err := field.validate(tt.args)
if tt.wantError {
assert.Error(t, err, "Expected error but got none")
} else {
assert.NoError(t, err, "Expected no error but got: %v", err)
}
})
}
}
func TestAllFields(t *testing.T) {
expectedFields := []string{
FieldHeader,
FieldResponseHeader,
FieldQuery,
FieldCookie,
FieldBody,
FieldResponseBody,
FieldStatusCode,
}
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
for _, expected := range expectedFields {
found := false
for _, actual := range AllFields {
if actual == expected {
found = true
break
}
}
assert.True(t, found, "Expected field %s not found in AllFields", expected)
}
}
func TestModFields(t *testing.T) {
for fieldName, field := range modFields {
// Test that each field has required components
assert.NotNil(t, field.validate, "Field %s has nil validate function", fieldName)
assert.NotNil(t, field.builder, "Field %s has nil builder function", fieldName)
assert.NotEmpty(t, field.help.command, "Field %s has empty help command", fieldName)
}
}

View File

@@ -99,10 +99,15 @@ func TestParseCommands(t *testing.T) {
},
// proxy directive tests
{
name: "proxy_valid",
name: "proxy_valid_abs",
input: "proxy http://localhost:8080",
wantErr: nil,
},
{
name: "proxy_valid_rel",
input: "proxy /foo/bar",
wantErr: nil,
},
{
name: "proxy_missing_target",
input: "proxy",

View File

@@ -0,0 +1,23 @@
package rules
import (
"testing"
gperr "github.com/yusing/goutils/errs"
)
func TestErrorFormat(t *testing.T) {
var rules Rules
err := parseRules(`
- on: error 405
do: error 405 error
- on: header too many args
do: error 405 error
- name: missing do
on: status 200
- on: header X-Header
do: set invalid_command
- do: set resp_body "{{ .Request.Method {{ .Request.URL.Path }}"
`, &rules)
gperr.LogError("error", err)
}

View File

@@ -7,15 +7,21 @@ import (
var (
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
ErrUnterminatedEnvVar = gperr.New("unterminated env var")
ErrUnknownDirective = gperr.New("unknown directive")
ErrUnknownModField = gperr.New("unknown field")
ErrEnvVarNotFound = gperr.New("env variable not found")
ErrInvalidArguments = gperr.New("invalid arguments")
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
ErrInvalidSetTarget = gperr.New("invalid `rule.set` target")
ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg")
ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg")
ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args")
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg")
ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg")
ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args")
ErrExpectTwoOrThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 or 3 args")
ErrExpectThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 3 args")
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
errTerminated = gperr.New("terminated")
)

View File

@@ -1,142 +0,0 @@
package rules
import (
"net/http"
"net/url"
)
type (
FieldHandler struct {
set, add, remove CommandHandler
}
FieldModifier string
)
const (
ModFieldSet FieldModifier = "set"
ModFieldAdd FieldModifier = "add"
ModFieldRemove FieldModifier = "remove"
)
const (
FieldHeader = "header"
FieldQuery = "query"
FieldCookie = "cookie"
)
var modFields = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) *FieldHandler
}{
FieldHeader: {
help: Help{
command: FieldHeader,
args: map[string]string{
"key": "the header key",
"value": "the header value",
},
},
validate: toStrTuple,
builder: func(args any) *FieldHandler {
k, v := args.(*StrTuple).Unpack()
return &FieldHandler{
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
w.Header()[k] = []string{v}
}),
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h[k] = append(h[k], v)
}),
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
delete(w.Header(), k)
}),
}
},
},
FieldQuery: {
help: Help{
command: FieldQuery,
args: map[string]string{
"key": "the query key",
"value": "the query value",
},
},
validate: toStrTuple,
builder: func(args any) *FieldHandler {
k, v := args.(*StrTuple).Unpack()
return &FieldHandler{
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateQueries(r, func(queries url.Values) {
queries.Set(k, v)
})
return true
}),
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateQueries(r, func(queries url.Values) {
queries.Add(k, v)
})
return true
}),
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateQueries(r, func(queries url.Values) {
queries.Del(k)
})
return true
}),
}
},
},
FieldCookie: {
help: Help{
command: FieldCookie,
args: map[string]string{
"key": "the cookie key",
"value": "the cookie value",
},
},
validate: toStrTuple,
builder: func(args any) *FieldHandler {
k, v := args.(*StrTuple).Unpack()
return &FieldHandler{
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
for i, c := range cookies {
if c.Name == k {
cookies[i].Value = v
return cookies
}
}
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return true
}),
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
return append(cookies, &http.Cookie{Name: k, Value: v})
})
return true
}),
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
index := -1
for i, c := range cookies {
if c.Name == k {
index = i
break
}
}
if index != -1 {
if len(cookies) == 1 {
return []*http.Cookie{}
}
return append(cookies[:index], cookies[index+1:]...)
}
return cookies
})
return true
}),
}
},
},
}

View File

@@ -1,40 +1,134 @@
package rules
import "strings"
import (
"fmt"
"slices"
"strconv"
"strings"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/strings/ansi"
)
type Help struct {
command string
description string
description []string
args map[string]string // args[arg] -> description
}
/*
Generate help string, e.g.
func makeLines(lines ...string) []string {
return lines
}
rewrite <from> <to>
from: the path to rewrite, must start with /
to: the path to rewrite to, must start with /
*/
func (h *Help) String() string {
func helpExample(cmd string, args ...string) string {
var sb strings.Builder
sb.WriteString(h.command)
sb.WriteString(" ")
for arg := range h.args {
sb.WriteString(strings.ToUpper(arg))
sb.WriteRune(' ')
}
if h.description != "" {
sb.WriteString("\n\t")
sb.WriteString(h.description)
sb.WriteRune('\n')
}
sb.WriteRune('\n')
for arg, desc := range h.args {
sb.WriteRune('\t')
sb.WriteString(strings.ToUpper(arg))
sb.WriteString(": ")
sb.WriteString(desc)
sb.WriteRune('\n')
sb.WriteString(" ")
sb.WriteString(ansi.WithANSI(cmd, ansi.HighlightGreen))
for _, arg := range args {
var out strings.Builder
pos := 0
for {
start := strings.Index(arg[pos:], "{{")
if start == -1 {
if pos < len(arg) {
// If no template at all (pos == 0), cyan highlight for whole-arg
// Otherwise, for mixed strings containing templates, leave non-template text unhighlighted
if pos == 0 {
out.WriteString(ansi.WithANSI(arg[pos:], ansi.HighlightCyan))
} else {
out.WriteString(arg[pos:])
}
}
break
}
start += pos
if start > pos {
// Non-template text should not be highlighted
out.WriteString(arg[pos:start])
}
end := strings.Index(arg[start+2:], "}}")
if end == -1 {
// Unmatched template start; write remainder without highlighting
out.WriteString(arg[start:])
break
}
end += start + 2
inner := strings.TrimSpace(arg[start+2 : end])
parts := strings.Split(inner, ".")
out.WriteString(helpTemplateVar(parts...))
pos = end + 2
}
fmt.Fprintf(&sb, ` "%s"`, out.String())
}
return sb.String()
}
func helpListItem(key string, value string) string {
var sb strings.Builder
sb.WriteString(" ")
sb.WriteString(ansi.WithANSI(key, ansi.HighlightYellow))
sb.WriteString(": ")
sb.WriteString(value)
return sb.String()
}
// helpFuncCall generates a string like "fn(arg1, arg2, arg3)"
func helpFuncCall(fn string, args ...string) string {
var sb strings.Builder
sb.WriteString(ansi.WithANSI(fn, ansi.HighlightRed))
sb.WriteString("(")
for i, arg := range args {
fmt.Fprintf(&sb, `"%s"`, ansi.WithANSI(arg, ansi.HighlightCyan))
if i < len(args)-1 {
sb.WriteString(", ")
}
}
sb.WriteString(")")
return sb.String()
}
// helpTemplateVar generates a string like "{{ .Request.Method }} {{ .Request.URL.Path }}"
func helpTemplateVar(parts ...string) string {
var sb strings.Builder
sb.WriteString(ansi.WithANSI("{{ ", ansi.HighlightWhite))
for i, part := range parts {
sb.WriteString(ansi.WithANSI(part, ansi.HighlightCyan))
if i < len(parts)-1 {
sb.WriteString(".")
}
}
sb.WriteString(ansi.WithANSI(" }}", ansi.HighlightWhite))
return sb.String()
}
/*
Generate help string as error, e.g.
rewrite <from> <to>
from: the path to rewrite, must start with /
to: the path to rewrite to, must start with /
*/
func (h *Help) Error() gperr.Error {
var lines gperr.MultilineError
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
lines.AddStrings(h.description...)
lines.Adds(" args:")
argKeys := make([]string, 0, len(h.args))
longestArg := 0
for arg := range h.args {
if len(arg) > longestArg {
longestArg = len(arg)
}
argKeys = append(argKeys, arg)
}
// sort argKeys alphabetically to make output stable
slices.Sort(argKeys)
for _, arg := range argKeys {
desc := h.args[arg]
lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc)
}
return &lines
}

View File

@@ -0,0 +1,940 @@
package rules_test
import (
"fmt"
"maps"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/route"
"github.com/yusing/godoxy/internal/route/routes"
"github.com/yusing/godoxy/internal/serialization"
gperr "github.com/yusing/goutils/errs"
"golang.org/x/crypto/bcrypt"
. "github.com/yusing/godoxy/internal/route/rules"
)
// mockUpstream creates a simple upstream handler for testing
func mockUpstream(status int, body string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status)
w.Write([]byte(body))
}
}
// mockUpstreamWithHeaders creates an upstream that returns specific headers
func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
maps.Copy(w.Header(), headers)
w.WriteHeader(status)
w.Write([]byte(body))
}
}
func mockRoute(alias string) *route.FileServer {
return &route.FileServer{Route: &route.Route{Alias: alias}}
}
func parseRules(data string, target *Rules) gperr.Error {
_, err := serialization.ConvertString(strings.TrimSpace(data), reflect.ValueOf(target))
return err
}
func TestHTTPFlow_BasicPreRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
w.WriteHeader(200)
w.Write([]byte("upstream response"))
})
var rules Rules
err := parseRules(`
- name: add-header
on: path /
do: set header X-Custom-Header test-value
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "upstream response", w.Body.String())
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
}
func TestHTTPFlow_BypassRule(t *testing.T) {
upstream := mockUpstream(200, "upstream response")
var rules Rules
err := parseRules(`
- name: bypass-condition
on: path /bypass
do: bypass
- name: should-not-execute
on: path /bypass
do: error 500 "should not reach here"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/bypass", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "upstream response", w.Body.String())
}
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
upstream := mockUpstream(200, "should not be called")
var rules Rules
err := parseRules(`
- name: error-response
on: path /error
do: error 403 Forbidden
- name: should-not-execute
on: path /error
do: set header X-Header ignored
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/error", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "Forbidden\n", w.Body.String())
assert.Empty(t, w.Header().Get("X-Header"))
}
func TestHTTPFlow_RedirectFlow(t *testing.T) {
upstream := mockUpstream(200, "should not be called")
var rules Rules
err := parseRules(`
- name: redirect-rule
on: path /old-path
do: redirect /new-path
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/old-path", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 307, w.Code) // TemporaryRedirect
assert.Equal(t, "/new-path", w.Header().Get("Location"))
}
func TestHTTPFlow_RewriteFlow(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("path: " + r.URL.Path))
})
var rules Rules
err := parseRules(`
- name: rewrite-rule
on: path glob(/api/*)
do: rewrite /api/ /v1/
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/api/users", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "path: /v1/users", w.Body.String())
}
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
})
var rules Rules
err := parseRules(`
- name: add-request-id
on: path /
do: set header X-Request-Id req-123
- name: add-auth-header
on: path /
do: set header X-Auth-Token token-456
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "upstream: req-123", w.Body.String())
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
}
func TestHTTPFlow_PostResponseRule(t *testing.T) {
upstream := mockUpstreamWithHeaders(200, "success", http.Header{
"X-Upstream": []string{"upstream-value"},
})
tempFile, err := os.CreateTemp("", "test-log-*.txt")
// Create a temporary file for logging
require.NoError(t, err)
defer os.Remove(tempFile.Name())
tempFile.Close()
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: log-response
on: path /test
do: log info %s "{{ .Request.Method }} {{ .Response.StatusCode }}"
`, tempFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "success", w.Body.String())
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream"))
// Check log file
content, err := os.ReadFile(tempFile.Name())
require.NoError(t, err)
assert.Equal(t, "GET 200\n", string(content))
}
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/success" {
w.WriteHeader(200)
w.Write([]byte("success"))
} else {
w.WriteHeader(404)
w.Write([]byte("not found"))
}
})
var rules Rules
// Create a temporary file for logging
tempFile, err := os.CreateTemp("", "test-error-log-*.txt")
require.NoError(t, err)
defer os.Remove(tempFile.Name())
tempFile.Close()
err = parseRules(fmt.Sprintf(`
- name: log-errors
on: status 4xx
do: log error %s "{{ .Request.URL }} returned {{ .Response.StatusCode }}"
`, tempFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test successful request (should not log)
req1 := httptest.NewRequest("GET", "/success", nil)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
// Test error request (should log)
req2 := httptest.NewRequest("GET", "/notfound", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 404, w2.Code)
// Check log file
content, err := os.ReadFile(tempFile.Name())
require.NoError(t, err)
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
require.Len(t, lines, 1, "only 4xx requests should be logged")
assert.Equal(t, "/notfound returned 404", lines[0])
}
func TestHTTPFlow_ConditionalRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("hello " + r.Header.Get("X-Username")))
})
var rules Rules
err := parseRules(`
- name: auth-required
on: header Authorization
do: |
set header X-Username authenticated-user
set resp_header X-Username authenticated-user
- name: default
do: |
set header X-Username anonymous
set resp_header X-Username anonymous
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test with Authorization header
req1 := httptest.NewRequest("GET", "/", nil)
req1.Header.Set("Authorization", "Bearer token")
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "hello authenticated-user", w1.Body.String())
assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username"))
// Test without Authorization header
req2 := httptest.NewRequest("GET", "/", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
assert.Equal(t, "hello anonymous", w2.Body.String())
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
}
func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate different responses based on path
if r.URL.Path == "/protected" {
if r.Header.Get("X-Auth") != "valid" {
w.WriteHeader(401)
w.Write([]byte("unauthorized"))
return
}
}
w.Header().Set("X-Response-Time", "100ms")
w.WriteHeader(200)
w.Write([]byte("success"))
})
// Create temporary files for logging
logFile, err := os.CreateTemp("", "test-access-log-*.txt")
require.NoError(t, err)
defer os.Remove(logFile.Name())
logFile.Close()
errorLogFile, err := os.CreateTemp("", "test-error-log-*.txt")
require.NoError(t, err)
defer os.Remove(errorLogFile.Name())
errorLogFile.Close()
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: add-correlation-id
do: set resp_header X-Correlation-Id random_uuid
- name: validate-auth
on: path /protected
do: require_basic_auth "Protected Area"
- name: log-all-requests
do: |
log info %q "{{ .Request.Method }} {{ .Request.URL }} -> {{ .Response.StatusCode }}"
- name: log-errors
on: status 4xx
do: |
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"
`, logFile.Name(), errorLogFile.Name()), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test successful request
req1 := httptest.NewRequest("GET", "/public", nil)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "success", w1.Body.String())
assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id"))
assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time"))
// Test unauthorized protected request
req2 := httptest.NewRequest("GET", "/protected", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 401, w2.Code)
assert.Equal(t, w2.Body.String(), "Unauthorized\n")
// Test authorized protected request
req3 := httptest.NewRequest("GET", "/protected", nil)
req3.SetBasicAuth("user", "pass")
w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3)
// This should fail because our simple upstream expects X-Auth: valid header
// but the basic auth requirement should add the appropriate header
assert.Equal(t, 401, w3.Code)
// Check log files
logContent, err := os.ReadFile(logFile.Name())
require.NoError(t, err)
lines := strings.Split(strings.TrimSpace(string(logContent)), "\n")
require.Len(t, lines, 3, "all requests should be logged")
assert.Equal(t, "GET /public -> 200", lines[0])
assert.Equal(t, "GET /protected -> 401", lines[1])
assert.Equal(t, "GET /protected -> 401", lines[2])
errorLogContent, err := os.ReadFile(errorLogFile.Name())
require.NoError(t, err)
// Should have at least one 401 error logged
lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n")
require.Len(t, lines, 2, "all errors should be logged")
assert.Equal(t, "ERROR: GET /protected 401", lines[0])
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
}
func TestHTTPFlow_DefaultRule(t *testing.T) {
upstream := mockUpstream(200, "upstream response")
var rules Rules
err := parseRules(`
- name: default
do: set resp_header X-Default-Applied true
- name: special-rule
on: path /special
do: set resp_header X-Special-Handled true
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test default rule
req1 := httptest.NewRequest("GET", "/regular", nil)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
// Test special rule + default rule
req2 := httptest.NewRequest("GET", "/special", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
}
func TestHTTPFlow_HeaderManipulation(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Echo back a header
headerValue := r.Header.Get("X-Test-Header")
w.Header().Set("X-Echoed-Header", headerValue)
w.WriteHeader(200)
w.Write([]byte("header echoed"))
})
var rules Rules
err := parseRules(`
- name: remove-sensitive-header
do: remove resp_header X-Secret
- name: add-custom-header
do: add resp_header X-Custom-Header custom-value
- name: modify-existing-header
on: header X-Test-Header
do: set header X-Test-Header modified-value
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Secret", "secret-value")
req.Header.Set("X-Test-Header", "original-value")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
// Ensure the secret header was removed and not passed to upstream
// (we can't directly test this, but the upstream shouldn't see it)
}
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
w.WriteHeader(200)
w.Write([]byte("query: " + query.Get("param")))
})
var rules Rules
err := parseRules(`
- name: add-query-param
on: query param
do: set query param added-value
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/path?param=original", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// The set command should have modified the query parameter
assert.Equal(t, "query: added-value", w.Body.String())
}
func TestHTTPFlow_ServeCommand(t *testing.T) {
// Create a temporary directory with test files
tempDir, err := os.MkdirTemp("", "test-serve-*")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
// Create test files directly in the temp directory
testFile := filepath.Join(tempDir, "index.html")
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0644)
require.NoError(t, err)
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: serve-static
on: path glob(/files/*)
do: serve %s
`, tempDir), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
// Test serving a file - serve command serves files relative to the root directory
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
// We need to create the file at the expected path
filesDir := filepath.Join(tempDir, "files")
err = os.Mkdir(filesDir, 0755)
require.NoError(t, err)
filesIndexFile := filepath.Join(filesDir, "index.html")
err = os.WriteFile(filesIndexFile, []byte("<h1>Test Page</h1>"), 0644)
require.NoError(t, err)
req1 := httptest.NewRequest("GET", "/files/index.html", nil)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
// The serve command should work, but might redirect
// Let's just verify it doesn't call the upstream
assert.NotEqual(t, "should not be called", w1.Body.String())
// Test file not found
req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 404, w2.Code)
}
func TestHTTPFlow_ProxyCommand(t *testing.T) {
// Create a mock upstream server
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream-Header", "upstream-value")
w.WriteHeader(200)
w.Write([]byte("upstream response"))
}))
defer upstreamServer.Close()
var rules Rules
err := parseRules(fmt.Sprintf(`
- name: proxy-to-upstream
on: path glob(/api/*)
do: proxy %s
`, upstreamServer.URL), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
req := httptest.NewRequest("GET", "/api/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// The proxy command should forward the request to the upstream server
assert.Equal(t, 200, w.Code)
assert.Equal(t, "upstream response", w.Body.String())
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
}
func TestHTTPFlow_NotifyCommand(t *testing.T) {
// TODO:
}
func TestHTTPFlow_FormConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("form processed"))
})
var rules Rules
err := parseRules(`
- name: process-form
on: form username
do: set resp_header X-Username "{{ index .Request.Form.username 0 }}"
- name: process-postform
on: postform email
do: set resp_header X-Email "{{ index .Request.PostForm.email 0 }}"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test form condition
formData := url.Values{"username": {"john_doe"}}
req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode()))
req1.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "john_doe", w1.Header().Get("X-Username"))
// Test postform condition
postFormData := url.Values{"email": {"john@example.com"}}
req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode()))
req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
}
func TestHTTPFlow_RemoteConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("remote processed"))
})
var rules Rules
err := parseRules(`
- name: allow-localhost
on: remote 127.0.0.1
do: set resp_header X-Access "local"
- name: block-private
on: remote 192.168.0.0/16
do: error 403 "Private network blocked"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test localhost condition
req1 := httptest.NewRequest("GET", "/", nil)
req1.RemoteAddr = "127.0.0.1:12345"
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "local", w1.Header().Get("X-Access"))
// Test private network block
req2 := httptest.NewRequest("GET", "/", nil)
req2.RemoteAddr = "192.168.1.100:12345"
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 403, w2.Code)
assert.Equal(t, "Private network blocked\n", w2.Body.String())
}
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("auth processed"))
})
// Generate bcrypt hashes for passwords
adminHash, err := bcrypt.GenerateFromPassword([]byte("adminpass"), bcrypt.DefaultCost)
require.NoError(t, err)
guestHash, err := bcrypt.GenerateFromPassword([]byte("guestpass"), bcrypt.DefaultCost)
require.NoError(t, err)
var rules Rules
err = parseRules(fmt.Sprintf(`
- name: check-auth
on: basic_auth admin %s
do: set resp_header X-Auth-Status "admin"
- name: check-other-user
on: basic_auth guest %s
do: set resp_header X-Auth-Status "guest"
`, string(adminHash), string(guestHash)), &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test admin user
req1 := httptest.NewRequest("GET", "/", nil)
req1.SetBasicAuth("admin", "adminpass")
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status"))
// Test guest user
req2 := httptest.NewRequest("GET", "/", nil)
req2.SetBasicAuth("guest", "guestpass")
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
}
func TestHTTPFlow_RouteConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("route processed"))
})
var rules Rules
err := parseRules(`
- name: backend-route
on: route backend
do: set resp_header X-Route "backend"
- name: frontend-route
on: route frontend
do: set resp_header X-Route "frontend"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test API route
req1 := httptest.NewRequest("GET", "/", nil)
req1 = routes.WithRouteContext(req1, mockRoute("backend"))
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "backend", w1.Header().Get("X-Route"))
// Test admin route
req2 := httptest.NewRequest("GET", "/", nil)
req2 = routes.WithRouteContext(req2, mockRoute("frontend"))
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
}
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(405)
w.Write([]byte("method not allowed"))
})
var rules Rules
err := parseRules(`
- name: method-not-allowed
on: status 405
do: |
error 405 'error'
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 405, w.Code)
assert.Equal(t, "error\n", w.Body.String())
}
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Response-Header", "response header")
w.WriteHeader(200)
w.Write([]byte("processed"))
})
t.Run("any_value", func(t *testing.T) {
var rules Rules
err := parseRules(`
- on: resp_header X-Response-Header
do: |
error 405 "error"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 405, w.Code)
assert.Equal(t, "error\n", w.Body.String())
})
t.Run("with_value", func(t *testing.T) {
var rules Rules
err := parseRules(`
- on: resp_header X-Response-Header "response header"
do: |
error 405 "error"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 405, w.Code)
assert.Equal(t, "error\n", w.Body.String())
})
t.Run("with_value_not_matched", func(t *testing.T) {
var rules Rules
err := parseRules(`
- on: resp_header X-Response-Header "not-matched"
do: |
error 405 "error"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "processed", w.Body.String())
})
}
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("complex processed"))
})
var rules Rules
err := parseRules(`
- name: admin-api
on: |
path glob(/api/admin/*)
header Authorization
method POST
do: |
set resp_header X-Access-Level "admin"
set resp_header X-API-Version "v1"
- name: user-api
on: |
path glob(/api/users/*) & method GET
do: |
set resp_header X-Access-Level "user"
set resp_header X-API-Version "v1"
- name: public-api
on: |
path glob(/api/public/*) & method GET
do: |
set resp_header X-Access-Level "public"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
// Test admin API (should match first rule)
req1 := httptest.NewRequest("POST", "/api/admin/users", nil)
req1.Header.Set("Authorization", "Bearer token")
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
assert.Equal(t, 200, w1.Code)
assert.Equal(t, "admin", w1.Header().Get("X-Access-Level"))
assert.Equal(t, "v1", w1.Header()["X-API-Version"][0])
// Test user API (should match second rule)
req2 := httptest.NewRequest("GET", "/api/users/profile", nil)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
assert.Equal(t, 200, w2.Code)
assert.Equal(t, "user", w2.Header().Get("X-Access-Level"))
assert.Equal(t, "v1", w2.Header()["X-API-Version"][0])
// Test public API (should match third rule)
req3 := httptest.NewRequest("GET", "/api/public/info", nil)
w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3)
assert.Equal(t, 200, w3.Code)
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
assert.Empty(t, w3.Header()["X-API-Version"])
}
func TestHTTPFlow_ResponseModifier(t *testing.T) {
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("original response"))
})
var rules Rules
err := parseRules(`
- name: modify-response
do: |
set resp_header X-Modified "true"
set resp_body "Modified: {{ .Request.Method }} {{ .Request.URL.Path }}"
`, &rules)
require.NoError(t, err)
handler := rules.BuildHandler(upstream)
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Equal(t, "true", w.Header().Get("X-Modified"))
assert.Equal(t, "Modified: GET /test\n", w.Body.String())
}

View File

@@ -0,0 +1,36 @@
package rules
import (
"io"
"os"
"github.com/yusing/godoxy/internal/logging/accesslog"
gperr "github.com/yusing/goutils/errs"
)
type noopWriteCloser struct {
io.Writer
}
func (n noopWriteCloser) Close() error {
return nil
}
var (
stdout io.WriteCloser = noopWriteCloser{os.Stdout}
stderr io.WriteCloser = noopWriteCloser{os.Stderr}
)
func openFile(path string) (io.WriteCloser, gperr.Error) {
switch path {
case "/dev/stdout":
return stdout, nil
case "/dev/stderr":
return stderr, nil
}
f, err := accesslog.NewFileIO(path)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return f, nil
}

View File

@@ -0,0 +1,120 @@
package rules
import (
"regexp"
"strings"
"github.com/gobwas/glob"
gperr "github.com/yusing/goutils/errs"
)
type (
Matcher func(string) bool
MatcherType string
)
const (
MatcherTypeString MatcherType = "string"
MatcherTypeGlob MatcherType = "glob"
MatcherTypeRegex MatcherType = "regex"
)
func unquoteExpr(s string) (string, gperr.Error) {
if s == "" {
return "", nil
}
switch s[0] {
case '"', '\'', '`':
if s[0] != s[len(s)-1] {
return "", ErrUnterminatedQuotes
}
return s[1 : len(s)-1], nil
default:
return s, nil
}
}
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
idx := strings.IndexByte(s, '(')
if idx == -1 {
return MatcherTypeString, s, nil
}
idxEnd := strings.LastIndexByte(s, ')')
if idxEnd == -1 {
return "", "", ErrUnterminatedBrackets
}
expr, err = unquoteExpr(s[idx+1 : idxEnd])
if err != nil {
return "", "", err
}
matcherType = MatcherType(strings.ToLower(s[:idx]))
switch matcherType {
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
return
default:
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
}
}
func ParseMatcher(expr string) (Matcher, gperr.Error) {
negate := false
if strings.HasPrefix(expr, "!") {
negate = true
expr = expr[1:]
}
t, expr, err := ExtractExpr(expr)
if err != nil {
return nil, err
}
switch t {
case MatcherTypeString:
return StringMatcher(expr, negate)
case MatcherTypeGlob:
return GlobMatcher(expr, negate)
case MatcherTypeRegex:
return RegexMatcher(expr, negate)
}
// won't reach here
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
}
func StringMatcher(s string, negate bool) (Matcher, gperr.Error) {
if negate {
return func(s2 string) bool {
return s != s2
}, nil
}
return func(s2 string) bool {
return s == s2
}, nil
}
func GlobMatcher(expr string, negate bool) (Matcher, gperr.Error) {
g, err := glob.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if negate {
return func(s string) bool {
return !g.Match(s)
}, nil
}
return g.Match, nil
}
func RegexMatcher(expr string, negate bool) (Matcher, gperr.Error) {
re, err := regexp.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if negate {
return func(s string) bool {
return !re.MatchString(s)
}, nil
}
return re.MatchString, nil
}

View File

@@ -0,0 +1,35 @@
package rules
import "testing"
func BenchmarkMatcher(b *testing.B) {
b.Run("StringMatcher", func(b *testing.B) {
matcher, err := StringMatcher("foo", false)
if err != nil {
b.Fatal(err)
}
for b.Loop() {
matcher("foo")
}
})
b.Run("GlobMatcher", func(b *testing.B) {
matcher, err := GlobMatcher("foo*bar?baz*[abc]*.txt", false)
if err != nil {
b.Fatal(err)
}
for b.Loop() {
matcher("foooooobarzbazcb.txt")
}
})
b.Run("RegexMatcher", func(b *testing.B) {
matcher, err := RegexMatcher(`^(foo\d+|bar(_baz)?)[a-z]{3,}\.txt$`, false)
if err != nil {
b.Fatal(err)
}
for b.Loop() {
matcher("foo123abcd.txt")
}
})
}

View File

@@ -49,6 +49,18 @@ func TestExtractExpr(t *testing.T) {
wantT: MatcherTypeRegex,
wantExpr: "^[A-Z]+$",
},
{
name: "regex with parentheses",
in: "regex(test(group))",
wantT: MatcherTypeRegex,
wantExpr: "test(group)",
},
{
name: "regex complex",
in: `regex("^(_next/static|_next/image|favicon.ico).*$")`,
wantT: MatcherTypeRegex,
wantExpr: "^(_next/static|_next/image|favicon.ico).*$",
},
{
name: "quoted expr",
in: "glob(`'foo'`)",
@@ -96,3 +108,62 @@ func TestExtractExprInvalid(t *testing.T) {
})
}
}
func TestNegated(t *testing.T) {
tests := []struct {
name string
expr string
in string
want bool
}{
{
name: "negated_string_match",
expr: "!string(`foo`)",
in: "foo",
want: false,
},
{
name: "negated_string_no_match",
expr: "!string(`foo`)",
in: "bar",
want: true,
},
{
name: "negated_glob_match",
expr: "!glob(`foo`)",
in: "foo",
want: false,
},
{
name: "negated_glob_no_match",
expr: "!glob(`foo`)",
in: "bar",
want: true,
},
{
name: "negated_regex_match",
expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)",
in: "favicon.ico",
want: false,
},
{
name: "negated_regex_no_match",
expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)",
in: "bar",
want: true,
},
{
name: "negated_regex_no_match2",
expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)",
in: "/",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matcher, err := ParseMatcher(tt.expr)
expect.NoError(t, err)
expect.Equal(t, tt.want, matcher(tt.in))
})
}
}

View File

@@ -8,16 +8,20 @@ import (
"github.com/yusing/godoxy/internal/route/routes"
gperr "github.com/yusing/goutils/errs"
strutils "github.com/yusing/goutils/strings"
)
type RuleOn struct {
raw string
checker Checker
raw string
checker Checker
isResponseChecker bool
}
func (on *RuleOn) Check(cached Cache, r *http.Request) bool {
return on.checker.Check(cached, r)
func (on *RuleOn) IsResponseChecker() bool {
return on.isResponseChecker
}
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
return on.checker.Check(w, r)
}
const (
@@ -32,20 +36,27 @@ const (
OnRemote = "remote"
OnBasicAuth = "basic_auth"
OnRoute = "route"
// on response
OnResponseHeader = "resp_header"
OnStatus = "status"
)
var checkers = map[string]struct {
help Help
validate ValidateFunc
builder func(args any) CheckFunc
help Help
validate ValidateFunc
builder func(args any) CheckFunc
isResponseChecker bool
}{
OnHeader: {
help: Help{
command: OnHeader,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
header username "user"
header username glob("user*")
header username regex("user.*")`,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnHeader, "username", "user"),
helpExample(OnHeader, "username", helpFuncCall("glob", "user*")),
helpExample(OnHeader, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the header key",
"[value]": "the header value",
@@ -55,22 +66,52 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return len(r.Header[k]) > 0
}
}
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(r.Header[k], matcher)
}
},
},
OnResponseHeader: {
isResponseChecker: true,
help: Help{
command: OnResponseHeader,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnResponseHeader, "username", "user"),
helpExample(OnResponseHeader, "username", helpFuncCall("glob", "user*")),
helpExample(OnResponseHeader, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the response header key",
"[value]": "the response header value",
},
},
validate: toKVOptionalVMatcher,
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(w http.ResponseWriter, r *http.Request) bool {
return len(GetInitResponseModifier(w).Header()[k]) > 0
}
}
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(GetInitResponseModifier(w).Header()[k], matcher)
}
},
},
OnQuery: {
help: Help{
command: OnQuery,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
query username "user"
query username glob("user*")
query username regex("user.*")`,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnQuery, "username", "user"),
helpExample(OnQuery, "username", helpFuncCall("glob", "user*")),
helpExample(OnQuery, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the query key",
"[value]": "the query value",
@@ -80,22 +121,24 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return len(cached.GetQueries(r)[k]) > 0
return func(w http.ResponseWriter, r *http.Request) bool {
return len(GetSharedData(w).GetQueries(r)[k]) > 0
}
}
return func(cached Cache, r *http.Request) bool {
return slices.ContainsFunc(cached.GetQueries(r)[k], matcher)
return func(w http.ResponseWriter, r *http.Request) bool {
return slices.ContainsFunc(GetSharedData(w).GetQueries(r)[k], matcher)
}
},
},
OnCookie: {
help: Help{
command: OnCookie,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
cookie username "user"
cookie username glob("user*")
cookie username regex("user.*")`,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnCookie, "username", "user"),
helpExample(OnCookie, "username", helpFuncCall("glob", "user*")),
helpExample(OnCookie, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the cookie key",
"[value]": "the cookie value",
@@ -105,8 +148,8 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
cookies := cached.GetCookies(r)
return func(w http.ResponseWriter, r *http.Request) bool {
cookies := GetSharedData(w).GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
return true
@@ -115,8 +158,8 @@ var checkers = map[string]struct {
return false
}
}
return func(cached Cache, r *http.Request) bool {
cookies := cached.GetCookies(r)
return func(w http.ResponseWriter, r *http.Request) bool {
cookies := GetSharedData(w).GetCookies(r)
for _, cookie := range cookies {
if cookie.Name == k {
if matcher(cookie.Value) {
@@ -131,10 +174,12 @@ var checkers = map[string]struct {
OnForm: {
help: Help{
command: OnForm,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
form username "user"
form username glob("user*")
form username regex("user.*")`,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnForm, "username", "user"),
helpExample(OnForm, "username", helpFuncCall("glob", "user*")),
helpExample(OnForm, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the form key",
"[value]": "the form value",
@@ -144,11 +189,11 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return r.FormValue(k) != ""
}
}
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return matcher(r.FormValue(k))
}
},
@@ -156,10 +201,12 @@ var checkers = map[string]struct {
OnPostForm: {
help: Help{
command: OnPostForm,
description: `Value supports string, glob pattern, or regex pattern, e.g.:
postform username "user"
postform username glob("user*")
postform username regex("user.*")`,
description: makeLines(
"Value supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnPostForm, "username", "user"),
helpExample(OnPostForm, "username", helpFuncCall("glob", "user*")),
helpExample(OnPostForm, "username", helpFuncCall("regex", "user.*")),
),
args: map[string]string{
"key": "the form key",
"[value]": "the form value",
@@ -169,11 +216,11 @@ var checkers = map[string]struct {
builder: func(args any) CheckFunc {
k, matcher := args.(*MapValueMatcher).Unpack()
if matcher == nil {
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return r.PostFormValue(k) != ""
}
}
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return matcher(r.PostFormValue(k))
}
},
@@ -188,7 +235,7 @@ var checkers = map[string]struct {
validate: validateMethod,
builder: func(args any) CheckFunc {
method := args.(string)
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return r.Method == method
}
},
@@ -196,11 +243,13 @@ var checkers = map[string]struct {
OnHost: {
help: Help{
command: OnHost,
description: `Supports string, glob pattern, or regex pattern, e.g.:
host example.com
host glob(example*.com)
host regex(example\w+\.com)
host regex(example\.com$)`,
description: makeLines(
"Supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnHost, "example.com"),
helpExample(OnHost, helpFuncCall("glob", "example*.com")),
helpExample(OnHost, helpFuncCall("regex", `(example\w+\.com)`)),
helpExample(OnHost, helpFuncCall("regex", `example\.com$`)),
),
args: map[string]string{
"host": "the host name",
},
@@ -208,7 +257,7 @@ var checkers = map[string]struct {
validate: validateSingleMatcher,
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
return matcher(r.Host)
}
},
@@ -216,11 +265,13 @@ var checkers = map[string]struct {
OnPath: {
help: Help{
command: OnPath,
description: `Supports string, glob pattern, or regex pattern, e.g.:
path /path/to
path glob(/path/to/*)
path regex(^/path/to/.*$)
path regex(/path/[A-Z]+/)`,
description: makeLines(
"Supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnPath, "/path/to"),
helpExample(OnPath, helpFuncCall("glob", "/path/to/*")),
helpExample(OnPath, helpFuncCall("regex", `^/path/to/.*$`)),
helpExample(OnPath, helpFuncCall("regex", `/path/[A-Z]+/`)),
),
args: map[string]string{
"path": "the request path",
},
@@ -228,7 +279,7 @@ var checkers = map[string]struct {
validate: validateURLPathMatcher,
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(cached Cache, r *http.Request) bool {
return func(w http.ResponseWriter, r *http.Request) bool {
reqPath := r.URL.Path
if len(reqPath) > 0 && reqPath[0] != '/' {
reqPath = "/" + reqPath
@@ -250,16 +301,16 @@ var checkers = map[string]struct {
// for /32 (IPv4) or /128 (IPv6), just compare the IP
if ones, bits := ipnet.Mask.Size(); ones == bits {
wantIP := ipnet.IP
return func(cached Cache, r *http.Request) bool {
ip := cached.GetRemoteIP(r)
return func(w http.ResponseWriter, r *http.Request) bool {
ip := GetSharedData(w).GetRemoteIP(r)
if ip == nil {
return false
}
return ip.Equal(wantIP)
}
}
return func(cached Cache, r *http.Request) bool {
ip := cached.GetRemoteIP(r)
return func(w http.ResponseWriter, r *http.Request) bool {
ip := GetSharedData(w).GetRemoteIP(r)
if ip == nil {
return false
}
@@ -278,18 +329,20 @@ var checkers = map[string]struct {
validate: validateUserBCryptPassword,
builder: func(args any) CheckFunc {
cred := args.(*HashedCrendentials)
return func(cached Cache, r *http.Request) bool {
return cred.Match(cached.GetBasicAuth(r))
return func(w http.ResponseWriter, r *http.Request) bool {
return cred.Match(GetSharedData(w).GetBasicAuth(r))
}
},
},
OnRoute: {
help: Help{
command: OnRoute,
description: `Supports string, glob pattern, or regex pattern, e.g.:
route example
route glob(example*)
route regex(example\w+)`,
description: makeLines(
"Supports string, glob pattern, or regex pattern, e.g.:",
helpExample(OnRoute, "example"),
helpExample(OnRoute, helpFuncCall("glob", "example*")),
helpExample(OnRoute, helpFuncCall("regex", "example\\w+")),
),
args: map[string]string{
"route": "the route name",
},
@@ -297,11 +350,43 @@ var checkers = map[string]struct {
validate: validateSingleMatcher,
builder: func(args any) CheckFunc {
matcher := args.(Matcher)
return func(_ Cache, r *http.Request) bool {
return func(_ http.ResponseWriter, r *http.Request) bool {
return matcher(routes.TryGetUpstreamName(r))
}
},
},
OnStatus: {
isResponseChecker: true,
help: Help{
command: OnStatus,
description: makeLines(
"Supported formats are:",
helpExample(OnStatus, "<status>"),
helpExample(OnStatus, "<status>-<status>"),
helpExample(OnStatus, "1xx"),
helpExample(OnStatus, "2xx"),
helpExample(OnStatus, "3xx"),
helpExample(OnStatus, "4xx"),
helpExample(OnStatus, "5xx"),
),
args: map[string]string{
"status": "the status code range",
},
},
validate: validateStatusRange,
builder: func(args any) CheckFunc {
beg, end := args.(*IntTuple).Unpack()
if beg == end {
return func(w http.ResponseWriter, _ *http.Request) bool {
return GetInitResponseModifier(w).StatusCode() == beg
}
}
return func(w http.ResponseWriter, _ *http.Request) bool {
statusCode := GetInitResponseModifier(w).StatusCode()
return statusCode >= beg && statusCode <= end
}
},
},
}
var (
@@ -367,6 +452,66 @@ func splitAnd(s string) []string {
return a[:i]
}
// splitPipe splits a string by "|" but respects quotes, brackets, and escaped characters.
// It's similar to the parser.go logic but specifically for pipe splitting.
func splitPipe(s string) []string {
if s == "" {
return []string{}
}
var result []string
var current strings.Builder
escaped := false
quote := rune(0)
brackets := 0
for _, r := range s {
if escaped {
current.WriteRune(r)
escaped = false
continue
}
switch r {
case '\\':
escaped = true
current.WriteRune(r)
case '"', '\'', '`':
if quote == 0 && brackets == 0 {
quote = r
} else if r == quote {
quote = 0
}
current.WriteRune(r)
case '(':
brackets++
current.WriteRune(r)
case ')':
if brackets > 0 {
brackets--
}
current.WriteRune(r)
case '|':
if quote == 0 && brackets == 0 {
// Found a pipe outside quotes/brackets, split here
result = append(result, strings.TrimSpace(current.String()))
current.Reset()
} else {
current.WriteRune(r)
}
default:
current.WriteRune(r)
}
}
// Add the last part
if current.Len() > 0 {
result = append(result, strings.TrimSpace(current.String()))
}
return result
}
// Parse implements strutils.Parser.
func (on *RuleOn) Parse(v string) error {
on.raw = v
@@ -375,19 +520,24 @@ func (on *RuleOn) Parse(v string) error {
checkAnd := make(CheckMatchAll, 0, len(rules))
errs := gperr.NewBuilder("rule.on syntax errors")
isResponseChecker := false
for i, rule := range rules {
if rule == "" {
continue
}
parsed, err := parseOn(rule)
parsed, isResp, err := parseOn(rule)
if err != nil {
errs.Add(err.Subjectf("line %d", i+1))
continue
}
if isResp {
isResponseChecker = true
}
checkAnd = append(checkAnd, parsed)
}
on.checker = checkAnd
on.isResponseChecker = isResponseChecker
return errs.Error()
}
@@ -399,40 +549,57 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
return []byte(on.String()), nil
}
func parseOn(line string) (Checker, gperr.Error) {
ors := strutils.SplitRune(line, '|')
func parseOn(line string) (Checker, bool, gperr.Error) {
ors := splitPipe(line)
if len(ors) > 1 {
errs := gperr.NewBuilder("rule.on syntax errors")
checkOr := make(CheckMatchSingle, len(ors))
isResponseChecker := false
for i, or := range ors {
curCheckers, err := parseOn(or)
curCheckers, isResp, err := parseOn(or)
if err != nil {
errs.Add(err)
continue
}
if isResp {
isResponseChecker = true
}
checkOr[i] = curCheckers.(CheckFunc)
}
if err := errs.Error(); err != nil {
return nil, err
return nil, false, err
}
return checkOr, nil
return checkOr, isResponseChecker, nil
}
subject, args, err := parse(line)
if err != nil {
return nil, err
return nil, false, err
}
negate := false
if strings.HasPrefix(subject, "!") {
negate = true
subject = subject[1:]
}
checker, ok := checkers[subject]
if !ok {
return nil, ErrInvalidOnTarget.Subject(subject)
return nil, false, ErrInvalidOnTarget.Subject(subject)
}
validArgs, err := checker.validate(args)
if err != nil {
return nil, err.Subject(subject).Withf("%s", checker.help.String())
return nil, false, err.Subject(subject).With(checker.help.Error())
}
return checker.builder(validArgs), nil
checkFunc := checker.builder(validArgs)
if negate {
origCheckFunc := checkFunc
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
return !origCheckFunc(w, r)
}
}
return checkFunc, checker.isResponseChecker, nil
}

View File

@@ -7,6 +7,86 @@ import (
expect "github.com/yusing/goutils/testing"
)
func TestSplitPipe(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{
name: "empty",
input: "",
want: []string{},
},
{
name: "single",
input: "rule",
want: []string{"rule"},
},
{
name: "simple_pipe",
input: "rule1 | rule2",
want: []string{"rule1", "rule2"},
},
{
name: "multiple_pipes",
input: "rule1 | rule2 | rule3",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "pipe_in_quotes",
input: `path regex("^(_next/static|_next/image|favicon.ico).*$")`,
want: []string{`path regex("^(_next/static|_next/image|favicon.ico).*$")`},
},
{
name: "pipe_in_single_quotes",
input: `path regex('^(_next/static|_next/image|favicon.ico).*$')`,
want: []string{`path regex('^(_next/static|_next/image|favicon.ico).*$')`},
},
{
name: "pipe_in_backticks",
input: "path regex(`^(_next/static|_next/image|favicon.ico).*$`)",
want: []string{"path regex(`^(_next/static|_next/image|favicon.ico).*$`)"},
},
{
name: "pipe_in_brackets",
input: "path regex(^(_next/static|_next/image|favicon.ico).*$)",
want: []string{"path regex(^(_next/static|_next/image|favicon.ico).*$)"},
},
{
name: "escaped_pipe",
input: `path regex("^(_next/static\|_next/image\|favicon.ico).*$")`,
want: []string{`path regex("^(_next/static\|_next/image\|favicon.ico).*$")`},
},
{
name: "mixed_quotes_and_pipes",
input: `rule1 | path regex("^(_next/static|_next/image|favicon.ico).*$") | rule3`,
want: []string{"rule1", `path regex("^(_next/static|_next/image|favicon.ico).*$")`, "rule3"},
},
{
name: "nested_brackets",
input: "path regex(^(foo|bar(baz|qux)).*$)",
want: []string{"path regex(^(foo|bar(baz|qux)).*$)"},
},
{
name: "spaces_around",
input: " rule1 | rule2 | rule3 ",
want: []string{"rule1", "rule2", "rule3"},
},
{
name: "empty_segments",
input: "rule1 || rule2 | | rule3",
want: []string{"rule1", "", "rule2", "", "rule3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitPipe(tt.input)
expect.Equal(t, got, tt.want)
})
}
}
func TestSplitAnd(t *testing.T) {
tests := []struct {
name string
@@ -179,6 +259,27 @@ func TestParseOn(t *testing.T) {
input: "route example1 example2",
wantErr: ErrExpectOneArg,
},
// pipe splitting tests
{
name: "pipe_simple",
input: "method GET | method POST",
wantErr: nil,
},
{
name: "pipe_in_quotes",
input: `path regex("^(_next/static|_next/image|favicon.ico).*$")`,
wantErr: nil,
},
{
name: "pipe_in_brackets",
input: "path regex(^(_next/static|_next/image|favicon.ico).*$)",
wantErr: nil,
},
{
name: "pipe_mixed",
input: `method GET | path regex("^(_next/static|_next/image|favicon.ico).*$") | header Authorization`,
wantErr: nil,
},
}
for _, tt := range tests {

View File

@@ -4,6 +4,7 @@ import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
@@ -47,6 +48,18 @@ func genCorrectnessTestCases(field string, genRequest func(k, v string) *http.Re
input: genRequest("bar", "abcd"),
want: false,
},
{
name: field + "_negated_match",
checker: "!" + field + " foo",
input: genRequest("foo", "bar"),
want: false,
},
{
name: field + "_negated_no_match",
checker: "!" + field + " foo",
input: genRequest("bar", "foo"),
want: true,
},
}
}
@@ -64,6 +77,18 @@ func TestOnCorrectness(t *testing.T) {
input: &http.Request{Method: http.MethodPost},
want: false,
},
{
name: "method_negated_match",
checker: "!method GET",
input: &http.Request{Method: http.MethodGet},
want: false,
},
{
name: "method_negated_no_match",
checker: "!method GET",
input: &http.Request{Method: http.MethodPost},
want: true,
},
{
name: "host_match",
checker: "host example.com",
@@ -80,6 +105,22 @@ func TestOnCorrectness(t *testing.T) {
},
want: false,
},
{
name: "host_negated_match",
checker: "!host example.com",
input: &http.Request{
Host: "example.com",
},
want: false,
},
{
name: "host_negated_no_match",
checker: "!host example.com",
input: &http.Request{
Host: "example.org",
},
want: true,
},
{
name: "path_exact_match",
checker: "path /example",
@@ -88,6 +129,22 @@ func TestOnCorrectness(t *testing.T) {
},
want: true,
},
{
name: "path_negated_match",
checker: "!path /example",
input: &http.Request{
URL: &url.URL{Path: "/example"},
},
want: false,
},
{
name: "path_negated_no_match",
checker: "!path /example",
input: &http.Request{
URL: &url.URL{Path: "/example/foo"},
},
want: true,
},
{
name: "remote_match",
checker: "remote 192.168.1.0/24",
@@ -96,6 +153,22 @@ func TestOnCorrectness(t *testing.T) {
},
want: true,
},
{
name: "remote_negated_match",
checker: "!remote 192.168.1.0/24",
input: &http.Request{
RemoteAddr: "192.168.1.5",
},
want: false,
},
{
name: "remote_negated_no_match",
checker: "!remote 192.168.1.0/24",
input: &http.Request{
RemoteAddr: "192.168.2.5",
},
want: true,
},
{
name: "remote_no_match",
checker: "remote 192.168.1.0/24",
@@ -124,6 +197,26 @@ func TestOnCorrectness(t *testing.T) {
},
want: false,
},
{
name: "basic_auth_negated_match",
checker: "!basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
},
},
want: false,
},
{
name: "basic_auth_negated_no_match",
checker: "!basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
input: &http.Request{
Header: http.Header{
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
},
},
want: true,
},
{
name: "route_match",
checker: "route example",
@@ -141,6 +234,23 @@ func TestOnCorrectness(t *testing.T) {
},
want: false,
},
{
name: "route_negated_match",
checker: "!route example",
input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{
Alias: "example",
Root: "/",
}))),
want: false,
},
{
name: "route_negated_no_match",
checker: "!route example",
input: &http.Request{
Header: http.Header{},
},
want: true,
},
{
name: "regex_match",
checker: `host regex(example\w+\.com)`,
@@ -157,6 +267,22 @@ func TestOnCorrectness(t *testing.T) {
},
want: false,
},
{
name: "regex_negated_match",
checker: `!host regex(example\w+\.com)`,
input: &http.Request{
Host: "example.org",
},
want: true,
},
{
name: "regex_negated_no_match",
checker: `!host regex(example\w+\.com)`,
input: &http.Request{
Host: "exampleabc.com",
},
want: false,
},
{
name: "glob match",
checker: `host glob(*.example.com)`,
@@ -181,6 +307,22 @@ func TestOnCorrectness(t *testing.T) {
},
want: false,
},
{
name: "glob negated_match",
checker: `!host glob(*.example.com)`,
input: &http.Request{
Host: "example.com",
},
want: true,
},
{
name: "glob negated_no_match",
checker: `!host glob(*.example.com)`,
input: &http.Request{
Host: "a.example.com",
},
want: false,
},
}
tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request {
@@ -219,10 +361,11 @@ func TestOnCorrectness(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
var on RuleOn
err := on.Parse(tt.checker)
expect.NoError(t, err)
got := on.Check(Cache{}, tt.input)
got := on.Check(w, tt.input)
expect.Equal(t, tt.want, got, fmt.Sprintf("expect %s to %v", tt.checker, tt.want))
})
}

View File

@@ -3,9 +3,9 @@ package rules
import (
"bytes"
"fmt"
"os"
"unicode"
"github.com/yusing/goutils/env"
gperr "github.com/yusing/goutils/errs"
)
@@ -33,7 +33,7 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
brackets := 0
var envVar bytes.Buffer
var missingEnvVars bytes.Buffer
var missingEnvVars []string
inEnvVar := false
expectingBrace := false
@@ -70,6 +70,10 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
escaped = false
continue
}
if expectingBrace && r != '{' && r != '$' { // not escaped and not env var
buf.WriteRune('$')
expectingBrace = false
}
switch r {
case '\\':
escaped = true
@@ -90,9 +94,11 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
}
case '}':
if inEnvVar {
envValue, ok := os.LookupEnv(envVar.String())
// NOTE: use env.LookupEnv instead of os.LookupEnv to support environment variable prefixes
// like ${API_ADDR} will lookup for GODOXY_API_ADDR, GOPROXY_API_ADDR and API_ADDR.
envValue, ok := env.LookupEnv(envVar.String())
if !ok {
fmt.Fprintf(&missingEnvVars, "%q, ", envVar.String())
missingEnvVars = append(missingEnvVars, envVar.String())
} else {
buf.WriteString(envValue)
}
@@ -140,15 +146,21 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
}
}
if expectingBrace {
buf.WriteRune('$')
}
if quote != 0 {
err = ErrUnterminatedQuotes
} else if brackets != 0 {
err = ErrUnterminatedBrackets
} else if inEnvVar {
err = ErrUnterminatedEnvVar
} else {
flush(false)
}
if missingEnvVars.Len() > 0 {
err = gperr.Join(err, ErrEnvVarNotFound.Subject(missingEnvVars.String()))
if len(missingEnvVars) > 0 {
err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...)))
}
return subject, args, err
}

View File

@@ -49,9 +49,9 @@ func TestParser(t *testing.T) {
},
{
name: "regex_escaped",
input: `foo regex(\b\B\s\S\w\W\d\D\$\.)`,
input: `foo regex(\b\B\s\S\w\W\d\D\$\.\(\)\{\}\|\?\"\')`,
subject: "foo",
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.)`},
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.\(\)\{\}\|\?"')`},
},
{
name: "quote inside argument",
@@ -71,6 +71,12 @@ func TestParser(t *testing.T) {
subject: "foo",
args: []string{"glob(\"`/**/to/path`\")"},
},
{
name: "complex_regex",
input: `path !regex("^(_next/static|_next/image|favicon.ico).*$")`,
subject: "path",
args: []string{`!regex("^(_next/static|_next/image|favicon.ico).*$")`},
},
{
name: "chaos",
input: `error 403 "Forbidden "foo" "bar""`,
@@ -170,6 +176,53 @@ func TestParser(t *testing.T) {
})
}
})
t.Run("negated", func(t *testing.T) {
test := `!error 403 "Forbidden"`
subject, args, err := parse(test)
expect.NoError(t, err)
expect.Equal(t, subject, "!error")
expect.Equal(t, args, []string{"403", "Forbidden"})
})
}
func TestFullParse(t *testing.T) {
input := `
- name: login page
on: path /login
do: pass
- name: require auth
on: path !regex("^(_next/static|_next/image|favicon.ico).*$")
do: require_auth
- name: redirect to login
on: status 401 | status 403
do: proxy /login
- name: proxy to backend
on: path glob("/api/v1/*")
do: proxy http://localhost:8999/
- name: proxy to backend (old /auth)
on: path glob("/auth/*")
do: proxy http://localhost:8999/api/v1/`
var rules Rules
err := parseRules(input, &rules)
expect.NoError(t, err)
expect.Equal(t, len(rules), 5)
expect.Equal(t, rules[0].Name, "login page")
expect.Equal(t, rules[0].On.String(), "path /login")
expect.Equal(t, rules[0].Do.String(), "pass")
expect.Equal(t, rules[1].Name, "require auth")
expect.Equal(t, rules[1].On.String(), `path !regex("^(_next/static|_next/image|favicon.ico).*$")`)
expect.Equal(t, rules[1].Do.String(), "require_auth")
expect.Equal(t, rules[2].Name, "redirect to login")
expect.Equal(t, rules[2].On.String(), "status 401 | status 403")
expect.Equal(t, rules[2].Do.String(), "proxy /login")
expect.Equal(t, rules[3].Name, "proxy to backend")
expect.Equal(t, rules[3].On.String(), `path glob("/api/v1/*")`)
expect.Equal(t, rules[3].Do.String(), "proxy http://localhost:8999/")
expect.Equal(t, rules[4].Name, "proxy to backend (old /auth)")
expect.Equal(t, rules[4].On.String(), `path glob("/auth/*")`)
expect.Equal(t, rules[4].Do.String(), "proxy http://localhost:8999/api/v1/")
}
func BenchmarkParser(b *testing.B) {

View File

@@ -0,0 +1,48 @@
package rulepresets
import (
"embed"
"reflect"
"sync"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/route/rules"
"github.com/yusing/godoxy/internal/serialization"
)
//go:embed *.yml
var fs embed.FS
var rulePresets = make(map[string]rules.Rules)
var once sync.Once
func GetRulePreset(name string) (rules.Rules, bool) {
once.Do(initPresets)
rules, ok := rulePresets[name]
return rules, ok
}
// init all rule presetsl lazily
func initPresets() {
files, err := fs.ReadDir(".")
if err != nil {
log.Error().Err(err).Msg("failed to read rule presets")
return
}
for _, file := range files {
var rules rules.Rules
content, err := fs.ReadFile(file.Name())
if err != nil {
log.Error().Str("name", file.Name()).Err(err).Msg("failed to read rule preset")
continue
}
_, err = serialization.ConvertString(string(content), reflect.ValueOf(&rules))
if err != nil {
log.Error().Str("name", file.Name()).Err(err).Msg("failed to unmarshal rule preset")
continue
}
rulePresets[file.Name()] = rules
log.Debug().Str("name", file.Name()).Msg("loaded rule preset")
}
}

View File

@@ -0,0 +1,17 @@
- name: login page
on: path /login
do: pass
- name: protected
on: |
!path regex("(_next/static|_next/image|favicon.ico).*")
!path glob("/api/v1/auth/*")
!path /api/v1/version
do: require_auth
- name: proxy to backend
on: path glob("/api/v1/*")
do: proxy http://${API_ADDR}/
- name: proxy to auth api
on: path glob("/auth/*")
do: |
rewrite /auth /api/v1/auth
proxy http://${API_ADDR}/

View File

@@ -0,0 +1,173 @@
package rules
import (
"bufio"
"bytes"
"errors"
"net"
"net/http"
"strconv"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/synk"
)
type ResponseModifier struct {
w http.ResponseWriter
b []byte // the bytes got from pool
buf *bytes.Buffer
statusCode int
shared Cache
hijacked bool
errs gperr.Builder
}
type Response struct {
StatusCode int
Header http.Header
}
var pool = synk.GetBytesPoolWithUniqueMemory()
func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
for {
switch ww := w.(type) {
case *ResponseModifier:
return ww
case interface{ Unwrap() http.ResponseWriter }:
w = ww.Unwrap()
default:
return nil
}
}
}
// GetInitResponseModifier returns the response modifier for the given response writer.
// If the response writer is already wrapped, it will return the wrapped response modifier.
// Otherwise, it will return a new response modifier.
func GetInitResponseModifier(w http.ResponseWriter) *ResponseModifier {
if rm := unwrapResponseModifier(w); rm != nil {
return rm
}
return NewResponseModifier(w)
}
// GetSharedData returns the shared data for the given response writer.
// It will initialize the shared data if not initialized.
func GetSharedData(w http.ResponseWriter) Cache {
rm := GetInitResponseModifier(w)
if rm.shared == nil {
rm.shared = NewCache()
}
return rm.shared
}
// NewResponseModifier returns a new response modifier for the given response writer.
//
// It should only be called once, at the very beginning of the request.
func NewResponseModifier(w http.ResponseWriter) *ResponseModifier {
b := pool.Get()
return &ResponseModifier{
w: w,
buf: bytes.NewBuffer(b),
b: b,
}
}
// func (rm *ResponseModifier) Unwrap() http.ResponseWriter {
// return rm.w
// }
func (rm *ResponseModifier) WriteHeader(code int) {
rm.statusCode = code
}
func (rm *ResponseModifier) ResetBody() {
rm.buf.Reset()
}
func (rm *ResponseModifier) ContentLength() int {
return rm.buf.Len()
}
func (rm *ResponseModifier) StatusCode() int {
if rm.statusCode == 0 {
return http.StatusOK
}
return rm.statusCode
}
func (rm *ResponseModifier) Header() http.Header {
return rm.w.Header()
}
func (rm *ResponseModifier) Response() Response {
return Response{StatusCode: rm.StatusCode(), Header: rm.Header()}
}
func (rm *ResponseModifier) Write(b []byte) (int, error) {
return rm.buf.Write(b)
}
// AppendError appends an error to the response modifier
// the error will be formatted as "rule <rule.Name> error: <err>"
//
// It will be aggregated and returned in FlushRelease.
func (rm *ResponseModifier) AppendError(rule Rule, err error) {
rm.errs.Addf("rule %q error: %w", rule.Name, err)
}
func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := rm.w.(http.Hijacker); ok {
rm.hijacked = true
return hijacker.Hijack()
}
return nil, nil, errors.New("hijack not supported")
}
// FlushRelease flushes the response modifier and releases the resources
// it returns the number of bytes written and the aggregated error
// if there is any error (rule errors or write error), it will be returned
func (rm *ResponseModifier) FlushRelease() (int, error) {
n := 0
if !rm.hijacked {
h := rm.w.Header()
// for k := range h {
// if strings.EqualFold(k, "content-length") {
// h.Del(k)
// }
// }
h.Set("Content-Length", strconv.Itoa(rm.buf.Len()))
rm.w.WriteHeader(rm.StatusCode())
nn, werr := rm.w.Write(rm.buf.Bytes())
n += nn
if werr != nil {
rm.errs.Addf("write error: %w", werr)
}
// flush the response writer
if flusher, ok := rm.w.(http.Flusher); ok {
flusher.Flush()
} else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok {
ferr := errFlusher.Flush()
if ferr != nil {
rm.errs.Addf("flush error: %w", ferr)
}
}
}
// release the buffer and reset the pointers
pool.Put(rm.b)
rm.b = nil
rm.buf = nil
// release the shared data
if rm.shared != nil {
rm.shared.Release()
rm.shared = nil
}
return n, rm.errs.Error()
}

View File

@@ -1,9 +1,12 @@
package rules
import (
"errors"
"fmt"
"net/http"
"github.com/bytedance/sonic"
gperr "github.com/yusing/goutils/errs"
)
type (
@@ -46,15 +49,16 @@ type (
}
)
func (rule *Rule) IsResponseRule() bool {
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
}
// BuildHandler returns a http.HandlerFunc that implements the rules.
//
// if a bypass rule matches,
// the request is passed to the upstream and no more rules are executed.
//
// if no rule matches, the default rule is executed
// if no rule matches and default rule is not set,
// the request is passed to the upstream.
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
if len(rules) == 0 {
return up
}
defaultRule := Rule{
Name: "default",
Do: Command{
@@ -63,55 +67,168 @@ func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
},
}
nonDefaultRules := make(Rules, 0, len(rules))
for _, rule := range rules {
var nonDefaultRules Rules
hasDefaultRule := false
for i, rule := range rules {
if rule.Name == "default" {
defaultRule = rule
hasDefaultRule = true
} else {
// set name to index if name is empty
if rule.Name == "" {
rule.Name = fmt.Sprintf("rule[%d]", i)
}
nonDefaultRules = append(nonDefaultRules, rule)
}
}
if len(nonDefaultRules) == 0 {
if defaultRule.Do.isBypass() {
return up.ServeHTTP
return up
}
if defaultRule.IsResponseRule() {
return func(w http.ResponseWriter, r *http.Request) {
rm := NewResponseModifier(w)
w = rm
up(w, r)
err := defaultRule.Do.exec.Handle(w, r)
if err != nil && !errors.Is(err, errTerminated) {
rm.AppendError(defaultRule, err)
}
}
}
return func(w http.ResponseWriter, r *http.Request) {
cache := NewCache()
defer cache.Release()
if defaultRule.Do.exec.Handle(cache, w, r) {
up.ServeHTTP(w, r)
rm := NewResponseModifier(w)
w = rm
err := defaultRule.Do.exec.Handle(w, r)
if err == nil {
up(w, r)
return
}
if !errors.Is(err, errTerminated) {
rm.AppendError(defaultRule, err)
}
}
}
if len(nonDefaultRules) == 0 {
nonDefaultRules = rules
preRules := make(Rules, 0, len(nonDefaultRules))
postRules := make(Rules, 0, len(nonDefaultRules))
for _, rule := range nonDefaultRules {
if rule.IsResponseRule() {
postRules = append(postRules, rule)
} else {
preRules = append(preRules, rule)
}
}
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
return func(w http.ResponseWriter, r *http.Request) {
cache := NewCache()
defer cache.Release()
rm := NewResponseModifier(w)
defer func() {
if _, err := rm.FlushRelease(); err != nil {
gperr.LogError("error executing rules", err)
}
}()
for _, rule := range nonDefaultRules {
if rule.Check(cache, r) {
if rule.Do.isBypass() {
up.ServeHTTP(w, r)
return
w = rm
shouldCallUpstream := true
preMatched := false
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
if defaultRule.Do.isBypass() {
// continue to upstream
} else {
err := defaultRule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
rm.AppendError(defaultRule, err)
}
shouldCallUpstream = false
}
if !rule.Handle(cache, w, r) {
}
}
if shouldCallUpstream {
for _, rule := range preRules {
if rule.Check(w, r) {
preMatched = true
if rule.Do.isBypass() {
break // post rules should still execute
}
err := rule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
rm.AppendError(rule, err)
}
shouldCallUpstream = false
break
}
}
}
}
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
if defaultRule.Do.isBypass() {
// continue to upstream
} else {
err := defaultRule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
rm.AppendError(defaultRule, err)
return
}
shouldCallUpstream = false
}
}
}
if shouldCallUpstream {
up(w, r)
}
// if no post rules, we are done here
if len(postRules) == 0 && !isDefaultRulePost {
return
}
for _, rule := range postRules {
if rule.Check(w, r) {
err := rule.Handle(w, r)
if err != nil {
if !errors.Is(err, errTerminated) {
rm.AppendError(rule, err)
}
return
}
}
}
// bypass or proceed
if defaultRule.Do.isBypass() || defaultRule.Handle(cache, w, r) {
up.ServeHTTP(w, r)
if isDefaultRulePost {
err := defaultRule.Handle(w, r)
if err != nil && !errors.Is(err, errTerminated) {
rm.AppendError(defaultRule, err)
}
}
}
}
func isTerminatingHandler(handler CommandHandler) bool {
switch h := handler.(type) {
case TerminatingCommand:
return true
case Commands:
if len(h) == 0 {
return false
}
return isTerminatingHandler(h[len(h)-1])
default:
return false
}
}
func (rules Rules) MarshalJSON() ([]byte, error) {
names := make([]string, len(rules))
for i, rule := range rules {
@@ -124,11 +241,14 @@ func (rule *Rule) String() string {
return rule.Name
}
func (rule *Rule) Check(cached Cache, r *http.Request) bool {
return rule.On.checker.Check(cached, r)
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
if rule.On.checker == nil {
return true
}
v := rule.On.checker.Check(w, r)
return v
}
func (rule *Rule) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
proceed = rule.Do.exec.Handle(cached, w, r)
return proceed
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
return rule.Do.exec.Handle(w, r)
}

View File

@@ -0,0 +1,74 @@
package rules
import (
"bytes"
"io"
"net/http"
"net/url"
"testing"
)
func BenchmarkRules(b *testing.B) {
var rules Rules
err := parseRules(`
- name: admin-api
on: |
path glob(/api/admin/*)
header Authorization
method POST
do: |
set resp_header X-Access-Level "admin"
set resp_header X-API-Version "v1"
- name: user-api
on: |
path glob(/api/users/*) & method GET
do: |
set resp_header X-Access-Level "user"
set resp_header X-API-Version "v1"
- name: public-api
on: |
path glob(/api/public/*) & method GET
do: |
set resp_header X-Access-Level "public"
`, &rules)
if err != nil {
b.Fatal(err)
}
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
b.Run("BuildHandler", func(b *testing.B) {
for b.Loop() {
rules.BuildHandler(upstream)
}
})
b.Run("RunHandler", func(b *testing.B) {
var r = &http.Request{
Body: io.NopCloser(bytes.NewReader([]byte(""))),
URL: &url.URL{Path: "/api/users/"},
}
var w noopResponseWriter
handler := rules.BuildHandler(upstream)
b.ResetTimer()
for b.Loop() {
handler.ServeHTTP(w, r)
}
})
}
type noopResponseWriter struct {
}
func (w noopResponseWriter) Header() http.Header {
return http.Header{}
}
func (w noopResponseWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (w noopResponseWriter) WriteHeader(int) {
}

View File

@@ -1,46 +0,0 @@
package rules
import (
"testing"
"github.com/yusing/godoxy/internal/serialization"
expect "github.com/yusing/goutils/testing"
)
func TestParseRule(t *testing.T) {
test := []map[string]any{
{
"name": "test",
"on": "method POST",
"do": "error 403 Forbidden",
},
{
"name": "auth",
"on": `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`,
"do": "bypass",
},
{
"name": "default",
"do": "require_basic_auth any_realm",
},
}
var rules struct {
Rules Rules
}
err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
expect.NoError(t, err)
expect.Equal(t, len(rules.Rules), len(test))
expect.Equal(t, rules.Rules[0].Name, "test")
expect.Equal(t, rules.Rules[0].On.String(), "method POST")
expect.Equal(t, rules.Rules[0].Do.String(), "error 403 Forbidden")
expect.Equal(t, rules.Rules[1].Name, "auth")
expect.Equal(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`)
expect.Equal(t, rules.Rules[1].Do.String(), "bypass")
expect.Equal(t, rules.Rules[2].Name, "default")
expect.Equal(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm")
}
// TODO: real tests.

View File

@@ -0,0 +1,43 @@
package rules
import (
"bytes"
"io"
"net/http"
)
type templateOrStr interface {
Execute(w io.Writer, data any) error
}
type strTemplate string
func (t strTemplate) Execute(w io.Writer, _ any) error {
n, err := w.Write([]byte(t))
if err != nil {
return err
}
if n != len(t) {
return io.ErrShortWrite
}
return nil
}
type keyValueTemplate = Tuple[string, templateOrStr]
func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) {
var buf bytes.Buffer
err := tmpl.Execute(&buf, reqResponseTemplateData{Request: r})
if err != nil {
return "", err
}
return buf.String(), nil
}
func executeRequestTemplateTo(tmpl templateOrStr, o io.Writer, r *http.Request) error {
return tmpl.Execute(o, reqResponseTemplateData{Request: r})
}
func executeReqRespTemplateTo(tmpl templateOrStr, o io.Writer, w http.ResponseWriter, r *http.Request) error {
return tmpl.Execute(o, reqResponseTemplateData{Request: r, Response: GetInitResponseModifier(w).Response()})
}

View File

@@ -6,10 +6,11 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/gobwas/glob"
"github.com/rs/zerolog"
nettypes "github.com/yusing/godoxy/internal/net/types"
gperr "github.com/yusing/goutils/errs"
httputils "github.com/yusing/goutils/http"
@@ -21,6 +22,17 @@ type (
First T1
Second T2
}
Tuple3[T1, T2, T3 any] struct {
First T1
Second T2
Third T3
}
Tuple4[T1, T2, T3, T4 any] struct {
First T1
Second T2
Third T3
Fourth T4
}
StrTuple = Tuple[string, string]
IntTuple = Tuple[int, int]
MapValueMatcher = Tuple[string, Matcher]
@@ -30,97 +42,24 @@ func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
return t.First, t.Second
}
func (t *Tuple3[T1, T2, T3]) Unpack() (T1, T2, T3) {
return t.First, t.Second, t.Third
}
func (t *Tuple4[T1, T2, T3, T4]) Unpack() (T1, T2, T3, T4) {
return t.First, t.Second, t.Third, t.Fourth
}
func (t *Tuple[T1, T2]) String() string {
return fmt.Sprintf("%v:%v", t.First, t.Second)
}
type (
Matcher func(string) bool
MatcherType string
)
const (
MatcherTypeString MatcherType = "string"
MatcherTypeGlob MatcherType = "glob"
MatcherTypeRegex MatcherType = "regex"
)
func unquoteExpr(s string) (string, gperr.Error) {
if s == "" {
return "", nil
}
switch s[0] {
case '"', '\'', '`':
if s[0] != s[len(s)-1] {
return "", ErrUnterminatedQuotes
}
return s[1 : len(s)-1], nil
default:
return s, nil
}
func (t *Tuple3[T1, T2, T3]) String() string {
return fmt.Sprintf("%v:%v:%v", t.First, t.Second, t.Third)
}
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
idx := strings.IndexByte(s, '(')
if idx == -1 {
return MatcherTypeString, s, nil
}
idxEnd := strings.LastIndexByte(s, ')')
if idxEnd == -1 {
return "", "", ErrUnterminatedBrackets
}
expr, err = unquoteExpr(s[idx+1 : idxEnd])
if err != nil {
return "", "", err
}
matcherType = MatcherType(strings.ToLower(s[:idx]))
switch matcherType {
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
return
default:
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
}
}
func ParseMatcher(expr string) (Matcher, gperr.Error) {
t, expr, err := ExtractExpr(expr)
if err != nil {
return nil, err
}
switch t {
case MatcherTypeString:
return StringMatcher(expr)
case MatcherTypeGlob:
return GlobMatcher(expr)
case MatcherTypeRegex:
return RegexMatcher(expr)
}
// won't reach here
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
}
func StringMatcher(s string) (Matcher, gperr.Error) {
return func(s2 string) bool {
return s == s2
}, nil
}
func GlobMatcher(expr string) (Matcher, gperr.Error) {
g, err := glob.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return g.Match, nil
}
func RegexMatcher(expr string) (Matcher, gperr.Error) {
re, err := regexp.Compile(expr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return re.MatchString, nil
func (t *Tuple4[T1, T2, T3, T4]) String() string {
return fmt.Sprintf("%v:%v:%v:%v", t.First, t.Second, t.Third, t.Fourth)
}
// validateSingleMatcher returns Matcher with the matcher validated.
@@ -131,14 +70,6 @@ func validateSingleMatcher(args []string) (any, gperr.Error) {
return ParseMatcher(args[0])
}
// toStrTuple returns *StrTuple.
func toStrTuple(args []string) (any, gperr.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
return &StrTuple{args[0], args[1]}, nil
}
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
switch len(args) {
@@ -155,6 +86,18 @@ func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
}
}
func toKeyValueTemplate(args []string) (any, gperr.Error) {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
tmpl, err := validateTemplate(args[1], false)
if err != nil {
return nil, err
}
return &keyValueTemplate{args[0], tmpl}, nil
}
// validateURL returns types.URL with the URL validated.
func validateURL(args []string) (any, gperr.Error) {
if len(args) != 1 {
@@ -164,6 +107,12 @@ func validateURL(args []string) (any, gperr.Error) {
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
if u.Scheme == "" {
// expect relative URL, must starts with /
if !strings.HasPrefix(u.Path, "/") {
return nil, ErrInvalidArguments.Withf("relative URL must starts with /")
}
}
return u, nil
}
@@ -250,6 +199,57 @@ func validateMethod(args []string) (any, gperr.Error) {
return method, nil
}
func validateStatusCode(status string) (int, error) {
statusCode, err := strconv.Atoi(status)
if err != nil {
return 0, err
}
if statusCode < 100 || statusCode > 599 {
return 0, fmt.Errorf("status code out of range: %s", status)
}
return statusCode, nil
}
// validateStatusRange returns Tuple[int, int] with the status range validated.
// accepted formats are:
// - <status>
// - <status>-<status>
// - 1xx
// - 2xx
// - 3xx
// - 4xx
// - 5xx
func validateStatusRange(args []string) (any, gperr.Error) {
if len(args) != 1 {
return nil, ErrExpectOneArg
}
beg, end, ok := strings.Cut(args[0], "-")
if !ok { // <status>
end = beg
}
switch beg {
case "1xx":
return &IntTuple{100, 199}, nil
case "2xx":
return &IntTuple{200, 299}, nil
case "3xx":
return &IntTuple{300, 399}, nil
case "4xx":
return &IntTuple{400, 499}, nil
case "5xx":
return &IntTuple{500, 599}, nil
}
begInt, begErr := validateStatusCode(beg)
endInt, endErr := validateStatusCode(end)
if begErr != nil || endErr != nil {
return nil, ErrInvalidArguments.With(gperr.Join(begErr, endErr))
}
return &IntTuple{begInt, endInt}, nil
}
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
if len(args) != 2 {
@@ -260,20 +260,77 @@ func validateUserBCryptPassword(args []string) (any, gperr.Error) {
// validateModField returns CommandHandler with the field validated.
func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.Error) {
if len(args) == 0 {
return nil, ErrExpectTwoOrThreeArgs
}
setField, ok := modFields[args[0]]
if !ok {
return nil, ErrInvalidSetTarget.Subject(args[0])
return nil, ErrUnknownModField.Subject(args[0])
}
if mod == ModFieldRemove {
if len(args) != 2 {
return nil, ErrExpectTwoArgs
}
// setField expect validateStrTuple
args = append(args, "")
}
validArgs, err := setField.validate(args[1:])
if err != nil {
return nil, err.Withf(setField.help.String())
return nil, err.With(setField.help.Error())
}
modder := setField.builder(validArgs)
switch mod {
case ModFieldAdd:
return modder.add, nil
add := modder.add
if add == nil {
return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod)
}
return add, nil
case ModFieldRemove:
return modder.remove, nil
remove := modder.remove
if remove == nil {
return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod)
}
return remove, nil
}
return modder.set, nil
set := modder.set
if set == nil {
return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod)
}
return set, nil
}
func isTemplate(tmplStr string) bool {
return strings.Contains(tmplStr, "{{")
}
func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) {
if newline && !strings.HasSuffix(tmplStr, "\n") {
tmplStr += "\n"
}
if !isTemplate(tmplStr) {
return strTemplate(tmplStr), nil
}
tmpl, err := template.New("template").Parse(tmplStr)
if err != nil {
return nil, ErrInvalidArguments.With(err)
}
return tmpl, nil
}
func validateLevel(level string) (zerolog.Level, gperr.Error) {
l, err := zerolog.ParseLevel(level)
if err != nil {
return zerolog.NoLevel, ErrInvalidArguments.With(err)
}
return l, nil
}
// func validateNotifProvider(provider string) gperr.Error {
// if !notif.HasProvider(provider) {
// return ErrInvalidArguments.Subject(provider)
// }
// return nil
// }