mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
feat(rules): add post-request rules system with response manipulation (#160)
* Add comprehensive post-request rules support for response phase * Enable response body, status, and header manipulation via set commands * Refactor command handlers to support both request and response phases * Implement response modifier system for post-request template execution * Support response-based rule matching with status and header checks * Add comprehensive benchmarks for matcher performance * Refactor authentication and proxying commands for unified error handling * Support negated conditions with ! * Enhance error handling, error formatting and validation * Routes: add `rule_file` field with rule preset support * Environment variable substitution: now supports variables without `GODOXY_` prefix * new conditions: * `on resp_header <key> [<value>]` * `on status <status>` * new commands: * `require_auth` * `set resp_header <key> <template>` * `set resp_body <template>` * `set status <code>` * `log <level> <path> <template>` * `notify <level> <provider> <title_template> <body_template>`
This commit is contained in:
2
goutils
2
goutils
Submodule goutils updated: 26146bd560...e78e3c2d35
@@ -58,3 +58,13 @@ func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func AuthOrProceed(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
err := defaultAuth.CheckToken(r)
|
||||
if err != nil {
|
||||
defaultAuth.LoginHandler(w, r)
|
||||
return false
|
||||
} else {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func (cfg *ConfigBase) Validate() gperr.Error {
|
||||
// If only stdout is enabled, it returns nil, nil.
|
||||
func (cfg *ConfigBase) IO() (WriterWithName, error) {
|
||||
if cfg.Path != "" {
|
||||
io, err := newFileIO(cfg.Path)
|
||||
io, err := NewFileIO(cfg.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -26,7 +26,10 @@ var (
|
||||
openedFilesMu sync.Mutex
|
||||
)
|
||||
|
||||
func newFileIO(path string) (WriterWithName, error) {
|
||||
// NewFileIO creates a new file writer with cleaned path.
|
||||
//
|
||||
// If the file is already opened, it will be returned.
|
||||
func NewFileIO(path string) (WriterWithName, error) {
|
||||
openedFilesMu.Lock()
|
||||
defer openedFilesMu.Unlock()
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
file, err := newFileIO(cfg.Path)
|
||||
file, err := NewFileIO(cfg.Path)
|
||||
expect.NoError(t, err)
|
||||
accessLogIOs[index] = file
|
||||
}(i)
|
||||
|
||||
@@ -61,6 +61,26 @@ func NewLogger(out ...io.Writer) zerolog.Logger {
|
||||
).Level(level).With().Timestamp().Logger()
|
||||
}
|
||||
|
||||
func NewLoggerWithFixedLevel(level zerolog.Level, out ...io.Writer) zerolog.Logger {
|
||||
levelStr := level.String()
|
||||
writer := zerolog.ConsoleWriter{
|
||||
Out: zerolog.MultiLevelWriter(out...),
|
||||
TimeFormat: timeFmt,
|
||||
FormatMessage: func(msgI interface{}) string { // pad spaces for each line
|
||||
if msgI == nil {
|
||||
return ""
|
||||
}
|
||||
return fmtMessage(msgI.(string))
|
||||
},
|
||||
FormatLevel: func(_ any) string {
|
||||
return levelStr
|
||||
},
|
||||
}
|
||||
return zerolog.New(
|
||||
writer,
|
||||
).Level(level).With().Timestamp().Logger()
|
||||
}
|
||||
|
||||
func InitLogger(out ...io.Writer) {
|
||||
logger = NewLogger(out...)
|
||||
log.SetOutput(logger)
|
||||
|
||||
@@ -8,11 +8,9 @@ import (
|
||||
|
||||
type Bypass []rules.RuleOn
|
||||
|
||||
func (b Bypass) ShouldBypass(r *http.Request) bool {
|
||||
cached := rules.NewCache()
|
||||
defer cached.Release()
|
||||
func (b Bypass) ShouldBypass(w http.ResponseWriter, r *http.Request) bool {
|
||||
for _, rule := range b {
|
||||
if rule.Check(cached, r) {
|
||||
if rule.Check(w, r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -26,14 +24,14 @@ type checkBypass struct {
|
||||
}
|
||||
|
||||
func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) {
|
||||
if c.modReq == nil || c.bypass.ShouldBypass(r) {
|
||||
if c.modReq == nil || c.bypass.ShouldBypass(w, r) {
|
||||
return true
|
||||
}
|
||||
return c.modReq.before(w, r)
|
||||
}
|
||||
|
||||
func (c *checkBypass) modifyResponse(resp *http.Response) error {
|
||||
if c.modRes == nil || c.bypass.ShouldBypass(resp.Request) {
|
||||
func (c *checkBypass) modifyResponse(w http.ResponseWriter, resp *http.Response) error {
|
||||
if c.modRes == nil || c.bypass.ShouldBypass(w, resp.Request) {
|
||||
return nil
|
||||
}
|
||||
return c.modRes.modifyResponse(resp)
|
||||
|
||||
@@ -20,10 +20,11 @@ type (
|
||||
)
|
||||
|
||||
type (
|
||||
FieldsBody []LogField
|
||||
ListBody []string
|
||||
MessageBody string
|
||||
errorBody struct {
|
||||
FieldsBody []LogField
|
||||
ListBody []string
|
||||
MessageBody string
|
||||
MessageBodyBytes []byte
|
||||
errorBody struct {
|
||||
Error error
|
||||
}
|
||||
)
|
||||
@@ -98,7 +99,15 @@ func (m MessageBody) Format(format LogFormat) ([]byte, error) {
|
||||
case LogFormatRawJSON:
|
||||
return sonic.Marshal(m)
|
||||
}
|
||||
return m.Format(LogFormatMarkdown)
|
||||
return []byte(m), nil
|
||||
}
|
||||
|
||||
func (m MessageBodyBytes) Format(format LogFormat) ([]byte, error) {
|
||||
switch format {
|
||||
case LogFormatRawJSON:
|
||||
return sonic.Marshal(string(m))
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (e errorBody) Format(format LogFormat) ([]byte, error) {
|
||||
|
||||
@@ -128,7 +128,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
|
||||
}
|
||||
|
||||
if len(r.Rules) > 0 {
|
||||
r.handler = r.Rules.BuildHandler(r.handler)
|
||||
r.handler = r.Rules.BuildHandler(r.handler.ServeHTTP)
|
||||
}
|
||||
|
||||
if r.HealthMon != nil {
|
||||
|
||||
@@ -2,7 +2,11 @@ package route
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -17,6 +21,7 @@ import (
|
||||
netutils "github.com/yusing/godoxy/internal/net"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/godoxy/internal/proxmox"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
"github.com/yusing/godoxy/internal/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
@@ -25,6 +30,7 @@ import (
|
||||
"github.com/yusing/godoxy/internal/common"
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
rulepresets "github.com/yusing/godoxy/internal/route/rules/presets"
|
||||
route "github.com/yusing/godoxy/internal/route/types"
|
||||
"github.com/yusing/godoxy/internal/utils"
|
||||
)
|
||||
@@ -41,7 +47,8 @@ type (
|
||||
|
||||
route.HTTPConfig
|
||||
PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"`
|
||||
Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name" extension:"x-nullable"`
|
||||
Rules rules.Rules `json:"rules,omitempty" extension:"x-nullable"`
|
||||
RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"`
|
||||
HealthCheck *types.HealthCheckConfig `json:"healthcheck"`
|
||||
LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"`
|
||||
Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"`
|
||||
@@ -212,7 +219,10 @@ func (r *Route) Validate() gperr.Error {
|
||||
}
|
||||
}
|
||||
|
||||
errs := gperr.NewBuilder("entry validation failed")
|
||||
var errs gperr.Builder
|
||||
if err := r.validateRules(); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
|
||||
var impl types.Route
|
||||
var err gperr.Error
|
||||
@@ -267,6 +277,39 @@ func (r *Route) Validate() gperr.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) validateRules() error {
|
||||
if r.RuleFile != "" && len(r.Rules) > 0 {
|
||||
return errors.New("`rule_file` and `rules` cannot be used together")
|
||||
} else if r.RuleFile != "" {
|
||||
src, err := url.Parse(r.RuleFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse rule file url %q: %w", r.RuleFile, err)
|
||||
}
|
||||
switch src.Scheme {
|
||||
case "embed": // embed://<preset_file_name>
|
||||
rules, ok := rulepresets.GetRulePreset(src.Host)
|
||||
if !ok {
|
||||
return fmt.Errorf("rule preset %q not found", src.Host)
|
||||
} else {
|
||||
r.Rules = rules
|
||||
}
|
||||
case "file", "":
|
||||
content, err := os.ReadFile(src.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read rule file %q: %w", src.Path, err)
|
||||
} else {
|
||||
_, err = serialization.ConvertString(string(content), reflect.ValueOf(&r.Rules))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal rule file %q: %w", src.Path, err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported rule file scheme %q", src.Scheme)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Route) Impl() types.Route {
|
||||
return r.impl
|
||||
}
|
||||
|
||||
@@ -86,6 +86,13 @@ func TryGetUpstreamPort(r *http.Request) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamHostPort(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.Host
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func TryGetUpstreamAddr(r *http.Request) string {
|
||||
if u := tryGetURL(r); u != nil {
|
||||
return u.Host
|
||||
|
||||
@@ -15,13 +15,13 @@ type (
|
||||
)
|
||||
|
||||
const (
|
||||
CacheKeyQueries = "queries"
|
||||
CacheKeyCookies = "cookies"
|
||||
CacheKeyRemoteIP = "remote_ip"
|
||||
CacheKeyBasicAuth = "basic_auth"
|
||||
cacheKeyQueries = "queries"
|
||||
cacheKeyCookies = "cookies"
|
||||
cacheKeyRemoteIP = "remote_ip"
|
||||
cacheKeyBasicAuth = "basic_auth"
|
||||
)
|
||||
|
||||
var cachePool = &sync.Pool{
|
||||
var cachePool = sync.Pool{
|
||||
New: func() any {
|
||||
return make(Cache)
|
||||
},
|
||||
@@ -41,10 +41,10 @@ func (c Cache) Release() {
|
||||
// GetQueries returns the queries.
|
||||
// If r does not have queries, an empty map is returned.
|
||||
func (c Cache) GetQueries(r *http.Request) url.Values {
|
||||
v, ok := c[CacheKeyQueries]
|
||||
v, ok := c[cacheKeyQueries]
|
||||
if !ok {
|
||||
v = r.URL.Query()
|
||||
c[CacheKeyQueries] = v
|
||||
c[cacheKeyQueries] = v
|
||||
}
|
||||
return v.(url.Values)
|
||||
}
|
||||
@@ -58,17 +58,17 @@ func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) {
|
||||
// GetCookies returns the cookies.
|
||||
// If r does not have cookies, an empty slice is returned.
|
||||
func (c Cache) GetCookies(r *http.Request) []*http.Cookie {
|
||||
v, ok := c[CacheKeyCookies]
|
||||
v, ok := c[cacheKeyCookies]
|
||||
if !ok {
|
||||
v = r.Cookies()
|
||||
c[CacheKeyCookies] = v
|
||||
c[cacheKeyCookies] = v
|
||||
}
|
||||
return v.([]*http.Cookie)
|
||||
}
|
||||
|
||||
func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) {
|
||||
cookies := update(c.GetCookies(r))
|
||||
c[CacheKeyCookies] = cookies
|
||||
c[cacheKeyCookies] = cookies
|
||||
r.Header.Del("Cookie")
|
||||
for _, cookie := range cookies {
|
||||
r.AddCookie(cookie)
|
||||
@@ -78,14 +78,14 @@ func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie])
|
||||
// GetRemoteIP returns the remote ip address.
|
||||
// If r.RemoteAddr is not a valid ip address, nil is returned.
|
||||
func (c Cache) GetRemoteIP(r *http.Request) net.IP {
|
||||
v, ok := c[CacheKeyRemoteIP]
|
||||
v, ok := c[cacheKeyRemoteIP]
|
||||
if !ok {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
host = r.RemoteAddr
|
||||
}
|
||||
v = net.ParseIP(host)
|
||||
c[CacheKeyRemoteIP] = v
|
||||
c[cacheKeyRemoteIP] = v
|
||||
}
|
||||
return v.(net.IP)
|
||||
}
|
||||
@@ -93,14 +93,14 @@ func (c Cache) GetRemoteIP(r *http.Request) net.IP {
|
||||
// GetBasicAuth returns *Credentials the basic auth username and password.
|
||||
// If r does not have basic auth, nil is returned.
|
||||
func (c Cache) GetBasicAuth(r *http.Request) *Credentials {
|
||||
v, ok := c[CacheKeyBasicAuth]
|
||||
v, ok := c[cacheKeyBasicAuth]
|
||||
if !ok {
|
||||
u, p, ok := r.BasicAuth()
|
||||
if ok {
|
||||
v = &Credentials{u, []byte(p)}
|
||||
c[CacheKeyBasicAuth] = v
|
||||
c[cacheKeyBasicAuth] = v
|
||||
} else {
|
||||
c[CacheKeyBasicAuth] = nil
|
||||
c[cacheKeyBasicAuth] = nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,30 +3,30 @@ package rules
|
||||
import "net/http"
|
||||
|
||||
type (
|
||||
CheckFunc func(cached Cache, r *http.Request) bool
|
||||
CheckFunc func(w http.ResponseWriter, r *http.Request) bool
|
||||
Checker interface {
|
||||
Check(cached Cache, r *http.Request) bool
|
||||
Check(w http.ResponseWriter, r *http.Request) bool
|
||||
}
|
||||
CheckMatchSingle []Checker
|
||||
CheckMatchAll []Checker
|
||||
)
|
||||
|
||||
func (checker CheckFunc) Check(cached Cache, r *http.Request) bool {
|
||||
return checker(cached, r)
|
||||
func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
return checker(w, r)
|
||||
}
|
||||
|
||||
func (checkers CheckMatchSingle) Check(cached Cache, r *http.Request) bool {
|
||||
func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
for _, check := range checkers {
|
||||
if check.Check(cached, r) {
|
||||
if check.Check(w, r) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (checkers CheckMatchAll) Check(cached Cache, r *http.Request) bool {
|
||||
func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
for _, check := range checkers {
|
||||
if !check.Check(cached, r) {
|
||||
if !check.Check(w, r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,19 +3,21 @@ package rules
|
||||
import "net/http"
|
||||
|
||||
type (
|
||||
handlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
CommandHandler interface {
|
||||
// CommandHandler can read and modify the values
|
||||
// then handle the request
|
||||
// finally proceed to next command (or return) base on situation
|
||||
Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
Handle(w http.ResponseWriter, r *http.Request) error
|
||||
IsResponseHandler() bool
|
||||
}
|
||||
// NonTerminatingCommand will run then proceed to next command or reverse proxy.
|
||||
NonTerminatingCommand http.HandlerFunc
|
||||
NonTerminatingCommand handlerFunc
|
||||
// TerminatingCommand will run then return immediately.
|
||||
TerminatingCommand http.HandlerFunc
|
||||
// DynamicCommand will return base on the request
|
||||
// and can read or modify the values.
|
||||
DynamicCommand func(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||
TerminatingCommand handlerFunc
|
||||
// OnResponseCommand will run then return based on the response.
|
||||
OnResponseCommand handlerFunc
|
||||
// BypassCommand will skip all the following commands
|
||||
// and directly return to reverse proxy.
|
||||
BypassCommand struct{}
|
||||
@@ -23,29 +25,55 @@ type (
|
||||
Commands []CommandHandler
|
||||
)
|
||||
|
||||
func (c NonTerminatingCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
c(w, r)
|
||||
return true
|
||||
func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return c(w, r)
|
||||
}
|
||||
|
||||
func (c TerminatingCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
c(w, r)
|
||||
func (c NonTerminatingCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c DynamicCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
return c(cached, w, r)
|
||||
func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
if err := c(w, r); err != nil {
|
||||
return err
|
||||
}
|
||||
return errTerminated
|
||||
}
|
||||
|
||||
func (c BypassCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
func (c TerminatingCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return c(w, r)
|
||||
}
|
||||
|
||||
func (c OnResponseCommand) IsResponseHandler() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c Commands) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return errTerminated
|
||||
}
|
||||
|
||||
func (c BypassCommand) IsResponseHandler() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
for _, cmd := range c {
|
||||
if !cmd.Handle(cached, w, r) {
|
||||
return false
|
||||
if err := cmd.Handle(w, r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Commands) IsResponseHandler() bool {
|
||||
for _, cmd := range c {
|
||||
if cmd.IsResponseHandler() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,27 +1,41 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/godoxy/internal/auth"
|
||||
"github.com/yusing/godoxy/internal/logging"
|
||||
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/godoxy/internal/notif"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"github.com/yusing/goutils/http/reverseproxy"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
"github.com/yusing/goutils/synk"
|
||||
)
|
||||
|
||||
type (
|
||||
Command struct {
|
||||
raw string
|
||||
exec CommandHandler
|
||||
raw string
|
||||
exec CommandHandler
|
||||
isResponseHandler bool
|
||||
}
|
||||
)
|
||||
|
||||
func (cmd *Command) IsResponseHandler() bool {
|
||||
return cmd.isResponseHandler
|
||||
}
|
||||
|
||||
const (
|
||||
CommandRequireAuth = "require_auth"
|
||||
CommandRewrite = "rewrite"
|
||||
CommandServe = "serve"
|
||||
CommandProxy = "proxy"
|
||||
@@ -31,18 +45,46 @@ const (
|
||||
CommandSet = "set"
|
||||
CommandAdd = "add"
|
||||
CommandRemove = "remove"
|
||||
CommandLog = "log"
|
||||
CommandNotify = "notify"
|
||||
CommandPass = "pass"
|
||||
CommandPassAlt = "bypass"
|
||||
)
|
||||
|
||||
var commands = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
build func(args any) CommandHandler
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
build func(args any) CommandHandler
|
||||
isResponseHandler bool
|
||||
}{
|
||||
CommandRequireAuth: {
|
||||
help: Help{
|
||||
command: CommandRequireAuth,
|
||||
description: makeLines("Require HTTP authentication for incoming requests"),
|
||||
args: map[string]string{},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
if len(args) != 0 {
|
||||
return nil, ErrExpectNoArg
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if !auth.AuthOrProceed(w, r) {
|
||||
return errTerminated
|
||||
}
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandRewrite: {
|
||||
help: Help{
|
||||
command: CommandRewrite,
|
||||
description: makeLines(
|
||||
"Rewrite a request path from one prefix to another, e.g.:",
|
||||
helpExample(CommandRewrite, "/foo", "/bar"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"from": "the path to rewrite, must start with /",
|
||||
"to": "the path to rewrite to, must start with /",
|
||||
@@ -67,24 +109,29 @@ var commands = map[string]struct {
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
orig, repl := args.(*StrTuple).Unpack()
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
path := r.URL.Path
|
||||
if len(path) > 0 && path[0] != '/' {
|
||||
path = "/" + path
|
||||
}
|
||||
if !strings.HasPrefix(path, orig) {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
path = repl + path[len(orig):]
|
||||
r.URL.Path = path
|
||||
r.URL.RawPath = r.URL.EscapedPath()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandServe: {
|
||||
help: Help{
|
||||
command: CommandServe,
|
||||
description: makeLines(
|
||||
"Serve static files from a local file system path, e.g.:",
|
||||
helpExample(CommandServe, "/var/www"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"root": "the file system path to serve, must be an existing directory",
|
||||
},
|
||||
@@ -92,14 +139,19 @@ var commands = map[string]struct {
|
||||
validate: validateFSPath,
|
||||
build: func(args any) CommandHandler {
|
||||
root := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path)))
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandRedirect: {
|
||||
help: Help{
|
||||
command: CommandRedirect,
|
||||
description: makeLines(
|
||||
"Redirect request to another URL, e.g.:",
|
||||
helpExample(CommandRedirect, "https://example.com"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"to": "the url to redirect to, can be relative or absolute URL",
|
||||
},
|
||||
@@ -107,14 +159,19 @@ var commands = map[string]struct {
|
||||
validate: validateURL,
|
||||
build: func(args any) CommandHandler {
|
||||
target := args.(*nettypes.URL).String()
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
http.Redirect(w, r, target, http.StatusTemporaryRedirect)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandError: {
|
||||
help: Help{
|
||||
command: CommandError,
|
||||
description: makeLines(
|
||||
"Send an HTTP error response and terminate processing, e.g.:",
|
||||
helpExample(CommandError, "400", "bad request"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"code": "the http status code to return",
|
||||
"text": "the error message to return",
|
||||
@@ -136,14 +193,21 @@ var commands = map[string]struct {
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
code, text := args.(*Tuple[int, string]).Unpack()
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
// error command should overwrite the response body
|
||||
GetInitResponseModifier(w).ResetBody()
|
||||
http.Error(w, text, code)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandRequireBasicAuth: {
|
||||
help: Help{
|
||||
command: CommandRequireBasicAuth,
|
||||
description: makeLines(
|
||||
"Require HTTP basic authentication for incoming requests, e.g.:",
|
||||
helpExample(CommandRequireBasicAuth, "Restricted Area"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"realm": "the authentication realm",
|
||||
},
|
||||
@@ -156,35 +220,63 @@ var commands = map[string]struct {
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
realm := args.(string)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandProxy: {
|
||||
help: Help{
|
||||
command: CommandProxy,
|
||||
description: makeLines(
|
||||
"Proxy the request to the specified absolute URL, e.g.:",
|
||||
helpExample(CommandProxy, "http://upstream:8080"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"to": "the url to proxy to, must be an absolute URL",
|
||||
},
|
||||
},
|
||||
validate: validateAbsoluteURL,
|
||||
validate: validateURL,
|
||||
build: func(args any) CommandHandler {
|
||||
target := args.(*nettypes.URL)
|
||||
if target.Scheme == "" {
|
||||
target.Scheme = "http"
|
||||
}
|
||||
if target.Host == "" {
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
url := target.URL
|
||||
url.Host = routes.TryGetUpstreamHostPort(r)
|
||||
if url.Host == "" {
|
||||
return fmt.Errorf("no upstream host: %s", r.URL.String())
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport())
|
||||
r.URL.Path = target.Path
|
||||
r.URL.RawPath = r.URL.EscapedPath()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
rp.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport())
|
||||
return TerminatingCommand(rp.ServeHTTP)
|
||||
return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rp.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandSet: {
|
||||
help: Help{
|
||||
command: CommandSet,
|
||||
description: makeLines(
|
||||
"Set a field in the request or response, e.g.:",
|
||||
helpExample(CommandSet, "header", "User-Agent", "godoxy"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"field": "the field to set",
|
||||
"value": "the value to set",
|
||||
"target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")),
|
||||
"field": "the field to set",
|
||||
"value": "the value to set",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
@@ -197,9 +289,14 @@ var commands = map[string]struct {
|
||||
CommandAdd: {
|
||||
help: Help{
|
||||
command: CommandAdd,
|
||||
description: makeLines(
|
||||
"Add a value to a field in the request or response, e.g.:",
|
||||
helpExample(CommandAdd, "header", "X-Foo", "bar"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"field": "the field to add",
|
||||
"value": "the value to add",
|
||||
"target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")),
|
||||
"field": "the field to add",
|
||||
"value": "the value to add",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
@@ -212,8 +309,13 @@ var commands = map[string]struct {
|
||||
CommandRemove: {
|
||||
help: Help{
|
||||
command: CommandRemove,
|
||||
description: makeLines(
|
||||
"Remove a field from the request or response, e.g.:",
|
||||
helpExample(CommandRemove, "header", "User-Agent"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"field": "the field to remove",
|
||||
"target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")),
|
||||
"field": "the field to remove",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
@@ -223,17 +325,157 @@ var commands = map[string]struct {
|
||||
return args.(CommandHandler)
|
||||
},
|
||||
},
|
||||
CommandLog: {
|
||||
isResponseHandler: true,
|
||||
help: Help{
|
||||
command: CommandLog,
|
||||
description: makeLines(
|
||||
"The template supports the following variables:",
|
||||
helpListItem("Request", "the request object"),
|
||||
helpListItem("Response", "the response object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(CommandLog, "info", "/dev/stdout", "{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"level": "the log level",
|
||||
"path": "the log path (/dev/stdout for stdout, /dev/stderr for stderr)",
|
||||
"template": "the template to log",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrExpectThreeArgs
|
||||
}
|
||||
tmpl, err := validateTemplate(args[2], true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
level, err := validateLevel(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// NOTE: file will stay opened forever
|
||||
// it leverages accesslog.NewFileIO so
|
||||
// it will be opened only once for the same path
|
||||
f, err := openFile(args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &onLogArgs{level, f, tmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
level, f, tmpl := args.(*onLogArgs).Unpack()
|
||||
var logger io.Writer
|
||||
if f == stdout || f == stderr {
|
||||
logger = logging.NewLoggerWithFixedLevel(level, f)
|
||||
} else {
|
||||
logger = f
|
||||
}
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
err := executeReqRespTemplateTo(tmpl, logger, w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
CommandNotify: {
|
||||
isResponseHandler: true,
|
||||
help: Help{
|
||||
command: CommandNotify,
|
||||
description: makeLines(
|
||||
"The template supports the following variables:",
|
||||
helpListItem("Request", "the request object"),
|
||||
helpListItem("Response", "the response object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(CommandNotify, "info", "ntfy", "Received request to {{ .Request.URL }}", "{{ .Request.Method }} {{ .Response.StatusCode }}"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"level": "the log level",
|
||||
"provider": "the notification provider (must be defined in config `providers.notification`)",
|
||||
"title": "the title of the notification",
|
||||
"body": "the body of the notification",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
if len(args) != 4 {
|
||||
return nil, ErrExpectFourArgs
|
||||
}
|
||||
titleTmpl, err := validateTemplate(args[2], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodyTmpl, err := validateTemplate(args[3], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
level, err := validateLevel(args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: validate provider
|
||||
// currently it is not possible, because rule validation happens on UnmarshalYAMLValidate
|
||||
// and we cannot call config.ActiveConfig.Load() because it will cause import cycle
|
||||
|
||||
// err = validateNotifProvider(args[1])
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil
|
||||
},
|
||||
build: func(args any) CommandHandler {
|
||||
level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack()
|
||||
to := []string{provider}
|
||||
|
||||
return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
buf := bufPool.Get()
|
||||
defer bufPool.Put(buf)
|
||||
|
||||
respBuf := bytes.NewBuffer(buf)
|
||||
|
||||
err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
titleLen := respBuf.Len()
|
||||
err = executeReqRespTemplateTo(bodyTmpl, respBuf, w, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: level,
|
||||
Title: string(buf[:titleLen]),
|
||||
Body: notif.MessageBodyBytes(buf[titleLen:]),
|
||||
To: to,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
type reqResponseTemplateData struct {
|
||||
Request *http.Request
|
||||
Response struct {
|
||||
StatusCode int
|
||||
Header http.Header
|
||||
}
|
||||
}
|
||||
|
||||
var bufPool = synk.GetBytesPoolWithUniqueMemory()
|
||||
|
||||
type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr]
|
||||
type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr]
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (cmd *Command) Parse(v string) error {
|
||||
lines := strutils.SplitLine(v)
|
||||
if len(lines) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
executors := make([]CommandHandler, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
executors := make([]CommandHandler, 0)
|
||||
isResponseHandler := false
|
||||
for line := range strings.SplitSeq(v, "\n") {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
@@ -257,13 +499,21 @@ func (cmd *Command) Parse(v string) error {
|
||||
}
|
||||
validArgs, err := builder.validate(args)
|
||||
if err != nil {
|
||||
return err.Subject(directive).Withf("%s", builder.help.String())
|
||||
// Only attach help for the directive that failed, avoid bringing in unrelated KV errors
|
||||
return err.Subject(directive).With(builder.help.Error())
|
||||
}
|
||||
|
||||
executors = append(executors, builder.build(validArgs))
|
||||
handler := builder.build(validArgs)
|
||||
executors = append(executors, handler)
|
||||
if builder.isResponseHandler || handler.IsResponseHandler() {
|
||||
isResponseHandler = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(executors) == 0 {
|
||||
cmd.raw = v
|
||||
cmd.exec = nil
|
||||
cmd.isResponseHandler = false
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -274,10 +524,14 @@ func (cmd *Command) Parse(v string) error {
|
||||
|
||||
cmd.raw = v
|
||||
cmd.exec = exec
|
||||
if exec.IsResponseHandler() {
|
||||
isResponseHandler = true
|
||||
}
|
||||
cmd.isResponseHandler = isResponseHandler
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCmd(executors []CommandHandler) (CommandHandler, error) {
|
||||
func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) {
|
||||
for i, exec := range executors {
|
||||
switch exec.(type) {
|
||||
case TerminatingCommand, BypassCommand:
|
||||
@@ -308,6 +562,10 @@ func (cmd *Command) isBypass() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
return cmd.exec.Handle(w, r)
|
||||
}
|
||||
|
||||
func (cmd *Command) String() string {
|
||||
return cmd.raw
|
||||
}
|
||||
|
||||
400
internal/route/rules/do_log_test.go
Normal file
400
internal/route/rules/do_log_test.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
|
||||
// mockUpstreamWithHeaders creates an upstream that returns specific headers
|
||||
func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
maps.Copy(w.Header(), headers)
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
|
||||
func parseRules(data string, target *Rules) gperr.Error {
|
||||
_, err := serialization.ConvertString(data, reflect.ValueOf(target))
|
||||
return err
|
||||
}
|
||||
|
||||
func TestLogCommand_TemporaryFile(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(200, "success response", http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
})
|
||||
|
||||
// Create a temporary file for logging
|
||||
tempFile, err := os.CreateTemp("", "test-log-*.log")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-request-response
|
||||
do: |
|
||||
log info %q '{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }} {{ index (index .Response.Header "Content-Type") 0 }}'
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/users", nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "success response", w.Body.String())
|
||||
|
||||
// Read and verify log content
|
||||
content, err := os.ReadFile(tempFile.Name())
|
||||
require.NoError(t, err)
|
||||
logContent := string(content)
|
||||
|
||||
assert.Equal(t, "POST /api/users 200 application/json\n", logContent)
|
||||
}
|
||||
|
||||
func TestLogCommand_StdoutAndStderr(t *testing.T) {
|
||||
upstream := mockUpstream(200, "success")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: log-stdout
|
||||
do: |
|
||||
log info /dev/stdout "stdout: {{ .Request.Method }} {{ .Response.StatusCode }}"
|
||||
- name: log-stderr
|
||||
do: |
|
||||
log error /dev/stderr "stderr: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
// Note: We can't easily capture stdout/stderr in unit tests,
|
||||
// but we can verify no errors occurred and the handler completed
|
||||
}
|
||||
|
||||
func TestLogCommand_DifferentLogLevels(t *testing.T) {
|
||||
upstream := mockUpstream(404, "not found")
|
||||
|
||||
// Create temporary files for different log levels
|
||||
infoFile, err := os.CreateTemp("", "test-info-*.log")
|
||||
require.NoError(t, err)
|
||||
infoFile.Close()
|
||||
defer os.Remove(infoFile.Name())
|
||||
|
||||
warnFile, err := os.CreateTemp("", "test-warn-*.log")
|
||||
require.NoError(t, err)
|
||||
warnFile.Close()
|
||||
defer os.Remove(warnFile.Name())
|
||||
|
||||
errorFile, err := os.CreateTemp("", "test-error-*.log")
|
||||
require.NoError(t, err)
|
||||
errorFile.Close()
|
||||
defer os.Remove(errorFile.Name())
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-info
|
||||
do: |
|
||||
log info %s "INFO: {{ .Request.Method }} {{ .Response.StatusCode }}"
|
||||
- name: log-warn
|
||||
do: |
|
||||
log warn %s "WARN: {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
- name: log-error
|
||||
do: |
|
||||
log error %s "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/resource/123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 404, w.Code)
|
||||
|
||||
// Verify each log file
|
||||
infoContent, err := os.ReadFile(infoFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "INFO: DELETE 404", strings.TrimSpace(string(infoContent)))
|
||||
|
||||
warnContent, err := os.ReadFile(warnFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "WARN: /api/resource/123 404", strings.TrimSpace(string(warnContent)))
|
||||
|
||||
errorContent, err := os.ReadFile(errorFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ERROR: DELETE /api/resource/123 404", strings.TrimSpace(string(errorContent)))
|
||||
}
|
||||
|
||||
func TestLogCommand_TemplateVariables(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", "custom-value")
|
||||
w.Header().Set("Content-Length", "42")
|
||||
w.WriteHeader(201)
|
||||
w.Write([]byte("created"))
|
||||
})
|
||||
|
||||
// Create temporary file
|
||||
tempFile, err := os.CreateTemp("", "test-template-*.log")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-with-templates
|
||||
do: |
|
||||
log info %s 'Request: {{ .Request.Method }} {{ .Request.URL }} Host: {{ .Request.Host }} User-Agent: {{ index .Request.Header "User-Agent" 0 }} Response: {{ .Response.StatusCode }} Custom-Header: {{ index .Response.Header "X-Custom-Header" 0 }} Content-Length: {{ index .Response.Header "Content-Length" 0 }}'
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/api/resource", nil)
|
||||
req.Header.Set("User-Agent", "test-client/1.0")
|
||||
req.Host = "example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 201, w.Code)
|
||||
|
||||
// Verify log content
|
||||
content, err := os.ReadFile(tempFile.Name())
|
||||
require.NoError(t, err)
|
||||
logContent := strings.TrimSpace(string(content))
|
||||
|
||||
assert.Equal(t, "Request: PUT /api/resource Host: example.com User-Agent: test-client/1.0 Response: 201 Custom-Header: custom-value Content-Length: 42", logContent)
|
||||
}
|
||||
|
||||
func TestLogCommand_ConditionalLogging(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/error":
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte("internal server error"))
|
||||
case "/notfound":
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("not found"))
|
||||
default:
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("success"))
|
||||
}
|
||||
})
|
||||
|
||||
// Create temporary files
|
||||
successFile, err := os.CreateTemp("", "test-success-*.log")
|
||||
require.NoError(t, err)
|
||||
successFile.Close()
|
||||
defer os.Remove(successFile.Name())
|
||||
|
||||
errorFile, err := os.CreateTemp("", "test-error-*.log")
|
||||
require.NoError(t, err)
|
||||
errorFile.Close()
|
||||
defer os.Remove(errorFile.Name())
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-success
|
||||
on: status 2xx
|
||||
do: |
|
||||
log info %q "SUCCESS: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
- name: log-error
|
||||
on: status 4xx | status 5xx
|
||||
do: |
|
||||
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"
|
||||
`, successFile.Name(), errorFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test success request
|
||||
req1 := httptest.NewRequest("GET", "/success", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
|
||||
// Test not found request
|
||||
req2 := httptest.NewRequest("GET", "/notfound", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
|
||||
// Test server error request
|
||||
req3 := httptest.NewRequest("POST", "/error", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
assert.Equal(t, 500, w3.Code)
|
||||
|
||||
// Verify success log
|
||||
successContent, err := os.ReadFile(successFile.Name())
|
||||
require.NoError(t, err)
|
||||
successLines := strings.Split(strings.TrimSpace(string(successContent)), "\n")
|
||||
assert.Len(t, successLines, 1)
|
||||
assert.Equal(t, "SUCCESS: GET /success 200", successLines[0])
|
||||
|
||||
// Verify error log
|
||||
errorContent, err := os.ReadFile(errorFile.Name())
|
||||
require.NoError(t, err)
|
||||
errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n")
|
||||
assert.Len(t, errorLines, 2)
|
||||
assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0])
|
||||
assert.Equal(t, "ERROR: POST /error 500", errorLines[1])
|
||||
}
|
||||
|
||||
func TestLogCommand_MultipleLogEntries(t *testing.T) {
|
||||
upstream := mockUpstream(200, "response")
|
||||
|
||||
// Create temporary file
|
||||
tempFile, err := os.CreateTemp("", "test-multiple-*.log")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-multiple
|
||||
do: |
|
||||
log info %q "{{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Make multiple requests
|
||||
requests := []struct {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"GET", "/users"},
|
||||
{"POST", "/users"},
|
||||
{"PUT", "/users/1"},
|
||||
{"DELETE", "/users/1"},
|
||||
}
|
||||
|
||||
for _, reqInfo := range requests {
|
||||
req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
// Verify all requests were logged
|
||||
content, err := os.ReadFile(tempFile.Name())
|
||||
require.NoError(t, err)
|
||||
logContent := strings.TrimSpace(string(content))
|
||||
lines := strings.Split(logContent, "\n")
|
||||
|
||||
assert.Len(t, lines, len(requests))
|
||||
|
||||
for i, reqInfo := range requests {
|
||||
expectedLog := reqInfo.method + " " + reqInfo.path + " 200"
|
||||
assert.Equal(t, expectedLog, lines[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogCommand_FilePermissions(t *testing.T) {
|
||||
upstream := mockUpstream(200, "success")
|
||||
|
||||
// Create a temporary directory
|
||||
tempDir, err := os.MkdirTemp("", "test-log-dir")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create a log file path within the temp directory
|
||||
logFilePath := filepath.Join(tempDir, "test.log")
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- on: status 2xx
|
||||
do: log info %q "{{ .Request.Method }} {{ .Response.StatusCode }}"`, logFilePath), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
|
||||
// Verify file was created and is writable
|
||||
_, err = os.Stat(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test writing to the file again to ensure it's not closed
|
||||
req2 := httptest.NewRequest("POST", "/test2", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
|
||||
// Verify both entries are in the file
|
||||
content, err := os.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
logContent := strings.TrimSpace(string(content))
|
||||
lines := strings.Split(logContent, "\n")
|
||||
|
||||
assert.Len(t, lines, 2)
|
||||
assert.Equal(t, "GET 200", lines[0])
|
||||
assert.Equal(t, "POST 200", lines[1])
|
||||
}
|
||||
|
||||
func TestLogCommand_InvalidTemplate(t *testing.T) {
|
||||
upstream := mockUpstream(200, "success")
|
||||
|
||||
var rules Rules
|
||||
|
||||
// Test with invalid template syntax
|
||||
err := parseRules(`
|
||||
- name: log-invalid
|
||||
do: |
|
||||
log info /dev/stdout "{{ .Invalid.Field }}"`, &rules)
|
||||
// Should not error during parsing, but template execution will fail gracefully
|
||||
assert.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
handler.ServeHTTP(w, req)
|
||||
})
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
328
internal/route/rules/do_set.go
Normal file
328
internal/route/rules/do_set.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
|
||||
type (
|
||||
FieldHandler struct {
|
||||
set, add, remove CommandHandler
|
||||
}
|
||||
FieldModifier string
|
||||
)
|
||||
|
||||
const (
|
||||
ModFieldSet FieldModifier = "set"
|
||||
ModFieldAdd FieldModifier = "add"
|
||||
ModFieldRemove FieldModifier = "remove"
|
||||
)
|
||||
|
||||
const (
|
||||
FieldHeader = "header"
|
||||
FieldResponseHeader = "resp_header"
|
||||
FieldQuery = "query"
|
||||
FieldCookie = "cookie"
|
||||
FieldBody = "body"
|
||||
FieldResponseBody = "resp_body"
|
||||
FieldStatusCode = "status"
|
||||
)
|
||||
|
||||
var AllFields = []string{FieldHeader, FieldResponseHeader, FieldQuery, FieldCookie, FieldBody, FieldResponseBody, FieldStatusCode}
|
||||
|
||||
// NOTE: should not use canonicalized header keys, respect to user's input
|
||||
var modFields = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) *FieldHandler
|
||||
}{
|
||||
FieldHeader: {
|
||||
help: Help{
|
||||
command: FieldHeader,
|
||||
args: map[string]string{
|
||||
"key": "the header key",
|
||||
"value": "the header template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Header[k] = []string{v}
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Header[k] = append(r.Header[k], v)
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
delete(r.Header, k)
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldResponseHeader: {
|
||||
help: Help{
|
||||
command: FieldResponseHeader,
|
||||
args: map[string]string{
|
||||
"key": "the response header key",
|
||||
"value": "the response header template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header()[k] = []string{v}
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header()[k] = append(w.Header()[k], v)
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
delete(w.Header(), k)
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldQuery: {
|
||||
help: Help{
|
||||
command: FieldQuery,
|
||||
args: map[string]string{
|
||||
"key": "the query key",
|
||||
"value": "the query template",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Set(k, v)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Add(k, v)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
GetSharedData(w).UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Del(k)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldCookie: {
|
||||
help: Help{
|
||||
command: FieldCookie,
|
||||
args: map[string]string{
|
||||
"key": "the cookie key",
|
||||
"value": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toKeyValueTemplate,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, tmpl := args.(*keyValueTemplate).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
cookies[i].Value = v
|
||||
return cookies
|
||||
}
|
||||
}
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
v, err := executeRequestTemplateString(tmpl, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
index := -1
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if index != -1 {
|
||||
if len(cookies) == 1 {
|
||||
return []*http.Cookie{}
|
||||
}
|
||||
return append(cookies[:index], cookies[index+1:]...)
|
||||
}
|
||||
return cookies
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldBody: {
|
||||
help: Help{
|
||||
command: FieldBody,
|
||||
description: makeLines(
|
||||
"Override the request body that will be sent to the upstream",
|
||||
"The template supports the following variables:",
|
||||
helpListItem("Request", "the request object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(FieldBody, "HTTP STATUS: {{ .Request.Method }} {{ .Request.URL.Path }}"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"template": "the body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
return validateTemplate(args[0], true)
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateOrStr)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Body != nil {
|
||||
r.Body.Close()
|
||||
r.Body = nil
|
||||
}
|
||||
|
||||
buf := pool.Get()
|
||||
b := bytes.NewBuffer(buf)
|
||||
|
||||
err := executeRequestTemplateTo(tmpl, b, r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() {
|
||||
pool.Put(buf)
|
||||
})
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldResponseBody: {
|
||||
help: Help{
|
||||
command: FieldResponseBody,
|
||||
description: makeLines(
|
||||
"Override the response body that will be sent to the client",
|
||||
"The template supports the following variables:",
|
||||
helpListItem("Request", "the request object"),
|
||||
helpListItem("Response", "the response object"),
|
||||
"",
|
||||
"Example:",
|
||||
helpExample(FieldResponseBody, "HTTP STATUS: {{ .Request.Method }} {{ .Response.StatusCode }}"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"template": "the response body template",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
return validateTemplate(args[0], true)
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
tmpl := args.(templateOrStr)
|
||||
return &FieldHandler{
|
||||
set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
rm := GetInitResponseModifier(w)
|
||||
rm.ResetBody()
|
||||
return executeReqRespTemplateTo(tmpl, rm, rm, r)
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldStatusCode: {
|
||||
help: Help{
|
||||
command: FieldStatusCode,
|
||||
description: makeLines(
|
||||
"Override the status code that will be sent to the client, e.g.:",
|
||||
helpExample(FieldStatusCode, "200"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"code": "the status code",
|
||||
},
|
||||
},
|
||||
validate: func(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
status, err := strconv.Atoi(args[0])
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if status < 100 || status > 599 {
|
||||
return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status)
|
||||
}
|
||||
return status, nil
|
||||
},
|
||||
builder: func(args any) *FieldHandler {
|
||||
status := args.(int)
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error {
|
||||
GetInitResponseModifier(w).WriteHeader(status)
|
||||
return nil
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
643
internal/route/rules/do_set_test.go
Normal file
643
internal/route/rules/do_set_test.go
Normal file
@@ -0,0 +1,643 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFieldHandler_Header(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
modifier FieldModifier
|
||||
setup func(*http.Request)
|
||||
verify func(*http.Request, *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "set header",
|
||||
key: "X-Test",
|
||||
value: "test-value",
|
||||
modifier: ModFieldSet,
|
||||
setup: func(r *http.Request) {
|
||||
r.Header.Set("X-Test", "old-value")
|
||||
},
|
||||
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
|
||||
got := r.Header.Get("X-Test")
|
||||
assert.Equal(t, "test-value", got, "Expected header X-Test to be 'test-value'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add header",
|
||||
key: "X-Test",
|
||||
value: "new-value",
|
||||
modifier: ModFieldAdd,
|
||||
setup: func(r *http.Request) {
|
||||
r.Header.Set("X-Test", "existing-value")
|
||||
},
|
||||
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
|
||||
values := r.Header["X-Test"]
|
||||
require.Len(t, values, 2, "Expected 2 header values")
|
||||
assert.Equal(t, "existing-value", values[0], "Expected first value of X-Test header to be 'existing-value'")
|
||||
assert.Equal(t, "new-value", values[1], "Expected second value of X-Test header to be 'new-value'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove header",
|
||||
key: "X-Test",
|
||||
value: "",
|
||||
modifier: ModFieldRemove,
|
||||
setup: func(r *http.Request) {
|
||||
r.Header.Set("X-Test", "to-be-removed")
|
||||
},
|
||||
verify: func(r *http.Request, w *httptest.ResponseRecorder) {
|
||||
got := r.Header.Get("X-Test")
|
||||
assert.Empty(t, got, "Expected header X-Test to be removed")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
case ModFieldAdd:
|
||||
cmd = handler.add
|
||||
case ModFieldRemove:
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(req, w)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldHandler_ResponseHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
modifier FieldModifier
|
||||
setup func(*httptest.ResponseRecorder)
|
||||
verify func(*httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "set response header",
|
||||
key: "X-Response-Test",
|
||||
value: "response-value",
|
||||
modifier: ModFieldSet,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
got := w.Header().Get("X-Response-Test")
|
||||
assert.Equal(t, "response-value", got, "Expected response header X-Response-Test to be 'response-value'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add response header",
|
||||
key: "X-Response-Test",
|
||||
value: "additional-value",
|
||||
modifier: ModFieldAdd,
|
||||
setup: func(w *httptest.ResponseRecorder) {
|
||||
w.Header().Set("X-Response-Test", "existing-value")
|
||||
},
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
values := w.Header()["X-Response-Test"]
|
||||
require.Len(t, values, 2)
|
||||
assert.Equal(t, values[0], "existing-value")
|
||||
assert.Equal(t, values[1], "additional-value")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove response header",
|
||||
key: "X-Response-Test",
|
||||
value: "",
|
||||
modifier: ModFieldRemove,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Empty(t, w.Header().Get("X-Response-Test"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
if tt.setup != nil {
|
||||
tt.setup(w)
|
||||
}
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
case ModFieldAdd:
|
||||
cmd = handler.add
|
||||
case ModFieldRemove:
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(w)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldHandler_Query(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
modifier FieldModifier
|
||||
setup func(*http.Request)
|
||||
verify func(*http.Request)
|
||||
}{
|
||||
{
|
||||
name: "set query",
|
||||
key: "test",
|
||||
value: "new-value",
|
||||
modifier: ModFieldSet,
|
||||
setup: func(r *http.Request) {
|
||||
r.URL.RawQuery = "test=old-value&other=keep"
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
got := r.URL.Query().Get("test")
|
||||
assert.Equal(t, "new-value", got, "Expected query 'test' to be 'new-value'")
|
||||
gotOther := r.URL.Query().Get("other")
|
||||
assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add query",
|
||||
key: "test",
|
||||
value: "additional-value",
|
||||
modifier: ModFieldAdd,
|
||||
setup: func(r *http.Request) {
|
||||
r.URL.RawQuery = "test=existing-value"
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
values := r.URL.Query()["test"]
|
||||
require.Len(t, values, 2, "Expected 2 query values")
|
||||
assert.Equal(t, "existing-value", values[0], "Expected first value of test query param to be 'existing-value'")
|
||||
assert.Equal(t, "additional-value", values[1], "Expected second value of test query param to be 'additional-value'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove query",
|
||||
key: "test",
|
||||
value: "",
|
||||
modifier: ModFieldRemove,
|
||||
setup: func(r *http.Request) {
|
||||
r.URL.RawQuery = "test=to-be-removed&other=keep"
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
got := r.URL.Query().Get("test")
|
||||
assert.Empty(t, got, "Expected query 'test' to be removed")
|
||||
gotOther := r.URL.Query().Get("other")
|
||||
assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
case ModFieldAdd:
|
||||
cmd = handler.add
|
||||
case ModFieldRemove:
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldHandler_Cookie(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
modifier FieldModifier
|
||||
setup func(*http.Request)
|
||||
verify func(*http.Request)
|
||||
}{
|
||||
{
|
||||
name: "set cookie",
|
||||
key: "test",
|
||||
value: "new-value",
|
||||
modifier: ModFieldSet,
|
||||
setup: func(r *http.Request) {
|
||||
r.AddCookie(&http.Cookie{Name: "test", Value: "old-value"})
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
cookie, err := r.Cookie("test")
|
||||
assert.NoError(t, err, "Expected cookie 'test' to exist")
|
||||
if err == nil {
|
||||
assert.Equal(t, "new-value", cookie.Value, "Expected cookie 'test' to be 'new-value'")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add cookie",
|
||||
key: "test",
|
||||
value: "additional-value",
|
||||
modifier: ModFieldAdd,
|
||||
setup: func(r *http.Request) {
|
||||
r.AddCookie(&http.Cookie{Name: "test", Value: "existing-value"})
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
cookies := r.Cookies()
|
||||
testCookies := make([]string, 0)
|
||||
for _, c := range cookies {
|
||||
if c.Name == "test" {
|
||||
testCookies = append(testCookies, c.Value)
|
||||
}
|
||||
}
|
||||
require.Len(t, testCookies, 2, "Expected 2 cookies with name 'test'")
|
||||
assert.Equal(t, "existing-value", testCookies[0], "Expected first value of 'test' cookie to be 'existing-value'")
|
||||
assert.Equal(t, "additional-value", testCookies[1], "Expected second value of 'test' cookie to be 'additional-value'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove cookie",
|
||||
key: "test",
|
||||
value: "",
|
||||
modifier: ModFieldRemove,
|
||||
setup: func(r *http.Request) {
|
||||
r.AddCookie(&http.Cookie{Name: "test", Value: "to-be-removed"})
|
||||
r.AddCookie(&http.Cookie{Name: "other", Value: "keep"})
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
_, err := r.Cookie("test")
|
||||
assert.Error(t, err, "Expected cookie 'test' to be removed")
|
||||
cookie, err := r.Cookie("other")
|
||||
assert.NoError(t, err, "Expected cookie 'other' to exist")
|
||||
if err == nil {
|
||||
assert.Equal(t, "keep", cookie.Value, "Expected cookie 'other' to be 'keep'")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.value, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to validate template: %v", tErr)
|
||||
}
|
||||
handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl})
|
||||
var cmd CommandHandler
|
||||
switch tt.modifier {
|
||||
case ModFieldSet:
|
||||
cmd = handler.set
|
||||
case ModFieldAdd:
|
||||
cmd = handler.add
|
||||
case ModFieldRemove:
|
||||
cmd = handler.remove
|
||||
}
|
||||
|
||||
err := cmd.Handle(w, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldHandler_Body(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
setup func(*http.Request)
|
||||
verify func(*http.Request)
|
||||
}{
|
||||
{
|
||||
name: "set body with template",
|
||||
template: "Hello {{ .Request.Method }} {{ .Request.URL.Path }}",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "POST"
|
||||
r.URL.Path = "/test"
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.NoError(t, err, "Failed to read body")
|
||||
expected := "Hello POST /test"
|
||||
assert.Equal(t, expected, string(body), "Expected body content")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set body with existing body",
|
||||
template: "Overridden",
|
||||
setup: func(r *http.Request) {
|
||||
r.Body = io.NopCloser(strings.NewReader("original body"))
|
||||
},
|
||||
verify: func(r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
assert.NoError(t, err, "Failed to read body")
|
||||
assert.Equal(t, "Overridden", string(body), "Expected body to be 'Overridden'")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.template, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to parse template: %v", tErr)
|
||||
}
|
||||
|
||||
handler := modFields[FieldBody].builder(tmpl)
|
||||
err := handler.set.Handle(w, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldHandler_ResponseBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
setup func(*http.Request)
|
||||
verify func(*ResponseModifier)
|
||||
}{
|
||||
{
|
||||
name: "set response body with template",
|
||||
template: "Response: {{ .Request.Method }} {{ .Request.URL.Path }}",
|
||||
setup: func(r *http.Request) {
|
||||
r.Method = "GET"
|
||||
r.URL.Path = "/api/test"
|
||||
},
|
||||
verify: func(rm *ResponseModifier) {
|
||||
content := rm.buf.String()
|
||||
expected := "Response: GET /api/test"
|
||||
assert.Equal(t, expected, content, "Expected response body")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
tt.setup(req)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Create ResponseModifier wrapper
|
||||
rm := NewResponseModifier(w)
|
||||
|
||||
tmpl, tErr := validateTemplate(tt.template, false)
|
||||
if tErr != nil {
|
||||
t.Fatalf("Failed to parse template: %v", tErr)
|
||||
}
|
||||
|
||||
handler := modFields[FieldResponseBody].builder(tmpl)
|
||||
err := handler.set.Handle(rm, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
|
||||
tt.verify(rm)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldHandler_StatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
verify func(*httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "set status code 200",
|
||||
status: 200,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 200, w.Code, "Expected status code 200")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set status code 404",
|
||||
status: 404,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 404, w.Code, "Expected status code 404")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set status code 500",
|
||||
status: 500,
|
||||
verify: func(w *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, 500, w.Code, "Expected status code 500")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
rm := NewResponseModifier(w)
|
||||
var cmd Command
|
||||
err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status))
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
err = cmd.ServeHTTP(rm, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Handler returned error: %v", err)
|
||||
}
|
||||
rm.FlushRelease()
|
||||
|
||||
tt.verify(w)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
args []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "header valid",
|
||||
field: FieldHeader,
|
||||
args: []string{"key", "value"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "header invalid - missing value",
|
||||
field: FieldHeader,
|
||||
args: []string{"key"},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "response header valid",
|
||||
field: FieldResponseHeader,
|
||||
args: []string{"key", "value"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "query valid",
|
||||
field: FieldQuery,
|
||||
args: []string{"key", "value"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "cookie valid",
|
||||
field: FieldCookie,
|
||||
args: []string{"key", "value"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "body valid template",
|
||||
field: FieldBody,
|
||||
args: []string{"Hello {{ .Request.Method }}"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "body invalid template syntax",
|
||||
field: FieldBody,
|
||||
args: []string{"Hello {{ .InvalidField "},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "response body valid template",
|
||||
field: FieldResponseBody,
|
||||
args: []string{"Response: {{ .Request.Method }}"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "status code valid",
|
||||
field: FieldStatusCode,
|
||||
args: []string{"200"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "status code invalid - too low",
|
||||
field: FieldStatusCode,
|
||||
args: []string{"99"},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "status code invalid - too high",
|
||||
field: FieldStatusCode,
|
||||
args: []string{"600"},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "status code invalid - not a number",
|
||||
field: FieldStatusCode,
|
||||
args: []string{"not-a-number"},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
field, exists := modFields[tt.field]
|
||||
assert.True(t, exists, "Field %s does not exist", tt.field)
|
||||
|
||||
_, err := field.validate(tt.args)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err, "Expected error but got none")
|
||||
} else {
|
||||
assert.NoError(t, err, "Expected no error but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllFields(t *testing.T) {
|
||||
expectedFields := []string{
|
||||
FieldHeader,
|
||||
FieldResponseHeader,
|
||||
FieldQuery,
|
||||
FieldCookie,
|
||||
FieldBody,
|
||||
FieldResponseBody,
|
||||
FieldStatusCode,
|
||||
}
|
||||
|
||||
require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields))
|
||||
|
||||
for _, expected := range expectedFields {
|
||||
found := false
|
||||
for _, actual := range AllFields {
|
||||
if actual == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected field %s not found in AllFields", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModFields(t *testing.T) {
|
||||
for fieldName, field := range modFields {
|
||||
// Test that each field has required components
|
||||
assert.NotNil(t, field.validate, "Field %s has nil validate function", fieldName)
|
||||
assert.NotNil(t, field.builder, "Field %s has nil builder function", fieldName)
|
||||
assert.NotEmpty(t, field.help.command, "Field %s has empty help command", fieldName)
|
||||
}
|
||||
}
|
||||
@@ -99,10 +99,15 @@ func TestParseCommands(t *testing.T) {
|
||||
},
|
||||
// proxy directive tests
|
||||
{
|
||||
name: "proxy_valid",
|
||||
name: "proxy_valid_abs",
|
||||
input: "proxy http://localhost:8080",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "proxy_valid_rel",
|
||||
input: "proxy /foo/bar",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "proxy_missing_target",
|
||||
input: "proxy",
|
||||
|
||||
23
internal/route/rules/error_format_test.go
Normal file
23
internal/route/rules/error_format_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
func TestErrorFormat(t *testing.T) {
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: error 405
|
||||
do: error 405 error
|
||||
- on: header too many args
|
||||
do: error 405 error
|
||||
- name: missing do
|
||||
on: status 200
|
||||
- on: header X-Header
|
||||
do: set invalid_command
|
||||
- do: set resp_body "{{ .Request.Method {{ .Request.URL.Path }}"
|
||||
`, &rules)
|
||||
gperr.LogError("error", err)
|
||||
}
|
||||
@@ -7,15 +7,21 @@ import (
|
||||
var (
|
||||
ErrUnterminatedQuotes = gperr.New("unterminated quotes")
|
||||
ErrUnterminatedBrackets = gperr.New("unterminated brackets")
|
||||
ErrUnterminatedEnvVar = gperr.New("unterminated env var")
|
||||
ErrUnknownDirective = gperr.New("unknown directive")
|
||||
ErrUnknownModField = gperr.New("unknown field")
|
||||
ErrEnvVarNotFound = gperr.New("env variable not found")
|
||||
ErrInvalidArguments = gperr.New("invalid arguments")
|
||||
ErrInvalidOnTarget = gperr.New("invalid `rule.on` target")
|
||||
ErrInvalidCommandSequence = gperr.New("invalid command sequence")
|
||||
ErrInvalidSetTarget = gperr.New("invalid `rule.set` target")
|
||||
|
||||
ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg")
|
||||
ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg")
|
||||
ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args")
|
||||
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
|
||||
ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg")
|
||||
ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg")
|
||||
ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args")
|
||||
ErrExpectTwoOrThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 or 3 args")
|
||||
ErrExpectThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 3 args")
|
||||
ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args")
|
||||
ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'")
|
||||
|
||||
errTerminated = gperr.New("terminated")
|
||||
)
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type (
|
||||
FieldHandler struct {
|
||||
set, add, remove CommandHandler
|
||||
}
|
||||
FieldModifier string
|
||||
)
|
||||
|
||||
const (
|
||||
ModFieldSet FieldModifier = "set"
|
||||
ModFieldAdd FieldModifier = "add"
|
||||
ModFieldRemove FieldModifier = "remove"
|
||||
)
|
||||
|
||||
const (
|
||||
FieldHeader = "header"
|
||||
FieldQuery = "query"
|
||||
FieldCookie = "cookie"
|
||||
)
|
||||
|
||||
var modFields = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) *FieldHandler
|
||||
}{
|
||||
FieldHeader: {
|
||||
help: Help{
|
||||
command: FieldHeader,
|
||||
args: map[string]string{
|
||||
"key": "the header key",
|
||||
"value": "the header value",
|
||||
},
|
||||
},
|
||||
validate: toStrTuple,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
return &FieldHandler{
|
||||
set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header()[k] = []string{v}
|
||||
}),
|
||||
add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h[k] = append(h[k], v)
|
||||
}),
|
||||
remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) {
|
||||
delete(w.Header(), k)
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldQuery: {
|
||||
help: Help{
|
||||
command: FieldQuery,
|
||||
args: map[string]string{
|
||||
"key": "the query key",
|
||||
"value": "the query value",
|
||||
},
|
||||
},
|
||||
validate: toStrTuple,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
return &FieldHandler{
|
||||
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||
cached.UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Set(k, v)
|
||||
})
|
||||
return true
|
||||
}),
|
||||
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||
cached.UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Add(k, v)
|
||||
})
|
||||
return true
|
||||
}),
|
||||
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||
cached.UpdateQueries(r, func(queries url.Values) {
|
||||
queries.Del(k)
|
||||
})
|
||||
return true
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
FieldCookie: {
|
||||
help: Help{
|
||||
command: FieldCookie,
|
||||
args: map[string]string{
|
||||
"key": "the cookie key",
|
||||
"value": "the cookie value",
|
||||
},
|
||||
},
|
||||
validate: toStrTuple,
|
||||
builder: func(args any) *FieldHandler {
|
||||
k, v := args.(*StrTuple).Unpack()
|
||||
return &FieldHandler{
|
||||
set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
cookies[i].Value = v
|
||||
return cookies
|
||||
}
|
||||
}
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return true
|
||||
}),
|
||||
add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
return append(cookies, &http.Cookie{Name: k, Value: v})
|
||||
})
|
||||
return true
|
||||
}),
|
||||
remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool {
|
||||
cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie {
|
||||
index := -1
|
||||
for i, c := range cookies {
|
||||
if c.Name == k {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if index != -1 {
|
||||
if len(cookies) == 1 {
|
||||
return []*http.Cookie{}
|
||||
}
|
||||
return append(cookies[:index], cookies[index+1:]...)
|
||||
}
|
||||
return cookies
|
||||
})
|
||||
return true
|
||||
}),
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1,40 +1,134 @@
|
||||
package rules
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/strings/ansi"
|
||||
)
|
||||
|
||||
type Help struct {
|
||||
command string
|
||||
description string
|
||||
description []string
|
||||
args map[string]string // args[arg] -> description
|
||||
}
|
||||
|
||||
/*
|
||||
Generate help string, e.g.
|
||||
func makeLines(lines ...string) []string {
|
||||
return lines
|
||||
}
|
||||
|
||||
rewrite <from> <to>
|
||||
from: the path to rewrite, must start with /
|
||||
to: the path to rewrite to, must start with /
|
||||
*/
|
||||
func (h *Help) String() string {
|
||||
func helpExample(cmd string, args ...string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(h.command)
|
||||
sb.WriteString(" ")
|
||||
for arg := range h.args {
|
||||
sb.WriteString(strings.ToUpper(arg))
|
||||
sb.WriteRune(' ')
|
||||
}
|
||||
if h.description != "" {
|
||||
sb.WriteString("\n\t")
|
||||
sb.WriteString(h.description)
|
||||
sb.WriteRune('\n')
|
||||
}
|
||||
sb.WriteRune('\n')
|
||||
for arg, desc := range h.args {
|
||||
sb.WriteRune('\t')
|
||||
sb.WriteString(strings.ToUpper(arg))
|
||||
sb.WriteString(": ")
|
||||
sb.WriteString(desc)
|
||||
sb.WriteRune('\n')
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(ansi.WithANSI(cmd, ansi.HighlightGreen))
|
||||
for _, arg := range args {
|
||||
var out strings.Builder
|
||||
pos := 0
|
||||
for {
|
||||
start := strings.Index(arg[pos:], "{{")
|
||||
if start == -1 {
|
||||
if pos < len(arg) {
|
||||
// If no template at all (pos == 0), cyan highlight for whole-arg
|
||||
// Otherwise, for mixed strings containing templates, leave non-template text unhighlighted
|
||||
if pos == 0 {
|
||||
out.WriteString(ansi.WithANSI(arg[pos:], ansi.HighlightCyan))
|
||||
} else {
|
||||
out.WriteString(arg[pos:])
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
start += pos
|
||||
if start > pos {
|
||||
// Non-template text should not be highlighted
|
||||
out.WriteString(arg[pos:start])
|
||||
}
|
||||
end := strings.Index(arg[start+2:], "}}")
|
||||
if end == -1 {
|
||||
// Unmatched template start; write remainder without highlighting
|
||||
out.WriteString(arg[start:])
|
||||
break
|
||||
}
|
||||
end += start + 2
|
||||
inner := strings.TrimSpace(arg[start+2 : end])
|
||||
parts := strings.Split(inner, ".")
|
||||
out.WriteString(helpTemplateVar(parts...))
|
||||
pos = end + 2
|
||||
}
|
||||
fmt.Fprintf(&sb, ` "%s"`, out.String())
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func helpListItem(key string, value string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(ansi.WithANSI(key, ansi.HighlightYellow))
|
||||
sb.WriteString(": ")
|
||||
sb.WriteString(value)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// helpFuncCall generates a string like "fn(arg1, arg2, arg3)"
|
||||
func helpFuncCall(fn string, args ...string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(ansi.WithANSI(fn, ansi.HighlightRed))
|
||||
sb.WriteString("(")
|
||||
for i, arg := range args {
|
||||
fmt.Fprintf(&sb, `"%s"`, ansi.WithANSI(arg, ansi.HighlightCyan))
|
||||
if i < len(args)-1 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
}
|
||||
sb.WriteString(")")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// helpTemplateVar generates a string like "{{ .Request.Method }} {{ .Request.URL.Path }}"
|
||||
func helpTemplateVar(parts ...string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(ansi.WithANSI("{{ ", ansi.HighlightWhite))
|
||||
for i, part := range parts {
|
||||
sb.WriteString(ansi.WithANSI(part, ansi.HighlightCyan))
|
||||
if i < len(parts)-1 {
|
||||
sb.WriteString(".")
|
||||
}
|
||||
}
|
||||
sb.WriteString(ansi.WithANSI(" }}", ansi.HighlightWhite))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
/*
|
||||
Generate help string as error, e.g.
|
||||
|
||||
rewrite <from> <to>
|
||||
from: the path to rewrite, must start with /
|
||||
to: the path to rewrite to, must start with /
|
||||
*/
|
||||
func (h *Help) Error() gperr.Error {
|
||||
var lines gperr.MultilineError
|
||||
|
||||
lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen))
|
||||
lines.AddStrings(h.description...)
|
||||
lines.Adds(" args:")
|
||||
|
||||
argKeys := make([]string, 0, len(h.args))
|
||||
longestArg := 0
|
||||
for arg := range h.args {
|
||||
if len(arg) > longestArg {
|
||||
longestArg = len(arg)
|
||||
}
|
||||
argKeys = append(argKeys, arg)
|
||||
}
|
||||
|
||||
// sort argKeys alphabetically to make output stable
|
||||
slices.Sort(argKeys)
|
||||
for _, arg := range argKeys {
|
||||
desc := h.args[arg]
|
||||
lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc)
|
||||
}
|
||||
return &lines
|
||||
}
|
||||
|
||||
940
internal/route/rules/http_flow_test.go
Normal file
940
internal/route/rules/http_flow_test.go
Normal file
@@ -0,0 +1,940 @@
|
||||
package rules_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/internal/route"
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
. "github.com/yusing/godoxy/internal/route/rules"
|
||||
)
|
||||
|
||||
// mockUpstream creates a simple upstream handler for testing
|
||||
func mockUpstream(status int, body string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
|
||||
// mockUpstreamWithHeaders creates an upstream that returns specific headers
|
||||
func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
maps.Copy(w.Header(), headers)
|
||||
w.WriteHeader(status)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
}
|
||||
|
||||
func mockRoute(alias string) *route.FileServer {
|
||||
return &route.FileServer{Route: &route.Route{Alias: alias}}
|
||||
}
|
||||
|
||||
func parseRules(data string, target *Rules) gperr.Error {
|
||||
_, err := serialization.ConvertString(strings.TrimSpace(data), reflect.ValueOf(target))
|
||||
return err
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicPreRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header"))
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("upstream response"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: add-header
|
||||
on: path /
|
||||
do: set header X-Custom-Header test-value
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BypassRule(t *testing.T) {
|
||||
upstream := mockUpstream(200, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: bypass-condition
|
||||
on: path /bypass
|
||||
do: bypass
|
||||
- name: should-not-execute
|
||||
on: path /bypass
|
||||
do: error 500 "should not reach here"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/bypass", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_TerminatingCommand(t *testing.T) {
|
||||
upstream := mockUpstream(200, "should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: error-response
|
||||
on: path /error
|
||||
do: error 403 Forbidden
|
||||
- name: should-not-execute
|
||||
on: path /error
|
||||
do: set header X-Header ignored
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/error", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 403, w.Code)
|
||||
assert.Equal(t, "Forbidden\n", w.Body.String())
|
||||
assert.Empty(t, w.Header().Get("X-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RedirectFlow(t *testing.T) {
|
||||
upstream := mockUpstream(200, "should not be called")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: redirect-rule
|
||||
on: path /old-path
|
||||
do: redirect /new-path
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/old-path", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 307, w.Code) // TemporaryRedirect
|
||||
assert.Equal(t, "/new-path", w.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RewriteFlow(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("path: " + r.URL.Path))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: rewrite-rule
|
||||
on: path glob(/api/*)
|
||||
do: rewrite /api/ /v1/
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/users", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "path: /v1/users", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_MultiplePreRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id")))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: add-request-id
|
||||
on: path /
|
||||
do: set header X-Request-Id req-123
|
||||
- name: add-auth-header
|
||||
on: path /
|
||||
do: set header X-Auth-Token token-456
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "upstream: req-123", w.Body.String())
|
||||
assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_PostResponseRule(t *testing.T) {
|
||||
upstream := mockUpstreamWithHeaders(200, "success", http.Header{
|
||||
"X-Upstream": []string{"upstream-value"},
|
||||
})
|
||||
|
||||
tempFile, err := os.CreateTemp("", "test-log-*.txt")
|
||||
// Create a temporary file for logging
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-response
|
||||
on: path /test
|
||||
do: log info %s "{{ .Request.Method }} {{ .Response.StatusCode }}"
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "success", w.Body.String())
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream"))
|
||||
|
||||
// Check log file
|
||||
content, err := os.ReadFile(tempFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "GET 200\n", string(content))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/success" {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("success"))
|
||||
} else {
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("not found"))
|
||||
}
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
|
||||
// Create a temporary file for logging
|
||||
tempFile, err := os.CreateTemp("", "test-error-log-*.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Close()
|
||||
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: log-errors
|
||||
on: status 4xx
|
||||
do: log error %s "{{ .Request.URL }} returned {{ .Response.StatusCode }}"
|
||||
`, tempFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test successful request (should not log)
|
||||
req1 := httptest.NewRequest("GET", "/success", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
|
||||
// Test error request (should log)
|
||||
req2 := httptest.NewRequest("GET", "/notfound", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
|
||||
// Check log file
|
||||
content, err := os.ReadFile(tempFile.Name())
|
||||
require.NoError(t, err)
|
||||
lines := strings.Split(strings.TrimSpace(string(content)), "\n")
|
||||
require.Len(t, lines, 1, "only 4xx requests should be logged")
|
||||
assert.Equal(t, "/notfound returned 404", lines[0])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ConditionalRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("hello " + r.Header.Get("X-Username")))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: auth-required
|
||||
on: header Authorization
|
||||
do: |
|
||||
set header X-Username authenticated-user
|
||||
set resp_header X-Username authenticated-user
|
||||
- name: default
|
||||
do: |
|
||||
set header X-Username anonymous
|
||||
set resp_header X-Username anonymous
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test with Authorization header
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1.Header.Set("Authorization", "Bearer token")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "hello authenticated-user", w1.Body.String())
|
||||
assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username"))
|
||||
|
||||
// Test without Authorization header
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, "hello anonymous", w2.Body.String())
|
||||
assert.Equal(t, "anonymous", w2.Header().Get("X-Username"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate different responses based on path
|
||||
if r.URL.Path == "/protected" {
|
||||
if r.Header.Get("X-Auth") != "valid" {
|
||||
w.WriteHeader(401)
|
||||
w.Write([]byte("unauthorized"))
|
||||
return
|
||||
}
|
||||
}
|
||||
w.Header().Set("X-Response-Time", "100ms")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("success"))
|
||||
})
|
||||
|
||||
// Create temporary files for logging
|
||||
logFile, err := os.CreateTemp("", "test-access-log-*.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(logFile.Name())
|
||||
logFile.Close()
|
||||
|
||||
errorLogFile, err := os.CreateTemp("", "test-error-log-*.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(errorLogFile.Name())
|
||||
errorLogFile.Close()
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: add-correlation-id
|
||||
do: set resp_header X-Correlation-Id random_uuid
|
||||
- name: validate-auth
|
||||
on: path /protected
|
||||
do: require_basic_auth "Protected Area"
|
||||
- name: log-all-requests
|
||||
do: |
|
||||
log info %q "{{ .Request.Method }} {{ .Request.URL }} -> {{ .Response.StatusCode }}"
|
||||
- name: log-errors
|
||||
on: status 4xx
|
||||
do: |
|
||||
log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"
|
||||
`, logFile.Name(), errorLogFile.Name()), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test successful request
|
||||
req1 := httptest.NewRequest("GET", "/public", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "success", w1.Body.String())
|
||||
assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id"))
|
||||
assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time"))
|
||||
|
||||
// Test unauthorized protected request
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 401, w2.Code)
|
||||
assert.Equal(t, w2.Body.String(), "Unauthorized\n")
|
||||
|
||||
// Test authorized protected request
|
||||
req3 := httptest.NewRequest("GET", "/protected", nil)
|
||||
req3.SetBasicAuth("user", "pass")
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
// This should fail because our simple upstream expects X-Auth: valid header
|
||||
// but the basic auth requirement should add the appropriate header
|
||||
assert.Equal(t, 401, w3.Code)
|
||||
|
||||
// Check log files
|
||||
logContent, err := os.ReadFile(logFile.Name())
|
||||
require.NoError(t, err)
|
||||
lines := strings.Split(strings.TrimSpace(string(logContent)), "\n")
|
||||
require.Len(t, lines, 3, "all requests should be logged")
|
||||
assert.Equal(t, "GET /public -> 200", lines[0])
|
||||
assert.Equal(t, "GET /protected -> 401", lines[1])
|
||||
assert.Equal(t, "GET /protected -> 401", lines[2])
|
||||
|
||||
errorLogContent, err := os.ReadFile(errorLogFile.Name())
|
||||
require.NoError(t, err)
|
||||
// Should have at least one 401 error logged
|
||||
lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n")
|
||||
require.Len(t, lines, 2, "all errors should be logged")
|
||||
assert.Equal(t, "ERROR: GET /protected 401", lines[0])
|
||||
assert.Equal(t, "ERROR: GET /protected 401", lines[1])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_DefaultRule(t *testing.T) {
|
||||
upstream := mockUpstream(200, "upstream response")
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: default
|
||||
do: set resp_header X-Default-Applied true
|
||||
- name: special-rule
|
||||
on: path /special
|
||||
do: set resp_header X-Special-Handled true
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test default rule
|
||||
req1 := httptest.NewRequest("GET", "/regular", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "true", w1.Header().Get("X-Default-Applied"))
|
||||
assert.Empty(t, w1.Header().Get("X-Special-Handled"))
|
||||
|
||||
// Test special rule + default rule
|
||||
req2 := httptest.NewRequest("GET", "/special", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Default-Applied"))
|
||||
assert.Equal(t, "true", w2.Header().Get("X-Special-Handled"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_HeaderManipulation(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Echo back a header
|
||||
headerValue := r.Header.Get("X-Test-Header")
|
||||
w.Header().Set("X-Echoed-Header", headerValue)
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("header echoed"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: remove-sensitive-header
|
||||
do: remove resp_header X-Secret
|
||||
- name: add-custom-header
|
||||
do: add resp_header X-Custom-Header custom-value
|
||||
- name: modify-existing-header
|
||||
on: header X-Test-Header
|
||||
do: set header X-Test-Header modified-value
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.Header.Set("X-Secret", "secret-value")
|
||||
req.Header.Set("X-Test-Header", "original-value")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header"))
|
||||
assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header"))
|
||||
// Ensure the secret header was removed and not passed to upstream
|
||||
// (we can't directly test this, but the upstream shouldn't see it)
|
||||
}
|
||||
|
||||
func TestHTTPFlow_QueryParameterHandling(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("query: " + query.Get("param")))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: add-query-param
|
||||
on: query param
|
||||
do: set query param added-value
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/path?param=original", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
// The set command should have modified the query parameter
|
||||
assert.Equal(t, "query: added-value", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ServeCommand(t *testing.T) {
|
||||
// Create a temporary directory with test files
|
||||
tempDir, err := os.MkdirTemp("", "test-serve-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create test files directly in the temp directory
|
||||
testFile := filepath.Join(tempDir, "index.html")
|
||||
err = os.WriteFile(testFile, []byte("<h1>Test Page</h1>"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: serve-static
|
||||
on: path glob(/files/*)
|
||||
do: serve %s
|
||||
`, tempDir), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
|
||||
|
||||
// Test serving a file - serve command serves files relative to the root directory
|
||||
// The path /files/index.html gets mapped to tempDir + "/files/index.html"
|
||||
// We need to create the file at the expected path
|
||||
filesDir := filepath.Join(tempDir, "files")
|
||||
err = os.Mkdir(filesDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
filesIndexFile := filepath.Join(filesDir, "index.html")
|
||||
err = os.WriteFile(filesIndexFile, []byte("<h1>Test Page</h1>"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
req1 := httptest.NewRequest("GET", "/files/index.html", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
// The serve command should work, but might redirect
|
||||
// Let's just verify it doesn't call the upstream
|
||||
assert.NotEqual(t, "should not be called", w1.Body.String())
|
||||
|
||||
// Test file not found
|
||||
req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 404, w2.Code)
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ProxyCommand(t *testing.T) {
|
||||
// Create a mock upstream server
|
||||
upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream-Header", "upstream-value")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("upstream response"))
|
||||
}))
|
||||
defer upstreamServer.Close()
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(fmt.Sprintf(`
|
||||
- name: proxy-to-upstream
|
||||
on: path glob(/api/*)
|
||||
do: proxy %s
|
||||
`, upstreamServer.URL), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(mockUpstream(200, "should not be called"))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// The proxy command should forward the request to the upstream server
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "upstream response", w.Body.String())
|
||||
assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_NotifyCommand(t *testing.T) {
|
||||
// TODO:
|
||||
}
|
||||
|
||||
func TestHTTPFlow_FormConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("form processed"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: process-form
|
||||
on: form username
|
||||
do: set resp_header X-Username "{{ index .Request.Form.username 0 }}"
|
||||
- name: process-postform
|
||||
on: postform email
|
||||
do: set resp_header X-Email "{{ index .Request.PostForm.email 0 }}"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test form condition
|
||||
formData := url.Values{"username": {"john_doe"}}
|
||||
req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode()))
|
||||
req1.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "john_doe", w1.Header().Get("X-Username"))
|
||||
|
||||
// Test postform condition
|
||||
postFormData := url.Values{"email": {"john@example.com"}}
|
||||
req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode()))
|
||||
req2.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, "john@example.com", w2.Header().Get("X-Email"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RemoteConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("remote processed"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: allow-localhost
|
||||
on: remote 127.0.0.1
|
||||
do: set resp_header X-Access "local"
|
||||
- name: block-private
|
||||
on: remote 192.168.0.0/16
|
||||
do: error 403 "Private network blocked"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test localhost condition
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1.RemoteAddr = "127.0.0.1:12345"
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "local", w1.Header().Get("X-Access"))
|
||||
|
||||
// Test private network block
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2.RemoteAddr = "192.168.1.100:12345"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 403, w2.Code)
|
||||
assert.Equal(t, "Private network blocked\n", w2.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_BasicAuthConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("auth processed"))
|
||||
})
|
||||
|
||||
// Generate bcrypt hashes for passwords
|
||||
adminHash, err := bcrypt.GenerateFromPassword([]byte("adminpass"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
guestHash, err := bcrypt.GenerateFromPassword([]byte("guestpass"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
var rules Rules
|
||||
err = parseRules(fmt.Sprintf(`
|
||||
- name: check-auth
|
||||
on: basic_auth admin %s
|
||||
do: set resp_header X-Auth-Status "admin"
|
||||
- name: check-other-user
|
||||
on: basic_auth guest %s
|
||||
do: set resp_header X-Auth-Status "guest"
|
||||
`, string(adminHash), string(guestHash)), &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test admin user
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1.SetBasicAuth("admin", "adminpass")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status"))
|
||||
|
||||
// Test guest user
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2.SetBasicAuth("guest", "guestpass")
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_RouteConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("route processed"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: backend-route
|
||||
on: route backend
|
||||
do: set resp_header X-Route "backend"
|
||||
- name: frontend-route
|
||||
on: route frontend
|
||||
do: set resp_header X-Route "frontend"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test API route
|
||||
req1 := httptest.NewRequest("GET", "/", nil)
|
||||
req1 = routes.WithRouteContext(req1, mockRoute("backend"))
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "backend", w1.Header().Get("X-Route"))
|
||||
|
||||
// Test admin route
|
||||
req2 := httptest.NewRequest("GET", "/", nil)
|
||||
req2 = routes.WithRouteContext(req2, mockRoute("frontend"))
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, "frontend", w2.Header().Get("X-Route"))
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseStatusConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(405)
|
||||
w.Write([]byte("method not allowed"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: method-not-allowed
|
||||
on: status 405
|
||||
do: |
|
||||
error 405 'error'
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Response-Header", "response header")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("processed"))
|
||||
})
|
||||
|
||||
t.Run("any_value", func(t *testing.T) {
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: resp_header X-Response-Header
|
||||
do: |
|
||||
error 405 "error"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
})
|
||||
t.Run("with_value", func(t *testing.T) {
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: resp_header X-Response-Header "response header"
|
||||
do: |
|
||||
error 405 "error"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 405, w.Code)
|
||||
assert.Equal(t, "error\n", w.Body.String())
|
||||
})
|
||||
|
||||
t.Run("with_value_not_matched", func(t *testing.T) {
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- on: resp_header X-Response-Header "not-matched"
|
||||
do: |
|
||||
error 405 "error"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "processed", w.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("complex processed"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: admin-api
|
||||
on: |
|
||||
path glob(/api/admin/*)
|
||||
header Authorization
|
||||
method POST
|
||||
do: |
|
||||
set resp_header X-Access-Level "admin"
|
||||
set resp_header X-API-Version "v1"
|
||||
- name: user-api
|
||||
on: |
|
||||
path glob(/api/users/*) & method GET
|
||||
do: |
|
||||
set resp_header X-Access-Level "user"
|
||||
set resp_header X-API-Version "v1"
|
||||
- name: public-api
|
||||
on: |
|
||||
path glob(/api/public/*) & method GET
|
||||
do: |
|
||||
set resp_header X-Access-Level "public"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
// Test admin API (should match first rule)
|
||||
req1 := httptest.NewRequest("POST", "/api/admin/users", nil)
|
||||
req1.Header.Set("Authorization", "Bearer token")
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
assert.Equal(t, 200, w1.Code)
|
||||
assert.Equal(t, "admin", w1.Header().Get("X-Access-Level"))
|
||||
assert.Equal(t, "v1", w1.Header()["X-API-Version"][0])
|
||||
|
||||
// Test user API (should match second rule)
|
||||
req2 := httptest.NewRequest("GET", "/api/users/profile", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
assert.Equal(t, 200, w2.Code)
|
||||
assert.Equal(t, "user", w2.Header().Get("X-Access-Level"))
|
||||
assert.Equal(t, "v1", w2.Header()["X-API-Version"][0])
|
||||
|
||||
// Test public API (should match third rule)
|
||||
req3 := httptest.NewRequest("GET", "/api/public/info", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w3, req3)
|
||||
|
||||
assert.Equal(t, 200, w3.Code)
|
||||
assert.Equal(t, "public", w3.Header().Get("X-Access-Level"))
|
||||
assert.Empty(t, w3.Header()["X-API-Version"])
|
||||
}
|
||||
|
||||
func TestHTTPFlow_ResponseModifier(t *testing.T) {
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("original response"))
|
||||
})
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: modify-response
|
||||
do: |
|
||||
set resp_header X-Modified "true"
|
||||
set resp_body "Modified: {{ .Request.Method }} {{ .Request.URL.Path }}"
|
||||
`, &rules)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := rules.BuildHandler(upstream)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "true", w.Header().Get("X-Modified"))
|
||||
assert.Equal(t, "Modified: GET /test\n", w.Body.String())
|
||||
}
|
||||
36
internal/route/rules/io.go
Normal file
36
internal/route/rules/io.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/yusing/godoxy/internal/logging/accesslog"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type noopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (n noopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
stdout io.WriteCloser = noopWriteCloser{os.Stdout}
|
||||
stderr io.WriteCloser = noopWriteCloser{os.Stderr}
|
||||
)
|
||||
|
||||
func openFile(path string) (io.WriteCloser, gperr.Error) {
|
||||
switch path {
|
||||
case "/dev/stdout":
|
||||
return stdout, nil
|
||||
case "/dev/stderr":
|
||||
return stderr, nil
|
||||
}
|
||||
f, err := accesslog.NewFileIO(path)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
120
internal/route/rules/matcher.go
Normal file
120
internal/route/rules/matcher.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type (
|
||||
Matcher func(string) bool
|
||||
MatcherType string
|
||||
)
|
||||
|
||||
const (
|
||||
MatcherTypeString MatcherType = "string"
|
||||
MatcherTypeGlob MatcherType = "glob"
|
||||
MatcherTypeRegex MatcherType = "regex"
|
||||
)
|
||||
|
||||
func unquoteExpr(s string) (string, gperr.Error) {
|
||||
if s == "" {
|
||||
return "", nil
|
||||
}
|
||||
switch s[0] {
|
||||
case '"', '\'', '`':
|
||||
if s[0] != s[len(s)-1] {
|
||||
return "", ErrUnterminatedQuotes
|
||||
}
|
||||
return s[1 : len(s)-1], nil
|
||||
default:
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
|
||||
idx := strings.IndexByte(s, '(')
|
||||
if idx == -1 {
|
||||
return MatcherTypeString, s, nil
|
||||
}
|
||||
idxEnd := strings.LastIndexByte(s, ')')
|
||||
if idxEnd == -1 {
|
||||
return "", "", ErrUnterminatedBrackets
|
||||
}
|
||||
|
||||
expr, err = unquoteExpr(s[idx+1 : idxEnd])
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
matcherType = MatcherType(strings.ToLower(s[:idx]))
|
||||
|
||||
switch matcherType {
|
||||
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
|
||||
return
|
||||
default:
|
||||
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
|
||||
}
|
||||
}
|
||||
|
||||
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
negate := false
|
||||
if strings.HasPrefix(expr, "!") {
|
||||
negate = true
|
||||
expr = expr[1:]
|
||||
}
|
||||
|
||||
t, expr, err := ExtractExpr(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case MatcherTypeString:
|
||||
return StringMatcher(expr, negate)
|
||||
case MatcherTypeGlob:
|
||||
return GlobMatcher(expr, negate)
|
||||
case MatcherTypeRegex:
|
||||
return RegexMatcher(expr, negate)
|
||||
}
|
||||
// won't reach here
|
||||
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
||||
}
|
||||
|
||||
func StringMatcher(s string, negate bool) (Matcher, gperr.Error) {
|
||||
if negate {
|
||||
return func(s2 string) bool {
|
||||
return s != s2
|
||||
}, nil
|
||||
}
|
||||
return func(s2 string) bool {
|
||||
return s == s2
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GlobMatcher(expr string, negate bool) (Matcher, gperr.Error) {
|
||||
g, err := glob.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if negate {
|
||||
return func(s string) bool {
|
||||
return !g.Match(s)
|
||||
}, nil
|
||||
}
|
||||
return g.Match, nil
|
||||
}
|
||||
|
||||
func RegexMatcher(expr string, negate bool) (Matcher, gperr.Error) {
|
||||
re, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if negate {
|
||||
return func(s string) bool {
|
||||
return !re.MatchString(s)
|
||||
}, nil
|
||||
}
|
||||
return re.MatchString, nil
|
||||
}
|
||||
35
internal/route/rules/matcher_bench_test.go
Normal file
35
internal/route/rules/matcher_bench_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package rules
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkMatcher(b *testing.B) {
|
||||
b.Run("StringMatcher", func(b *testing.B) {
|
||||
matcher, err := StringMatcher("foo", false)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for b.Loop() {
|
||||
matcher("foo")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GlobMatcher", func(b *testing.B) {
|
||||
matcher, err := GlobMatcher("foo*bar?baz*[abc]*.txt", false)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for b.Loop() {
|
||||
matcher("foooooobarzbazcb.txt")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("RegexMatcher", func(b *testing.B) {
|
||||
matcher, err := RegexMatcher(`^(foo\d+|bar(_baz)?)[a-z]{3,}\.txt$`, false)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for b.Loop() {
|
||||
matcher("foo123abcd.txt")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -49,6 +49,18 @@ func TestExtractExpr(t *testing.T) {
|
||||
wantT: MatcherTypeRegex,
|
||||
wantExpr: "^[A-Z]+$",
|
||||
},
|
||||
{
|
||||
name: "regex with parentheses",
|
||||
in: "regex(test(group))",
|
||||
wantT: MatcherTypeRegex,
|
||||
wantExpr: "test(group)",
|
||||
},
|
||||
{
|
||||
name: "regex complex",
|
||||
in: `regex("^(_next/static|_next/image|favicon.ico).*$")`,
|
||||
wantT: MatcherTypeRegex,
|
||||
wantExpr: "^(_next/static|_next/image|favicon.ico).*$",
|
||||
},
|
||||
{
|
||||
name: "quoted expr",
|
||||
in: "glob(`'foo'`)",
|
||||
@@ -96,3 +108,62 @@ func TestExtractExprInvalid(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegated(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expr string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "negated_string_match",
|
||||
expr: "!string(`foo`)",
|
||||
in: "foo",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negated_string_no_match",
|
||||
expr: "!string(`foo`)",
|
||||
in: "bar",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "negated_glob_match",
|
||||
expr: "!glob(`foo`)",
|
||||
in: "foo",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negated_glob_no_match",
|
||||
expr: "!glob(`foo`)",
|
||||
in: "bar",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "negated_regex_match",
|
||||
expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)",
|
||||
in: "favicon.ico",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "negated_regex_no_match",
|
||||
expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)",
|
||||
in: "bar",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "negated_regex_no_match2",
|
||||
expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)",
|
||||
in: "/",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matcher, err := ParseMatcher(tt.expr)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, tt.want, matcher(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,16 +8,20 @@ import (
|
||||
|
||||
"github.com/yusing/godoxy/internal/route/routes"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
)
|
||||
|
||||
type RuleOn struct {
|
||||
raw string
|
||||
checker Checker
|
||||
raw string
|
||||
checker Checker
|
||||
isResponseChecker bool
|
||||
}
|
||||
|
||||
func (on *RuleOn) Check(cached Cache, r *http.Request) bool {
|
||||
return on.checker.Check(cached, r)
|
||||
func (on *RuleOn) IsResponseChecker() bool {
|
||||
return on.isResponseChecker
|
||||
}
|
||||
|
||||
func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
return on.checker.Check(w, r)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -32,20 +36,27 @@ const (
|
||||
OnRemote = "remote"
|
||||
OnBasicAuth = "basic_auth"
|
||||
OnRoute = "route"
|
||||
|
||||
// on response
|
||||
OnResponseHeader = "resp_header"
|
||||
OnStatus = "status"
|
||||
)
|
||||
|
||||
var checkers = map[string]struct {
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
help Help
|
||||
validate ValidateFunc
|
||||
builder func(args any) CheckFunc
|
||||
isResponseChecker bool
|
||||
}{
|
||||
OnHeader: {
|
||||
help: Help{
|
||||
command: OnHeader,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
header username "user"
|
||||
header username glob("user*")
|
||||
header username regex("user.*")`,
|
||||
description: makeLines(
|
||||
"Value supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnHeader, "username", "user"),
|
||||
helpExample(OnHeader, "username", helpFuncCall("glob", "user*")),
|
||||
helpExample(OnHeader, "username", helpFuncCall("regex", "user.*")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"key": "the header key",
|
||||
"[value]": "the header value",
|
||||
@@ -55,22 +66,52 @@ var checkers = map[string]struct {
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(r.Header[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(r.Header[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnResponseHeader: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnResponseHeader,
|
||||
description: makeLines(
|
||||
"Value supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnResponseHeader, "username", "user"),
|
||||
helpExample(OnResponseHeader, "username", helpFuncCall("glob", "user*")),
|
||||
helpExample(OnResponseHeader, "username", helpFuncCall("regex", "user.*")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"key": "the response header key",
|
||||
"[value]": "the response header value",
|
||||
},
|
||||
},
|
||||
validate: toKVOptionalVMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(GetInitResponseModifier(w).Header()[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(GetInitResponseModifier(w).Header()[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnQuery: {
|
||||
help: Help{
|
||||
command: OnQuery,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
query username "user"
|
||||
query username glob("user*")
|
||||
query username regex("user.*")`,
|
||||
description: makeLines(
|
||||
"Value supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnQuery, "username", "user"),
|
||||
helpExample(OnQuery, "username", helpFuncCall("glob", "user*")),
|
||||
helpExample(OnQuery, "username", helpFuncCall("regex", "user.*")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"key": "the query key",
|
||||
"[value]": "the query value",
|
||||
@@ -80,22 +121,24 @@ var checkers = map[string]struct {
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return len(cached.GetQueries(r)[k]) > 0
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return len(GetSharedData(w).GetQueries(r)[k]) > 0
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return slices.ContainsFunc(cached.GetQueries(r)[k], matcher)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return slices.ContainsFunc(GetSharedData(w).GetQueries(r)[k], matcher)
|
||||
}
|
||||
},
|
||||
},
|
||||
OnCookie: {
|
||||
help: Help{
|
||||
command: OnCookie,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
cookie username "user"
|
||||
cookie username glob("user*")
|
||||
cookie username regex("user.*")`,
|
||||
description: makeLines(
|
||||
"Value supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnCookie, "username", "user"),
|
||||
helpExample(OnCookie, "username", helpFuncCall("glob", "user*")),
|
||||
helpExample(OnCookie, "username", helpFuncCall("regex", "user.*")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"key": "the cookie key",
|
||||
"[value]": "the cookie value",
|
||||
@@ -105,8 +148,8 @@ var checkers = map[string]struct {
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
cookies := cached.GetCookies(r)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := GetSharedData(w).GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
return true
|
||||
@@ -115,8 +158,8 @@ var checkers = map[string]struct {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
cookies := cached.GetCookies(r)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
cookies := GetSharedData(w).GetCookies(r)
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == k {
|
||||
if matcher(cookie.Value) {
|
||||
@@ -131,10 +174,12 @@ var checkers = map[string]struct {
|
||||
OnForm: {
|
||||
help: Help{
|
||||
command: OnForm,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
form username "user"
|
||||
form username glob("user*")
|
||||
form username regex("user.*")`,
|
||||
description: makeLines(
|
||||
"Value supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnForm, "username", "user"),
|
||||
helpExample(OnForm, "username", helpFuncCall("glob", "user*")),
|
||||
helpExample(OnForm, "username", helpFuncCall("regex", "user.*")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"key": "the form key",
|
||||
"[value]": "the form value",
|
||||
@@ -144,11 +189,11 @@ var checkers = map[string]struct {
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return r.FormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return matcher(r.FormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -156,10 +201,12 @@ var checkers = map[string]struct {
|
||||
OnPostForm: {
|
||||
help: Help{
|
||||
command: OnPostForm,
|
||||
description: `Value supports string, glob pattern, or regex pattern, e.g.:
|
||||
postform username "user"
|
||||
postform username glob("user*")
|
||||
postform username regex("user.*")`,
|
||||
description: makeLines(
|
||||
"Value supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnPostForm, "username", "user"),
|
||||
helpExample(OnPostForm, "username", helpFuncCall("glob", "user*")),
|
||||
helpExample(OnPostForm, "username", helpFuncCall("regex", "user.*")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"key": "the form key",
|
||||
"[value]": "the form value",
|
||||
@@ -169,11 +216,11 @@ var checkers = map[string]struct {
|
||||
builder: func(args any) CheckFunc {
|
||||
k, matcher := args.(*MapValueMatcher).Unpack()
|
||||
if matcher == nil {
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return r.PostFormValue(k) != ""
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return matcher(r.PostFormValue(k))
|
||||
}
|
||||
},
|
||||
@@ -188,7 +235,7 @@ var checkers = map[string]struct {
|
||||
validate: validateMethod,
|
||||
builder: func(args any) CheckFunc {
|
||||
method := args.(string)
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return r.Method == method
|
||||
}
|
||||
},
|
||||
@@ -196,11 +243,13 @@ var checkers = map[string]struct {
|
||||
OnHost: {
|
||||
help: Help{
|
||||
command: OnHost,
|
||||
description: `Supports string, glob pattern, or regex pattern, e.g.:
|
||||
host example.com
|
||||
host glob(example*.com)
|
||||
host regex(example\w+\.com)
|
||||
host regex(example\.com$)`,
|
||||
description: makeLines(
|
||||
"Supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnHost, "example.com"),
|
||||
helpExample(OnHost, helpFuncCall("glob", "example*.com")),
|
||||
helpExample(OnHost, helpFuncCall("regex", `(example\w+\.com)`)),
|
||||
helpExample(OnHost, helpFuncCall("regex", `example\.com$`)),
|
||||
),
|
||||
args: map[string]string{
|
||||
"host": "the host name",
|
||||
},
|
||||
@@ -208,7 +257,7 @@ var checkers = map[string]struct {
|
||||
validate: validateSingleMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return matcher(r.Host)
|
||||
}
|
||||
},
|
||||
@@ -216,11 +265,13 @@ var checkers = map[string]struct {
|
||||
OnPath: {
|
||||
help: Help{
|
||||
command: OnPath,
|
||||
description: `Supports string, glob pattern, or regex pattern, e.g.:
|
||||
path /path/to
|
||||
path glob(/path/to/*)
|
||||
path regex(^/path/to/.*$)
|
||||
path regex(/path/[A-Z]+/)`,
|
||||
description: makeLines(
|
||||
"Supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnPath, "/path/to"),
|
||||
helpExample(OnPath, helpFuncCall("glob", "/path/to/*")),
|
||||
helpExample(OnPath, helpFuncCall("regex", `^/path/to/.*$`)),
|
||||
helpExample(OnPath, helpFuncCall("regex", `/path/[A-Z]+/`)),
|
||||
),
|
||||
args: map[string]string{
|
||||
"path": "the request path",
|
||||
},
|
||||
@@ -228,7 +279,7 @@ var checkers = map[string]struct {
|
||||
validate: validateURLPathMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
reqPath := r.URL.Path
|
||||
if len(reqPath) > 0 && reqPath[0] != '/' {
|
||||
reqPath = "/" + reqPath
|
||||
@@ -250,16 +301,16 @@ var checkers = map[string]struct {
|
||||
// for /32 (IPv4) or /128 (IPv6), just compare the IP
|
||||
if ones, bits := ipnet.Mask.Size(); ones == bits {
|
||||
wantIP := ipnet.IP
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
ip := cached.GetRemoteIP(r)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := GetSharedData(w).GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.Equal(wantIP)
|
||||
}
|
||||
}
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
ip := cached.GetRemoteIP(r)
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
ip := GetSharedData(w).GetRemoteIP(r)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
@@ -278,18 +329,20 @@ var checkers = map[string]struct {
|
||||
validate: validateUserBCryptPassword,
|
||||
builder: func(args any) CheckFunc {
|
||||
cred := args.(*HashedCrendentials)
|
||||
return func(cached Cache, r *http.Request) bool {
|
||||
return cred.Match(cached.GetBasicAuth(r))
|
||||
return func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return cred.Match(GetSharedData(w).GetBasicAuth(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
OnRoute: {
|
||||
help: Help{
|
||||
command: OnRoute,
|
||||
description: `Supports string, glob pattern, or regex pattern, e.g.:
|
||||
route example
|
||||
route glob(example*)
|
||||
route regex(example\w+)`,
|
||||
description: makeLines(
|
||||
"Supports string, glob pattern, or regex pattern, e.g.:",
|
||||
helpExample(OnRoute, "example"),
|
||||
helpExample(OnRoute, helpFuncCall("glob", "example*")),
|
||||
helpExample(OnRoute, helpFuncCall("regex", "example\\w+")),
|
||||
),
|
||||
args: map[string]string{
|
||||
"route": "the route name",
|
||||
},
|
||||
@@ -297,11 +350,43 @@ var checkers = map[string]struct {
|
||||
validate: validateSingleMatcher,
|
||||
builder: func(args any) CheckFunc {
|
||||
matcher := args.(Matcher)
|
||||
return func(_ Cache, r *http.Request) bool {
|
||||
return func(_ http.ResponseWriter, r *http.Request) bool {
|
||||
return matcher(routes.TryGetUpstreamName(r))
|
||||
}
|
||||
},
|
||||
},
|
||||
OnStatus: {
|
||||
isResponseChecker: true,
|
||||
help: Help{
|
||||
command: OnStatus,
|
||||
description: makeLines(
|
||||
"Supported formats are:",
|
||||
helpExample(OnStatus, "<status>"),
|
||||
helpExample(OnStatus, "<status>-<status>"),
|
||||
helpExample(OnStatus, "1xx"),
|
||||
helpExample(OnStatus, "2xx"),
|
||||
helpExample(OnStatus, "3xx"),
|
||||
helpExample(OnStatus, "4xx"),
|
||||
helpExample(OnStatus, "5xx"),
|
||||
),
|
||||
args: map[string]string{
|
||||
"status": "the status code range",
|
||||
},
|
||||
},
|
||||
validate: validateStatusRange,
|
||||
builder: func(args any) CheckFunc {
|
||||
beg, end := args.(*IntTuple).Unpack()
|
||||
if beg == end {
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
return GetInitResponseModifier(w).StatusCode() == beg
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, _ *http.Request) bool {
|
||||
statusCode := GetInitResponseModifier(w).StatusCode()
|
||||
return statusCode >= beg && statusCode <= end
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -367,6 +452,66 @@ func splitAnd(s string) []string {
|
||||
return a[:i]
|
||||
}
|
||||
|
||||
// splitPipe splits a string by "|" but respects quotes, brackets, and escaped characters.
|
||||
// It's similar to the parser.go logic but specifically for pipe splitting.
|
||||
func splitPipe(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var result []string
|
||||
var current strings.Builder
|
||||
escaped := false
|
||||
quote := rune(0)
|
||||
brackets := 0
|
||||
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
escaped = true
|
||||
current.WriteRune(r)
|
||||
case '"', '\'', '`':
|
||||
if quote == 0 && brackets == 0 {
|
||||
quote = r
|
||||
} else if r == quote {
|
||||
quote = 0
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '(':
|
||||
brackets++
|
||||
current.WriteRune(r)
|
||||
case ')':
|
||||
if brackets > 0 {
|
||||
brackets--
|
||||
}
|
||||
current.WriteRune(r)
|
||||
case '|':
|
||||
if quote == 0 && brackets == 0 {
|
||||
// Found a pipe outside quotes/brackets, split here
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
current.Reset()
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last part
|
||||
if current.Len() > 0 {
|
||||
result = append(result, strings.TrimSpace(current.String()))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Parse implements strutils.Parser.
|
||||
func (on *RuleOn) Parse(v string) error {
|
||||
on.raw = v
|
||||
@@ -375,19 +520,24 @@ func (on *RuleOn) Parse(v string) error {
|
||||
checkAnd := make(CheckMatchAll, 0, len(rules))
|
||||
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
isResponseChecker := false
|
||||
for i, rule := range rules {
|
||||
if rule == "" {
|
||||
continue
|
||||
}
|
||||
parsed, err := parseOn(rule)
|
||||
parsed, isResp, err := parseOn(rule)
|
||||
if err != nil {
|
||||
errs.Add(err.Subjectf("line %d", i+1))
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
}
|
||||
checkAnd = append(checkAnd, parsed)
|
||||
}
|
||||
|
||||
on.checker = checkAnd
|
||||
on.isResponseChecker = isResponseChecker
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
@@ -399,40 +549,57 @@ func (on *RuleOn) MarshalText() ([]byte, error) {
|
||||
return []byte(on.String()), nil
|
||||
}
|
||||
|
||||
func parseOn(line string) (Checker, gperr.Error) {
|
||||
ors := strutils.SplitRune(line, '|')
|
||||
func parseOn(line string) (Checker, bool, gperr.Error) {
|
||||
ors := splitPipe(line)
|
||||
|
||||
if len(ors) > 1 {
|
||||
errs := gperr.NewBuilder("rule.on syntax errors")
|
||||
checkOr := make(CheckMatchSingle, len(ors))
|
||||
isResponseChecker := false
|
||||
for i, or := range ors {
|
||||
curCheckers, err := parseOn(or)
|
||||
curCheckers, isResp, err := parseOn(or)
|
||||
if err != nil {
|
||||
errs.Add(err)
|
||||
continue
|
||||
}
|
||||
if isResp {
|
||||
isResponseChecker = true
|
||||
}
|
||||
checkOr[i] = curCheckers.(CheckFunc)
|
||||
}
|
||||
if err := errs.Error(); err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
return checkOr, nil
|
||||
return checkOr, isResponseChecker, nil
|
||||
}
|
||||
|
||||
subject, args, err := parse(line)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
negate := false
|
||||
if strings.HasPrefix(subject, "!") {
|
||||
negate = true
|
||||
subject = subject[1:]
|
||||
}
|
||||
|
||||
checker, ok := checkers[subject]
|
||||
if !ok {
|
||||
return nil, ErrInvalidOnTarget.Subject(subject)
|
||||
return nil, false, ErrInvalidOnTarget.Subject(subject)
|
||||
}
|
||||
|
||||
validArgs, err := checker.validate(args)
|
||||
if err != nil {
|
||||
return nil, err.Subject(subject).Withf("%s", checker.help.String())
|
||||
return nil, false, err.Subject(subject).With(checker.help.Error())
|
||||
}
|
||||
|
||||
return checker.builder(validArgs), nil
|
||||
checkFunc := checker.builder(validArgs)
|
||||
if negate {
|
||||
origCheckFunc := checkFunc
|
||||
checkFunc = func(w http.ResponseWriter, r *http.Request) bool {
|
||||
return !origCheckFunc(w, r)
|
||||
}
|
||||
}
|
||||
return checkFunc, checker.isResponseChecker, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,86 @@ import (
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
func TestSplitPipe(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
input: "",
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
input: "rule",
|
||||
want: []string{"rule"},
|
||||
},
|
||||
{
|
||||
name: "simple_pipe",
|
||||
input: "rule1 | rule2",
|
||||
want: []string{"rule1", "rule2"},
|
||||
},
|
||||
{
|
||||
name: "multiple_pipes",
|
||||
input: "rule1 | rule2 | rule3",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "pipe_in_quotes",
|
||||
input: `path regex("^(_next/static|_next/image|favicon.ico).*$")`,
|
||||
want: []string{`path regex("^(_next/static|_next/image|favicon.ico).*$")`},
|
||||
},
|
||||
{
|
||||
name: "pipe_in_single_quotes",
|
||||
input: `path regex('^(_next/static|_next/image|favicon.ico).*$')`,
|
||||
want: []string{`path regex('^(_next/static|_next/image|favicon.ico).*$')`},
|
||||
},
|
||||
{
|
||||
name: "pipe_in_backticks",
|
||||
input: "path regex(`^(_next/static|_next/image|favicon.ico).*$`)",
|
||||
want: []string{"path regex(`^(_next/static|_next/image|favicon.ico).*$`)"},
|
||||
},
|
||||
{
|
||||
name: "pipe_in_brackets",
|
||||
input: "path regex(^(_next/static|_next/image|favicon.ico).*$)",
|
||||
want: []string{"path regex(^(_next/static|_next/image|favicon.ico).*$)"},
|
||||
},
|
||||
{
|
||||
name: "escaped_pipe",
|
||||
input: `path regex("^(_next/static\|_next/image\|favicon.ico).*$")`,
|
||||
want: []string{`path regex("^(_next/static\|_next/image\|favicon.ico).*$")`},
|
||||
},
|
||||
{
|
||||
name: "mixed_quotes_and_pipes",
|
||||
input: `rule1 | path regex("^(_next/static|_next/image|favicon.ico).*$") | rule3`,
|
||||
want: []string{"rule1", `path regex("^(_next/static|_next/image|favicon.ico).*$")`, "rule3"},
|
||||
},
|
||||
{
|
||||
name: "nested_brackets",
|
||||
input: "path regex(^(foo|bar(baz|qux)).*$)",
|
||||
want: []string{"path regex(^(foo|bar(baz|qux)).*$)"},
|
||||
},
|
||||
{
|
||||
name: "spaces_around",
|
||||
input: " rule1 | rule2 | rule3 ",
|
||||
want: []string{"rule1", "rule2", "rule3"},
|
||||
},
|
||||
{
|
||||
name: "empty_segments",
|
||||
input: "rule1 || rule2 | | rule3",
|
||||
want: []string{"rule1", "", "rule2", "", "rule3"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitPipe(tt.input)
|
||||
expect.Equal(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitAnd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -179,6 +259,27 @@ func TestParseOn(t *testing.T) {
|
||||
input: "route example1 example2",
|
||||
wantErr: ErrExpectOneArg,
|
||||
},
|
||||
// pipe splitting tests
|
||||
{
|
||||
name: "pipe_simple",
|
||||
input: "method GET | method POST",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "pipe_in_quotes",
|
||||
input: `path regex("^(_next/static|_next/image|favicon.ico).*$")`,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "pipe_in_brackets",
|
||||
input: "path regex(^(_next/static|_next/image|favicon.ico).*$)",
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "pipe_mixed",
|
||||
input: `method GET | path regex("^(_next/static|_next/image|favicon.ico).*$") | header Authorization`,
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@@ -47,6 +48,18 @@ func genCorrectnessTestCases(field string, genRequest func(k, v string) *http.Re
|
||||
input: genRequest("bar", "abcd"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: field + "_negated_match",
|
||||
checker: "!" + field + " foo",
|
||||
input: genRequest("foo", "bar"),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: field + "_negated_no_match",
|
||||
checker: "!" + field + " foo",
|
||||
input: genRequest("bar", "foo"),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,6 +77,18 @@ func TestOnCorrectness(t *testing.T) {
|
||||
input: &http.Request{Method: http.MethodPost},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "method_negated_match",
|
||||
checker: "!method GET",
|
||||
input: &http.Request{Method: http.MethodGet},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "method_negated_no_match",
|
||||
checker: "!method GET",
|
||||
input: &http.Request{Method: http.MethodPost},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "host_match",
|
||||
checker: "host example.com",
|
||||
@@ -80,6 +105,22 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "host_negated_match",
|
||||
checker: "!host example.com",
|
||||
input: &http.Request{
|
||||
Host: "example.com",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "host_negated_no_match",
|
||||
checker: "!host example.com",
|
||||
input: &http.Request{
|
||||
Host: "example.org",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "path_exact_match",
|
||||
checker: "path /example",
|
||||
@@ -88,6 +129,22 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "path_negated_match",
|
||||
checker: "!path /example",
|
||||
input: &http.Request{
|
||||
URL: &url.URL{Path: "/example"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "path_negated_no_match",
|
||||
checker: "!path /example",
|
||||
input: &http.Request{
|
||||
URL: &url.URL{Path: "/example/foo"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "remote_match",
|
||||
checker: "remote 192.168.1.0/24",
|
||||
@@ -96,6 +153,22 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "remote_negated_match",
|
||||
checker: "!remote 192.168.1.0/24",
|
||||
input: &http.Request{
|
||||
RemoteAddr: "192.168.1.5",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "remote_negated_no_match",
|
||||
checker: "!remote 192.168.1.0/24",
|
||||
input: &http.Request{
|
||||
RemoteAddr: "192.168.2.5",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "remote_no_match",
|
||||
checker: "remote 192.168.1.0/24",
|
||||
@@ -124,6 +197,26 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "basic_auth_negated_match",
|
||||
checker: "!basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
|
||||
input: &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password"
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "basic_auth_negated_no_match",
|
||||
checker: "!basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))),
|
||||
input: &http.Request{
|
||||
Header: http.Header{
|
||||
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong"
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "route_match",
|
||||
checker: "route example",
|
||||
@@ -141,6 +234,23 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "route_negated_match",
|
||||
checker: "!route example",
|
||||
input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{
|
||||
Alias: "example",
|
||||
Root: "/",
|
||||
}))),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "route_negated_no_match",
|
||||
checker: "!route example",
|
||||
input: &http.Request{
|
||||
Header: http.Header{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex_match",
|
||||
checker: `host regex(example\w+\.com)`,
|
||||
@@ -157,6 +267,22 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "regex_negated_match",
|
||||
checker: `!host regex(example\w+\.com)`,
|
||||
input: &http.Request{
|
||||
Host: "example.org",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex_negated_no_match",
|
||||
checker: `!host regex(example\w+\.com)`,
|
||||
input: &http.Request{
|
||||
Host: "exampleabc.com",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "glob match",
|
||||
checker: `host glob(*.example.com)`,
|
||||
@@ -181,6 +307,22 @@ func TestOnCorrectness(t *testing.T) {
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "glob negated_match",
|
||||
checker: `!host glob(*.example.com)`,
|
||||
input: &http.Request{
|
||||
Host: "example.com",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "glob negated_no_match",
|
||||
checker: `!host glob(*.example.com)`,
|
||||
input: &http.Request{
|
||||
Host: "a.example.com",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request {
|
||||
@@ -219,10 +361,11 @@ func TestOnCorrectness(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
var on RuleOn
|
||||
err := on.Parse(tt.checker)
|
||||
expect.NoError(t, err)
|
||||
got := on.Check(Cache{}, tt.input)
|
||||
got := on.Check(w, tt.input)
|
||||
expect.Equal(t, tt.want, got, fmt.Sprintf("expect %s to %v", tt.checker, tt.want))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package rules
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"unicode"
|
||||
|
||||
"github.com/yusing/goutils/env"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
brackets := 0
|
||||
|
||||
var envVar bytes.Buffer
|
||||
var missingEnvVars bytes.Buffer
|
||||
var missingEnvVars []string
|
||||
inEnvVar := false
|
||||
expectingBrace := false
|
||||
|
||||
@@ -70,6 +70,10 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if expectingBrace && r != '{' && r != '$' { // not escaped and not env var
|
||||
buf.WriteRune('$')
|
||||
expectingBrace = false
|
||||
}
|
||||
switch r {
|
||||
case '\\':
|
||||
escaped = true
|
||||
@@ -90,9 +94,11 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
}
|
||||
case '}':
|
||||
if inEnvVar {
|
||||
envValue, ok := os.LookupEnv(envVar.String())
|
||||
// NOTE: use env.LookupEnv instead of os.LookupEnv to support environment variable prefixes
|
||||
// like ${API_ADDR} will lookup for GODOXY_API_ADDR, GOPROXY_API_ADDR and API_ADDR.
|
||||
envValue, ok := env.LookupEnv(envVar.String())
|
||||
if !ok {
|
||||
fmt.Fprintf(&missingEnvVars, "%q, ", envVar.String())
|
||||
missingEnvVars = append(missingEnvVars, envVar.String())
|
||||
} else {
|
||||
buf.WriteString(envValue)
|
||||
}
|
||||
@@ -140,15 +146,21 @@ func parse(v string) (subject string, args []string, err gperr.Error) {
|
||||
}
|
||||
}
|
||||
|
||||
if expectingBrace {
|
||||
buf.WriteRune('$')
|
||||
}
|
||||
|
||||
if quote != 0 {
|
||||
err = ErrUnterminatedQuotes
|
||||
} else if brackets != 0 {
|
||||
err = ErrUnterminatedBrackets
|
||||
} else if inEnvVar {
|
||||
err = ErrUnterminatedEnvVar
|
||||
} else {
|
||||
flush(false)
|
||||
}
|
||||
if missingEnvVars.Len() > 0 {
|
||||
err = gperr.Join(err, ErrEnvVarNotFound.Subject(missingEnvVars.String()))
|
||||
if len(missingEnvVars) > 0 {
|
||||
err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...)))
|
||||
}
|
||||
return subject, args, err
|
||||
}
|
||||
|
||||
@@ -49,9 +49,9 @@ func TestParser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "regex_escaped",
|
||||
input: `foo regex(\b\B\s\S\w\W\d\D\$\.)`,
|
||||
input: `foo regex(\b\B\s\S\w\W\d\D\$\.\(\)\{\}\|\?\"\')`,
|
||||
subject: "foo",
|
||||
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.)`},
|
||||
args: []string{`regex(\b\B\s\S\w\W\d\D\$\.\(\)\{\}\|\?"')`},
|
||||
},
|
||||
{
|
||||
name: "quote inside argument",
|
||||
@@ -71,6 +71,12 @@ func TestParser(t *testing.T) {
|
||||
subject: "foo",
|
||||
args: []string{"glob(\"`/**/to/path`\")"},
|
||||
},
|
||||
{
|
||||
name: "complex_regex",
|
||||
input: `path !regex("^(_next/static|_next/image|favicon.ico).*$")`,
|
||||
subject: "path",
|
||||
args: []string{`!regex("^(_next/static|_next/image|favicon.ico).*$")`},
|
||||
},
|
||||
{
|
||||
name: "chaos",
|
||||
input: `error 403 "Forbidden "foo" "bar""`,
|
||||
@@ -170,6 +176,53 @@ func TestParser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("negated", func(t *testing.T) {
|
||||
test := `!error 403 "Forbidden"`
|
||||
subject, args, err := parse(test)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, subject, "!error")
|
||||
expect.Equal(t, args, []string{"403", "Forbidden"})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFullParse(t *testing.T) {
|
||||
input := `
|
||||
- name: login page
|
||||
on: path /login
|
||||
do: pass
|
||||
- name: require auth
|
||||
on: path !regex("^(_next/static|_next/image|favicon.ico).*$")
|
||||
do: require_auth
|
||||
- name: redirect to login
|
||||
on: status 401 | status 403
|
||||
do: proxy /login
|
||||
- name: proxy to backend
|
||||
on: path glob("/api/v1/*")
|
||||
do: proxy http://localhost:8999/
|
||||
- name: proxy to backend (old /auth)
|
||||
on: path glob("/auth/*")
|
||||
do: proxy http://localhost:8999/api/v1/`
|
||||
|
||||
var rules Rules
|
||||
err := parseRules(input, &rules)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, len(rules), 5)
|
||||
expect.Equal(t, rules[0].Name, "login page")
|
||||
expect.Equal(t, rules[0].On.String(), "path /login")
|
||||
expect.Equal(t, rules[0].Do.String(), "pass")
|
||||
expect.Equal(t, rules[1].Name, "require auth")
|
||||
expect.Equal(t, rules[1].On.String(), `path !regex("^(_next/static|_next/image|favicon.ico).*$")`)
|
||||
expect.Equal(t, rules[1].Do.String(), "require_auth")
|
||||
expect.Equal(t, rules[2].Name, "redirect to login")
|
||||
expect.Equal(t, rules[2].On.String(), "status 401 | status 403")
|
||||
expect.Equal(t, rules[2].Do.String(), "proxy /login")
|
||||
expect.Equal(t, rules[3].Name, "proxy to backend")
|
||||
expect.Equal(t, rules[3].On.String(), `path glob("/api/v1/*")`)
|
||||
expect.Equal(t, rules[3].Do.String(), "proxy http://localhost:8999/")
|
||||
expect.Equal(t, rules[4].Name, "proxy to backend (old /auth)")
|
||||
expect.Equal(t, rules[4].On.String(), `path glob("/auth/*")`)
|
||||
expect.Equal(t, rules[4].Do.String(), "proxy http://localhost:8999/api/v1/")
|
||||
}
|
||||
|
||||
func BenchmarkParser(b *testing.B) {
|
||||
|
||||
48
internal/route/rules/presets/embed.go
Normal file
48
internal/route/rules/presets/embed.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rulepresets
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/internal/route/rules"
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
)
|
||||
|
||||
//go:embed *.yml
|
||||
var fs embed.FS
|
||||
|
||||
var rulePresets = make(map[string]rules.Rules)
|
||||
|
||||
var once sync.Once
|
||||
|
||||
func GetRulePreset(name string) (rules.Rules, bool) {
|
||||
once.Do(initPresets)
|
||||
rules, ok := rulePresets[name]
|
||||
return rules, ok
|
||||
}
|
||||
|
||||
// init all rule presetsl lazily
|
||||
func initPresets() {
|
||||
files, err := fs.ReadDir(".")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to read rule presets")
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
var rules rules.Rules
|
||||
content, err := fs.ReadFile(file.Name())
|
||||
if err != nil {
|
||||
log.Error().Str("name", file.Name()).Err(err).Msg("failed to read rule preset")
|
||||
continue
|
||||
}
|
||||
_, err = serialization.ConvertString(string(content), reflect.ValueOf(&rules))
|
||||
if err != nil {
|
||||
log.Error().Str("name", file.Name()).Err(err).Msg("failed to unmarshal rule preset")
|
||||
continue
|
||||
}
|
||||
rulePresets[file.Name()] = rules
|
||||
log.Debug().Str("name", file.Name()).Msg("loaded rule preset")
|
||||
}
|
||||
}
|
||||
17
internal/route/rules/presets/webui.yml
Normal file
17
internal/route/rules/presets/webui.yml
Normal file
@@ -0,0 +1,17 @@
|
||||
- name: login page
|
||||
on: path /login
|
||||
do: pass
|
||||
- name: protected
|
||||
on: |
|
||||
!path regex("(_next/static|_next/image|favicon.ico).*")
|
||||
!path glob("/api/v1/auth/*")
|
||||
!path /api/v1/version
|
||||
do: require_auth
|
||||
- name: proxy to backend
|
||||
on: path glob("/api/v1/*")
|
||||
do: proxy http://${API_ADDR}/
|
||||
- name: proxy to auth api
|
||||
on: path glob("/auth/*")
|
||||
do: |
|
||||
rewrite /auth /api/v1/auth
|
||||
proxy http://${API_ADDR}/
|
||||
173
internal/route/rules/response_modifier.go
Normal file
173
internal/route/rules/response_modifier.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/synk"
|
||||
)
|
||||
|
||||
type ResponseModifier struct {
|
||||
w http.ResponseWriter
|
||||
b []byte // the bytes got from pool
|
||||
buf *bytes.Buffer
|
||||
statusCode int
|
||||
shared Cache
|
||||
|
||||
hijacked bool
|
||||
|
||||
errs gperr.Builder
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
StatusCode int
|
||||
Header http.Header
|
||||
}
|
||||
|
||||
var pool = synk.GetBytesPoolWithUniqueMemory()
|
||||
|
||||
func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
||||
for {
|
||||
switch ww := w.(type) {
|
||||
case *ResponseModifier:
|
||||
return ww
|
||||
case interface{ Unwrap() http.ResponseWriter }:
|
||||
w = ww.Unwrap()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetInitResponseModifier returns the response modifier for the given response writer.
|
||||
// If the response writer is already wrapped, it will return the wrapped response modifier.
|
||||
// Otherwise, it will return a new response modifier.
|
||||
func GetInitResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
||||
if rm := unwrapResponseModifier(w); rm != nil {
|
||||
return rm
|
||||
}
|
||||
return NewResponseModifier(w)
|
||||
}
|
||||
|
||||
// GetSharedData returns the shared data for the given response writer.
|
||||
// It will initialize the shared data if not initialized.
|
||||
func GetSharedData(w http.ResponseWriter) Cache {
|
||||
rm := GetInitResponseModifier(w)
|
||||
if rm.shared == nil {
|
||||
rm.shared = NewCache()
|
||||
}
|
||||
return rm.shared
|
||||
}
|
||||
|
||||
// NewResponseModifier returns a new response modifier for the given response writer.
|
||||
//
|
||||
// It should only be called once, at the very beginning of the request.
|
||||
func NewResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
||||
b := pool.Get()
|
||||
return &ResponseModifier{
|
||||
w: w,
|
||||
buf: bytes.NewBuffer(b),
|
||||
b: b,
|
||||
}
|
||||
}
|
||||
|
||||
// func (rm *ResponseModifier) Unwrap() http.ResponseWriter {
|
||||
// return rm.w
|
||||
// }
|
||||
|
||||
func (rm *ResponseModifier) WriteHeader(code int) {
|
||||
rm.statusCode = code
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) ResetBody() {
|
||||
rm.buf.Reset()
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) ContentLength() int {
|
||||
return rm.buf.Len()
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) StatusCode() int {
|
||||
if rm.statusCode == 0 {
|
||||
return http.StatusOK
|
||||
}
|
||||
return rm.statusCode
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) Header() http.Header {
|
||||
return rm.w.Header()
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) Response() Response {
|
||||
return Response{StatusCode: rm.StatusCode(), Header: rm.Header()}
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) Write(b []byte) (int, error) {
|
||||
return rm.buf.Write(b)
|
||||
}
|
||||
|
||||
// AppendError appends an error to the response modifier
|
||||
// the error will be formatted as "rule <rule.Name> error: <err>"
|
||||
//
|
||||
// It will be aggregated and returned in FlushRelease.
|
||||
func (rm *ResponseModifier) AppendError(rule Rule, err error) {
|
||||
rm.errs.Addf("rule %q error: %w", rule.Name, err)
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := rm.w.(http.Hijacker); ok {
|
||||
rm.hijacked = true
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return nil, nil, errors.New("hijack not supported")
|
||||
}
|
||||
|
||||
// FlushRelease flushes the response modifier and releases the resources
|
||||
// it returns the number of bytes written and the aggregated error
|
||||
// if there is any error (rule errors or write error), it will be returned
|
||||
func (rm *ResponseModifier) FlushRelease() (int, error) {
|
||||
n := 0
|
||||
if !rm.hijacked {
|
||||
h := rm.w.Header()
|
||||
// for k := range h {
|
||||
// if strings.EqualFold(k, "content-length") {
|
||||
// h.Del(k)
|
||||
// }
|
||||
// }
|
||||
h.Set("Content-Length", strconv.Itoa(rm.buf.Len()))
|
||||
rm.w.WriteHeader(rm.StatusCode())
|
||||
nn, werr := rm.w.Write(rm.buf.Bytes())
|
||||
n += nn
|
||||
if werr != nil {
|
||||
rm.errs.Addf("write error: %w", werr)
|
||||
}
|
||||
|
||||
// flush the response writer
|
||||
if flusher, ok := rm.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
} else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok {
|
||||
ferr := errFlusher.Flush()
|
||||
if ferr != nil {
|
||||
rm.errs.Addf("flush error: %w", ferr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// release the buffer and reset the pointers
|
||||
pool.Put(rm.b)
|
||||
rm.b = nil
|
||||
rm.buf = nil
|
||||
|
||||
// release the shared data
|
||||
if rm.shared != nil {
|
||||
rm.shared.Release()
|
||||
rm.shared = nil
|
||||
}
|
||||
|
||||
return n, rm.errs.Error()
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -46,15 +49,16 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (rule *Rule) IsResponseRule() bool {
|
||||
return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler()
|
||||
}
|
||||
|
||||
// BuildHandler returns a http.HandlerFunc that implements the rules.
|
||||
//
|
||||
// if a bypass rule matches,
|
||||
// the request is passed to the upstream and no more rules are executed.
|
||||
//
|
||||
// if no rule matches, the default rule is executed
|
||||
// if no rule matches and default rule is not set,
|
||||
// the request is passed to the upstream.
|
||||
func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
|
||||
func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
|
||||
if len(rules) == 0 {
|
||||
return up
|
||||
}
|
||||
|
||||
defaultRule := Rule{
|
||||
Name: "default",
|
||||
Do: Command{
|
||||
@@ -63,55 +67,168 @@ func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc {
|
||||
},
|
||||
}
|
||||
|
||||
nonDefaultRules := make(Rules, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
var nonDefaultRules Rules
|
||||
hasDefaultRule := false
|
||||
for i, rule := range rules {
|
||||
if rule.Name == "default" {
|
||||
defaultRule = rule
|
||||
hasDefaultRule = true
|
||||
} else {
|
||||
// set name to index if name is empty
|
||||
if rule.Name == "" {
|
||||
rule.Name = fmt.Sprintf("rule[%d]", i)
|
||||
}
|
||||
nonDefaultRules = append(nonDefaultRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonDefaultRules) == 0 {
|
||||
if defaultRule.Do.isBypass() {
|
||||
return up.ServeHTTP
|
||||
return up
|
||||
}
|
||||
if defaultRule.IsResponseRule() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rm := NewResponseModifier(w)
|
||||
w = rm
|
||||
up(w, r)
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cache := NewCache()
|
||||
defer cache.Release()
|
||||
if defaultRule.Do.exec.Handle(cache, w, r) {
|
||||
up.ServeHTTP(w, r)
|
||||
rm := NewResponseModifier(w)
|
||||
w = rm
|
||||
err := defaultRule.Do.exec.Handle(w, r)
|
||||
if err == nil {
|
||||
up(w, r)
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(nonDefaultRules) == 0 {
|
||||
nonDefaultRules = rules
|
||||
preRules := make(Rules, 0, len(nonDefaultRules))
|
||||
postRules := make(Rules, 0, len(nonDefaultRules))
|
||||
for _, rule := range nonDefaultRules {
|
||||
if rule.IsResponseRule() {
|
||||
postRules = append(postRules, rule)
|
||||
} else {
|
||||
preRules = append(preRules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule()
|
||||
defaultTerminates := isTerminatingHandler(defaultRule.Do.exec)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cache := NewCache()
|
||||
defer cache.Release()
|
||||
rm := NewResponseModifier(w)
|
||||
defer func() {
|
||||
if _, err := rm.FlushRelease(); err != nil {
|
||||
gperr.LogError("error executing rules", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, rule := range nonDefaultRules {
|
||||
if rule.Check(cache, r) {
|
||||
if rule.Do.isBypass() {
|
||||
up.ServeHTTP(w, r)
|
||||
return
|
||||
w = rm
|
||||
|
||||
shouldCallUpstream := true
|
||||
preMatched := false
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && !defaultTerminates {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(defaultRule, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
}
|
||||
if !rule.Handle(cache, w, r) {
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
for _, rule := range preRules {
|
||||
if rule.Check(w, r) {
|
||||
preMatched = true
|
||||
if rule.Do.isBypass() {
|
||||
break // post rules should still execute
|
||||
}
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(rule, err)
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched {
|
||||
if defaultRule.Do.isBypass() {
|
||||
// continue to upstream
|
||||
} else {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(defaultRule, err)
|
||||
return
|
||||
}
|
||||
shouldCallUpstream = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldCallUpstream {
|
||||
up(w, r)
|
||||
}
|
||||
|
||||
// if no post rules, we are done here
|
||||
if len(postRules) == 0 && !isDefaultRulePost {
|
||||
return
|
||||
}
|
||||
|
||||
for _, rule := range postRules {
|
||||
if rule.Check(w, r) {
|
||||
err := rule.Handle(w, r)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(rule, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// bypass or proceed
|
||||
if defaultRule.Do.isBypass() || defaultRule.Handle(cache, w, r) {
|
||||
up.ServeHTTP(w, r)
|
||||
if isDefaultRulePost {
|
||||
err := defaultRule.Handle(w, r)
|
||||
if err != nil && !errors.Is(err, errTerminated) {
|
||||
rm.AppendError(defaultRule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isTerminatingHandler(handler CommandHandler) bool {
|
||||
switch h := handler.(type) {
|
||||
case TerminatingCommand:
|
||||
return true
|
||||
case Commands:
|
||||
if len(h) == 0 {
|
||||
return false
|
||||
}
|
||||
return isTerminatingHandler(h[len(h)-1])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (rules Rules) MarshalJSON() ([]byte, error) {
|
||||
names := make([]string, len(rules))
|
||||
for i, rule := range rules {
|
||||
@@ -124,11 +241,14 @@ func (rule *Rule) String() string {
|
||||
return rule.Name
|
||||
}
|
||||
|
||||
func (rule *Rule) Check(cached Cache, r *http.Request) bool {
|
||||
return rule.On.checker.Check(cached, r)
|
||||
func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
|
||||
if rule.On.checker == nil {
|
||||
return true
|
||||
}
|
||||
v := rule.On.checker.Check(w, r)
|
||||
return v
|
||||
}
|
||||
|
||||
func (rule *Rule) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
proceed = rule.Do.exec.Handle(cached, w, r)
|
||||
return proceed
|
||||
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
|
||||
return rule.Do.exec.Handle(w, r)
|
||||
}
|
||||
|
||||
74
internal/route/rules/rules_bench_test.go
Normal file
74
internal/route/rules/rules_bench_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkRules(b *testing.B) {
|
||||
var rules Rules
|
||||
err := parseRules(`
|
||||
- name: admin-api
|
||||
on: |
|
||||
path glob(/api/admin/*)
|
||||
header Authorization
|
||||
method POST
|
||||
do: |
|
||||
set resp_header X-Access-Level "admin"
|
||||
set resp_header X-API-Version "v1"
|
||||
- name: user-api
|
||||
on: |
|
||||
path glob(/api/users/*) & method GET
|
||||
do: |
|
||||
set resp_header X-Access-Level "user"
|
||||
set resp_header X-API-Version "v1"
|
||||
- name: public-api
|
||||
on: |
|
||||
path glob(/api/public/*) & method GET
|
||||
do: |
|
||||
set resp_header X-Access-Level "public"
|
||||
`, &rules)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
b.Run("BuildHandler", func(b *testing.B) {
|
||||
for b.Loop() {
|
||||
rules.BuildHandler(upstream)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("RunHandler", func(b *testing.B) {
|
||||
var r = &http.Request{
|
||||
Body: io.NopCloser(bytes.NewReader([]byte(""))),
|
||||
URL: &url.URL{Path: "/api/users/"},
|
||||
}
|
||||
var w noopResponseWriter
|
||||
handler := rules.BuildHandler(upstream)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type noopResponseWriter struct {
|
||||
}
|
||||
|
||||
func (w noopResponseWriter) Header() http.Header {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func (w noopResponseWriter) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (w noopResponseWriter) WriteHeader(int) {
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/godoxy/internal/serialization"
|
||||
expect "github.com/yusing/goutils/testing"
|
||||
)
|
||||
|
||||
func TestParseRule(t *testing.T) {
|
||||
test := []map[string]any{
|
||||
{
|
||||
"name": "test",
|
||||
"on": "method POST",
|
||||
"do": "error 403 Forbidden",
|
||||
},
|
||||
{
|
||||
"name": "auth",
|
||||
"on": `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`,
|
||||
"do": "bypass",
|
||||
},
|
||||
{
|
||||
"name": "default",
|
||||
"do": "require_basic_auth any_realm",
|
||||
},
|
||||
}
|
||||
|
||||
var rules struct {
|
||||
Rules Rules
|
||||
}
|
||||
err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules)
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, len(rules.Rules), len(test))
|
||||
expect.Equal(t, rules.Rules[0].Name, "test")
|
||||
expect.Equal(t, rules.Rules[0].On.String(), "method POST")
|
||||
expect.Equal(t, rules.Rules[0].Do.String(), "error 403 Forbidden")
|
||||
|
||||
expect.Equal(t, rules.Rules[1].Name, "auth")
|
||||
expect.Equal(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`)
|
||||
expect.Equal(t, rules.Rules[1].Do.String(), "bypass")
|
||||
|
||||
expect.Equal(t, rules.Rules[2].Name, "default")
|
||||
expect.Equal(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm")
|
||||
}
|
||||
|
||||
// TODO: real tests.
|
||||
43
internal/route/rules/template.go
Normal file
43
internal/route/rules/template.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package rules
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type templateOrStr interface {
|
||||
Execute(w io.Writer, data any) error
|
||||
}
|
||||
|
||||
type strTemplate string
|
||||
|
||||
func (t strTemplate) Execute(w io.Writer, _ any) error {
|
||||
n, err := w.Write([]byte(t))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != len(t) {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type keyValueTemplate = Tuple[string, templateOrStr]
|
||||
|
||||
func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, reqResponseTemplateData{Request: r})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func executeRequestTemplateTo(tmpl templateOrStr, o io.Writer, r *http.Request) error {
|
||||
return tmpl.Execute(o, reqResponseTemplateData{Request: r})
|
||||
}
|
||||
|
||||
func executeReqRespTemplateTo(tmpl templateOrStr, o io.Writer, w http.ResponseWriter, r *http.Request) error {
|
||||
return tmpl.Execute(o, reqResponseTemplateData{Request: r, Response: GetInitResponseModifier(w).Response()})
|
||||
}
|
||||
@@ -6,10 +6,11 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"github.com/rs/zerolog"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
@@ -21,6 +22,17 @@ type (
|
||||
First T1
|
||||
Second T2
|
||||
}
|
||||
Tuple3[T1, T2, T3 any] struct {
|
||||
First T1
|
||||
Second T2
|
||||
Third T3
|
||||
}
|
||||
Tuple4[T1, T2, T3, T4 any] struct {
|
||||
First T1
|
||||
Second T2
|
||||
Third T3
|
||||
Fourth T4
|
||||
}
|
||||
StrTuple = Tuple[string, string]
|
||||
IntTuple = Tuple[int, int]
|
||||
MapValueMatcher = Tuple[string, Matcher]
|
||||
@@ -30,97 +42,24 @@ func (t *Tuple[T1, T2]) Unpack() (T1, T2) {
|
||||
return t.First, t.Second
|
||||
}
|
||||
|
||||
func (t *Tuple3[T1, T2, T3]) Unpack() (T1, T2, T3) {
|
||||
return t.First, t.Second, t.Third
|
||||
}
|
||||
|
||||
func (t *Tuple4[T1, T2, T3, T4]) Unpack() (T1, T2, T3, T4) {
|
||||
return t.First, t.Second, t.Third, t.Fourth
|
||||
}
|
||||
|
||||
func (t *Tuple[T1, T2]) String() string {
|
||||
return fmt.Sprintf("%v:%v", t.First, t.Second)
|
||||
}
|
||||
|
||||
type (
|
||||
Matcher func(string) bool
|
||||
MatcherType string
|
||||
)
|
||||
|
||||
const (
|
||||
MatcherTypeString MatcherType = "string"
|
||||
MatcherTypeGlob MatcherType = "glob"
|
||||
MatcherTypeRegex MatcherType = "regex"
|
||||
)
|
||||
|
||||
func unquoteExpr(s string) (string, gperr.Error) {
|
||||
if s == "" {
|
||||
return "", nil
|
||||
}
|
||||
switch s[0] {
|
||||
case '"', '\'', '`':
|
||||
if s[0] != s[len(s)-1] {
|
||||
return "", ErrUnterminatedQuotes
|
||||
}
|
||||
return s[1 : len(s)-1], nil
|
||||
default:
|
||||
return s, nil
|
||||
}
|
||||
func (t *Tuple3[T1, T2, T3]) String() string {
|
||||
return fmt.Sprintf("%v:%v:%v", t.First, t.Second, t.Third)
|
||||
}
|
||||
|
||||
func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) {
|
||||
idx := strings.IndexByte(s, '(')
|
||||
if idx == -1 {
|
||||
return MatcherTypeString, s, nil
|
||||
}
|
||||
idxEnd := strings.LastIndexByte(s, ')')
|
||||
if idxEnd == -1 {
|
||||
return "", "", ErrUnterminatedBrackets
|
||||
}
|
||||
|
||||
expr, err = unquoteExpr(s[idx+1 : idxEnd])
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
matcherType = MatcherType(strings.ToLower(s[:idx]))
|
||||
|
||||
switch matcherType {
|
||||
case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString:
|
||||
return
|
||||
default:
|
||||
return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType)
|
||||
}
|
||||
}
|
||||
|
||||
func ParseMatcher(expr string) (Matcher, gperr.Error) {
|
||||
t, expr, err := ExtractExpr(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch t {
|
||||
case MatcherTypeString:
|
||||
return StringMatcher(expr)
|
||||
case MatcherTypeGlob:
|
||||
return GlobMatcher(expr)
|
||||
case MatcherTypeRegex:
|
||||
return RegexMatcher(expr)
|
||||
}
|
||||
// won't reach here
|
||||
return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t)
|
||||
}
|
||||
|
||||
func StringMatcher(s string) (Matcher, gperr.Error) {
|
||||
return func(s2 string) bool {
|
||||
return s == s2
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GlobMatcher(expr string) (Matcher, gperr.Error) {
|
||||
g, err := glob.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return g.Match, nil
|
||||
}
|
||||
|
||||
func RegexMatcher(expr string) (Matcher, gperr.Error) {
|
||||
re, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return re.MatchString, nil
|
||||
func (t *Tuple4[T1, T2, T3, T4]) String() string {
|
||||
return fmt.Sprintf("%v:%v:%v:%v", t.First, t.Second, t.Third, t.Fourth)
|
||||
}
|
||||
|
||||
// validateSingleMatcher returns Matcher with the matcher validated.
|
||||
@@ -131,14 +70,6 @@ func validateSingleMatcher(args []string) (any, gperr.Error) {
|
||||
return ParseMatcher(args[0])
|
||||
}
|
||||
|
||||
// toStrTuple returns *StrTuple.
|
||||
func toStrTuple(args []string) (any, gperr.Error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
return &StrTuple{args[0], args[1]}, nil
|
||||
}
|
||||
|
||||
// toKVOptionalVMatcher returns *MapValueMatcher that value is optional.
|
||||
func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
switch len(args) {
|
||||
@@ -155,6 +86,18 @@ func toKVOptionalVMatcher(args []string) (any, gperr.Error) {
|
||||
}
|
||||
}
|
||||
|
||||
func toKeyValueTemplate(args []string) (any, gperr.Error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
|
||||
tmpl, err := validateTemplate(args[1], false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &keyValueTemplate{args[0], tmpl}, nil
|
||||
}
|
||||
|
||||
// validateURL returns types.URL with the URL validated.
|
||||
func validateURL(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
@@ -164,6 +107,12 @@ func validateURL(args []string) (any, gperr.Error) {
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
// expect relative URL, must starts with /
|
||||
if !strings.HasPrefix(u.Path, "/") {
|
||||
return nil, ErrInvalidArguments.Withf("relative URL must starts with /")
|
||||
}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
@@ -250,6 +199,57 @@ func validateMethod(args []string) (any, gperr.Error) {
|
||||
return method, nil
|
||||
}
|
||||
|
||||
func validateStatusCode(status string) (int, error) {
|
||||
statusCode, err := strconv.Atoi(status)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if statusCode < 100 || statusCode > 599 {
|
||||
return 0, fmt.Errorf("status code out of range: %s", status)
|
||||
}
|
||||
return statusCode, nil
|
||||
}
|
||||
|
||||
// validateStatusRange returns Tuple[int, int] with the status range validated.
|
||||
// accepted formats are:
|
||||
// - <status>
|
||||
// - <status>-<status>
|
||||
// - 1xx
|
||||
// - 2xx
|
||||
// - 3xx
|
||||
// - 4xx
|
||||
// - 5xx
|
||||
func validateStatusRange(args []string) (any, gperr.Error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrExpectOneArg
|
||||
}
|
||||
|
||||
beg, end, ok := strings.Cut(args[0], "-")
|
||||
if !ok { // <status>
|
||||
end = beg
|
||||
}
|
||||
|
||||
switch beg {
|
||||
case "1xx":
|
||||
return &IntTuple{100, 199}, nil
|
||||
case "2xx":
|
||||
return &IntTuple{200, 299}, nil
|
||||
case "3xx":
|
||||
return &IntTuple{300, 399}, nil
|
||||
case "4xx":
|
||||
return &IntTuple{400, 499}, nil
|
||||
case "5xx":
|
||||
return &IntTuple{500, 599}, nil
|
||||
}
|
||||
|
||||
begInt, begErr := validateStatusCode(beg)
|
||||
endInt, endErr := validateStatusCode(end)
|
||||
if begErr != nil || endErr != nil {
|
||||
return nil, ErrInvalidArguments.With(gperr.Join(begErr, endErr))
|
||||
}
|
||||
return &IntTuple{begInt, endInt}, nil
|
||||
}
|
||||
|
||||
// validateUserBCryptPassword returns *HashedCrendential with the password validated.
|
||||
func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
if len(args) != 2 {
|
||||
@@ -260,20 +260,77 @@ func validateUserBCryptPassword(args []string) (any, gperr.Error) {
|
||||
|
||||
// validateModField returns CommandHandler with the field validated.
|
||||
func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.Error) {
|
||||
if len(args) == 0 {
|
||||
return nil, ErrExpectTwoOrThreeArgs
|
||||
}
|
||||
setField, ok := modFields[args[0]]
|
||||
if !ok {
|
||||
return nil, ErrInvalidSetTarget.Subject(args[0])
|
||||
return nil, ErrUnknownModField.Subject(args[0])
|
||||
}
|
||||
if mod == ModFieldRemove {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrExpectTwoArgs
|
||||
}
|
||||
// setField expect validateStrTuple
|
||||
args = append(args, "")
|
||||
}
|
||||
validArgs, err := setField.validate(args[1:])
|
||||
if err != nil {
|
||||
return nil, err.Withf(setField.help.String())
|
||||
return nil, err.With(setField.help.Error())
|
||||
}
|
||||
modder := setField.builder(validArgs)
|
||||
switch mod {
|
||||
case ModFieldAdd:
|
||||
return modder.add, nil
|
||||
add := modder.add
|
||||
if add == nil {
|
||||
return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod)
|
||||
}
|
||||
return add, nil
|
||||
case ModFieldRemove:
|
||||
return modder.remove, nil
|
||||
remove := modder.remove
|
||||
if remove == nil {
|
||||
return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod)
|
||||
}
|
||||
return remove, nil
|
||||
}
|
||||
return modder.set, nil
|
||||
set := modder.set
|
||||
if set == nil {
|
||||
return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod)
|
||||
}
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func isTemplate(tmplStr string) bool {
|
||||
return strings.Contains(tmplStr, "{{")
|
||||
}
|
||||
|
||||
func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) {
|
||||
if newline && !strings.HasSuffix(tmplStr, "\n") {
|
||||
tmplStr += "\n"
|
||||
}
|
||||
|
||||
if !isTemplate(tmplStr) {
|
||||
return strTemplate(tmplStr), nil
|
||||
}
|
||||
|
||||
tmpl, err := template.New("template").Parse(tmplStr)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
func validateLevel(level string) (zerolog.Level, gperr.Error) {
|
||||
l, err := zerolog.ParseLevel(level)
|
||||
if err != nil {
|
||||
return zerolog.NoLevel, ErrInvalidArguments.With(err)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// func validateNotifProvider(provider string) gperr.Error {
|
||||
// if !notif.HasProvider(provider) {
|
||||
// return ErrInvalidArguments.Subject(provider)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
Reference in New Issue
Block a user