diff --git a/goutils b/goutils index 26146bd5..e78e3c2d 160000 --- a/goutils +++ b/goutils @@ -1 +1 @@ -Subproject commit 26146bd560f1ce384ae568088e36e2217e2de02b +Subproject commit e78e3c2d35afc8173273d5075412d15b17889794 diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 7935fcc2..876f67ba 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -58,3 +58,13 @@ func AuthCheckHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } } + +func AuthOrProceed(w http.ResponseWriter, r *http.Request) (proceed bool) { + err := defaultAuth.CheckToken(r) + if err != nil { + defaultAuth.LoginHandler(w, r) + return false + } else { + return true + } +} diff --git a/internal/logging/accesslog/config.go b/internal/logging/accesslog/config.go index 20acfa9b..e90df847 100644 --- a/internal/logging/accesslog/config.go +++ b/internal/logging/accesslog/config.go @@ -69,7 +69,7 @@ func (cfg *ConfigBase) Validate() gperr.Error { // If only stdout is enabled, it returns nil, nil. func (cfg *ConfigBase) IO() (WriterWithName, error) { if cfg.Path != "" { - io, err := newFileIO(cfg.Path) + io, err := NewFileIO(cfg.Path) if err != nil { return nil, err } diff --git a/internal/logging/accesslog/file_logger.go b/internal/logging/accesslog/file_logger.go index 418e8f41..355a013f 100644 --- a/internal/logging/accesslog/file_logger.go +++ b/internal/logging/accesslog/file_logger.go @@ -26,7 +26,10 @@ var ( openedFilesMu sync.Mutex ) -func newFileIO(path string) (WriterWithName, error) { +// NewFileIO creates a new file writer with cleaned path. +// +// If the file is already opened, it will be returned. +func NewFileIO(path string) (WriterWithName, error) { openedFilesMu.Lock() defer openedFilesMu.Unlock() diff --git a/internal/logging/accesslog/file_logger_test.go b/internal/logging/accesslog/file_logger_test.go index 3d4d9781..a54df5c8 100644 --- a/internal/logging/accesslog/file_logger_test.go +++ b/internal/logging/accesslog/file_logger_test.go @@ -31,7 +31,7 @@ func TestConcurrentFileLoggersShareSameAccessLogIO(t *testing.T) { wg.Add(1) go func(index int) { defer wg.Done() - file, err := newFileIO(cfg.Path) + file, err := NewFileIO(cfg.Path) expect.NoError(t, err) accessLogIOs[index] = file }(i) diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 284c4c4d..1a9fa85f 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -61,6 +61,26 @@ func NewLogger(out ...io.Writer) zerolog.Logger { ).Level(level).With().Timestamp().Logger() } +func NewLoggerWithFixedLevel(level zerolog.Level, out ...io.Writer) zerolog.Logger { + levelStr := level.String() + writer := zerolog.ConsoleWriter{ + Out: zerolog.MultiLevelWriter(out...), + TimeFormat: timeFmt, + FormatMessage: func(msgI interface{}) string { // pad spaces for each line + if msgI == nil { + return "" + } + return fmtMessage(msgI.(string)) + }, + FormatLevel: func(_ any) string { + return levelStr + }, + } + return zerolog.New( + writer, + ).Level(level).With().Timestamp().Logger() +} + func InitLogger(out ...io.Writer) { logger = NewLogger(out...) log.SetOutput(logger) diff --git a/internal/net/gphttp/middleware/bypass.go b/internal/net/gphttp/middleware/bypass.go index 5858d05b..04a88ce3 100644 --- a/internal/net/gphttp/middleware/bypass.go +++ b/internal/net/gphttp/middleware/bypass.go @@ -8,11 +8,9 @@ import ( type Bypass []rules.RuleOn -func (b Bypass) ShouldBypass(r *http.Request) bool { - cached := rules.NewCache() - defer cached.Release() +func (b Bypass) ShouldBypass(w http.ResponseWriter, r *http.Request) bool { for _, rule := range b { - if rule.Check(cached, r) { + if rule.Check(w, r) { return true } } @@ -26,14 +24,14 @@ type checkBypass struct { } func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNext bool) { - if c.modReq == nil || c.bypass.ShouldBypass(r) { + if c.modReq == nil || c.bypass.ShouldBypass(w, r) { return true } return c.modReq.before(w, r) } -func (c *checkBypass) modifyResponse(resp *http.Response) error { - if c.modRes == nil || c.bypass.ShouldBypass(resp.Request) { +func (c *checkBypass) modifyResponse(w http.ResponseWriter, resp *http.Response) error { + if c.modRes == nil || c.bypass.ShouldBypass(w, resp.Request) { return nil } return c.modRes.modifyResponse(resp) diff --git a/internal/notif/body.go b/internal/notif/body.go index e68b9d8b..d87bddbe 100644 --- a/internal/notif/body.go +++ b/internal/notif/body.go @@ -20,10 +20,11 @@ type ( ) type ( - FieldsBody []LogField - ListBody []string - MessageBody string - errorBody struct { + FieldsBody []LogField + ListBody []string + MessageBody string + MessageBodyBytes []byte + errorBody struct { Error error } ) @@ -98,7 +99,15 @@ func (m MessageBody) Format(format LogFormat) ([]byte, error) { case LogFormatRawJSON: return sonic.Marshal(m) } - return m.Format(LogFormatMarkdown) + return []byte(m), nil +} + +func (m MessageBodyBytes) Format(format LogFormat) ([]byte, error) { + switch format { + case LogFormatRawJSON: + return sonic.Marshal(string(m)) + } + return m, nil } func (e errorBody) Format(format LogFormat) ([]byte, error) { diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index 60423072..601e9a9c 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -128,7 +128,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { } if len(r.Rules) > 0 { - r.handler = r.Rules.BuildHandler(r.handler) + r.handler = r.Rules.BuildHandler(r.handler.ServeHTTP) } if r.HealthMon != nil { diff --git a/internal/route/route.go b/internal/route/route.go index ee8c8bdd..9c2e283d 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -2,7 +2,11 @@ package route import ( "context" + "errors" "fmt" + "net/url" + "os" + "reflect" "runtime" "strings" "sync" @@ -17,6 +21,7 @@ import ( netutils "github.com/yusing/godoxy/internal/net" nettypes "github.com/yusing/godoxy/internal/net/types" "github.com/yusing/godoxy/internal/proxmox" + "github.com/yusing/godoxy/internal/serialization" "github.com/yusing/godoxy/internal/types" gperr "github.com/yusing/goutils/errs" strutils "github.com/yusing/goutils/strings" @@ -25,6 +30,7 @@ import ( "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/logging/accesslog" "github.com/yusing/godoxy/internal/route/rules" + rulepresets "github.com/yusing/godoxy/internal/route/rules/presets" route "github.com/yusing/godoxy/internal/route/types" "github.com/yusing/godoxy/internal/utils" ) @@ -41,7 +47,8 @@ type ( route.HTTPConfig PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"` - Rules rules.Rules `json:"rules,omitempty" validate:"omitempty,unique=Name" extension:"x-nullable"` + Rules rules.Rules `json:"rules,omitempty" extension:"x-nullable"` + RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"` HealthCheck *types.HealthCheckConfig `json:"healthcheck"` LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"` Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"` @@ -212,7 +219,10 @@ func (r *Route) Validate() gperr.Error { } } - errs := gperr.NewBuilder("entry validation failed") + var errs gperr.Builder + if err := r.validateRules(); err != nil { + errs.Add(err) + } var impl types.Route var err gperr.Error @@ -267,6 +277,39 @@ func (r *Route) Validate() gperr.Error { return nil } +func (r *Route) validateRules() error { + if r.RuleFile != "" && len(r.Rules) > 0 { + return errors.New("`rule_file` and `rules` cannot be used together") + } else if r.RuleFile != "" { + src, err := url.Parse(r.RuleFile) + if err != nil { + return fmt.Errorf("failed to parse rule file url %q: %w", r.RuleFile, err) + } + switch src.Scheme { + case "embed": // embed:// + rules, ok := rulepresets.GetRulePreset(src.Host) + if !ok { + return fmt.Errorf("rule preset %q not found", src.Host) + } else { + r.Rules = rules + } + case "file", "": + content, err := os.ReadFile(src.Path) + if err != nil { + return fmt.Errorf("failed to read rule file %q: %w", src.Path, err) + } else { + _, err = serialization.ConvertString(string(content), reflect.ValueOf(&r.Rules)) + if err != nil { + return fmt.Errorf("failed to unmarshal rule file %q: %w", src.Path, err) + } + } + default: + return fmt.Errorf("unsupported rule file scheme %q", src.Scheme) + } + } + return nil +} + func (r *Route) Impl() types.Route { return r.impl } diff --git a/internal/route/routes/context.go b/internal/route/routes/context.go index a5bead0a..0bb60dc8 100644 --- a/internal/route/routes/context.go +++ b/internal/route/routes/context.go @@ -86,6 +86,13 @@ func TryGetUpstreamPort(r *http.Request) string { return "" } +func TryGetUpstreamHostPort(r *http.Request) string { + if u := tryGetURL(r); u != nil { + return u.Host + } + return "" +} + func TryGetUpstreamAddr(r *http.Request) string { if u := tryGetURL(r); u != nil { return u.Host diff --git a/internal/route/rules/cache.go b/internal/route/rules/cache.go index 711e19ba..afb8ac3b 100644 --- a/internal/route/rules/cache.go +++ b/internal/route/rules/cache.go @@ -15,13 +15,13 @@ type ( ) const ( - CacheKeyQueries = "queries" - CacheKeyCookies = "cookies" - CacheKeyRemoteIP = "remote_ip" - CacheKeyBasicAuth = "basic_auth" + cacheKeyQueries = "queries" + cacheKeyCookies = "cookies" + cacheKeyRemoteIP = "remote_ip" + cacheKeyBasicAuth = "basic_auth" ) -var cachePool = &sync.Pool{ +var cachePool = sync.Pool{ New: func() any { return make(Cache) }, @@ -41,10 +41,10 @@ func (c Cache) Release() { // GetQueries returns the queries. // If r does not have queries, an empty map is returned. func (c Cache) GetQueries(r *http.Request) url.Values { - v, ok := c[CacheKeyQueries] + v, ok := c[cacheKeyQueries] if !ok { v = r.URL.Query() - c[CacheKeyQueries] = v + c[cacheKeyQueries] = v } return v.(url.Values) } @@ -58,17 +58,17 @@ func (c Cache) UpdateQueries(r *http.Request, update func(url.Values)) { // GetCookies returns the cookies. // If r does not have cookies, an empty slice is returned. func (c Cache) GetCookies(r *http.Request) []*http.Cookie { - v, ok := c[CacheKeyCookies] + v, ok := c[cacheKeyCookies] if !ok { v = r.Cookies() - c[CacheKeyCookies] = v + c[cacheKeyCookies] = v } return v.([]*http.Cookie) } func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) { cookies := update(c.GetCookies(r)) - c[CacheKeyCookies] = cookies + c[cacheKeyCookies] = cookies r.Header.Del("Cookie") for _, cookie := range cookies { r.AddCookie(cookie) @@ -78,14 +78,14 @@ func (c Cache) UpdateCookies(r *http.Request, update UpdateFunc[[]*http.Cookie]) // GetRemoteIP returns the remote ip address. // If r.RemoteAddr is not a valid ip address, nil is returned. func (c Cache) GetRemoteIP(r *http.Request) net.IP { - v, ok := c[CacheKeyRemoteIP] + v, ok := c[cacheKeyRemoteIP] if !ok { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr } v = net.ParseIP(host) - c[CacheKeyRemoteIP] = v + c[cacheKeyRemoteIP] = v } return v.(net.IP) } @@ -93,14 +93,14 @@ func (c Cache) GetRemoteIP(r *http.Request) net.IP { // GetBasicAuth returns *Credentials the basic auth username and password. // If r does not have basic auth, nil is returned. func (c Cache) GetBasicAuth(r *http.Request) *Credentials { - v, ok := c[CacheKeyBasicAuth] + v, ok := c[cacheKeyBasicAuth] if !ok { u, p, ok := r.BasicAuth() if ok { v = &Credentials{u, []byte(p)} - c[CacheKeyBasicAuth] = v + c[cacheKeyBasicAuth] = v } else { - c[CacheKeyBasicAuth] = nil + c[cacheKeyBasicAuth] = nil return nil } } diff --git a/internal/route/rules/check_on.go b/internal/route/rules/check_on.go index 389c9aa1..750a6211 100644 --- a/internal/route/rules/check_on.go +++ b/internal/route/rules/check_on.go @@ -3,30 +3,30 @@ package rules import "net/http" type ( - CheckFunc func(cached Cache, r *http.Request) bool + CheckFunc func(w http.ResponseWriter, r *http.Request) bool Checker interface { - Check(cached Cache, r *http.Request) bool + Check(w http.ResponseWriter, r *http.Request) bool } CheckMatchSingle []Checker CheckMatchAll []Checker ) -func (checker CheckFunc) Check(cached Cache, r *http.Request) bool { - return checker(cached, r) +func (checker CheckFunc) Check(w http.ResponseWriter, r *http.Request) bool { + return checker(w, r) } -func (checkers CheckMatchSingle) Check(cached Cache, r *http.Request) bool { +func (checkers CheckMatchSingle) Check(w http.ResponseWriter, r *http.Request) bool { for _, check := range checkers { - if check.Check(cached, r) { + if check.Check(w, r) { return true } } return false } -func (checkers CheckMatchAll) Check(cached Cache, r *http.Request) bool { +func (checkers CheckMatchAll) Check(w http.ResponseWriter, r *http.Request) bool { for _, check := range checkers { - if !check.Check(cached, r) { + if !check.Check(w, r) { return false } } diff --git a/internal/route/rules/command.go b/internal/route/rules/command.go index 2b58fdbc..a856edaa 100644 --- a/internal/route/rules/command.go +++ b/internal/route/rules/command.go @@ -3,19 +3,21 @@ package rules import "net/http" type ( + handlerFunc func(w http.ResponseWriter, r *http.Request) error + CommandHandler interface { // CommandHandler can read and modify the values // then handle the request // finally proceed to next command (or return) base on situation - Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) + Handle(w http.ResponseWriter, r *http.Request) error + IsResponseHandler() bool } // NonTerminatingCommand will run then proceed to next command or reverse proxy. - NonTerminatingCommand http.HandlerFunc + NonTerminatingCommand handlerFunc // TerminatingCommand will run then return immediately. - TerminatingCommand http.HandlerFunc - // DynamicCommand will return base on the request - // and can read or modify the values. - DynamicCommand func(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) + TerminatingCommand handlerFunc + // OnResponseCommand will run then return based on the response. + OnResponseCommand handlerFunc // BypassCommand will skip all the following commands // and directly return to reverse proxy. BypassCommand struct{} @@ -23,29 +25,55 @@ type ( Commands []CommandHandler ) -func (c NonTerminatingCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { - c(w, r) - return true +func (c NonTerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error { + return c(w, r) } -func (c TerminatingCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { - c(w, r) +func (c NonTerminatingCommand) IsResponseHandler() bool { return false } -func (c DynamicCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { - return c(cached, w, r) +func (c TerminatingCommand) Handle(w http.ResponseWriter, r *http.Request) error { + if err := c(w, r); err != nil { + return err + } + return errTerminated } -func (c BypassCommand) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { +func (c TerminatingCommand) IsResponseHandler() bool { + return false +} + +func (c OnResponseCommand) Handle(w http.ResponseWriter, r *http.Request) error { + return c(w, r) +} + +func (c OnResponseCommand) IsResponseHandler() bool { return true } -func (c Commands) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { +func (c BypassCommand) Handle(w http.ResponseWriter, r *http.Request) error { + return errTerminated +} + +func (c BypassCommand) IsResponseHandler() bool { + return false +} + +func (c Commands) Handle(w http.ResponseWriter, r *http.Request) error { for _, cmd := range c { - if !cmd.Handle(cached, w, r) { - return false + if err := cmd.Handle(w, r); err != nil { + return err } } - return true + return nil +} + +func (c Commands) IsResponseHandler() bool { + for _, cmd := range c { + if cmd.IsResponseHandler() { + return true + } + } + return false } diff --git a/internal/route/rules/do.go b/internal/route/rules/do.go index 034460b8..c86b07e5 100644 --- a/internal/route/rules/do.go +++ b/internal/route/rules/do.go @@ -1,27 +1,41 @@ package rules import ( + "bytes" + "fmt" + "io" "net/http" "path" "strconv" "strings" + "github.com/rs/zerolog" + "github.com/yusing/godoxy/internal/auth" + "github.com/yusing/godoxy/internal/logging" gphttp "github.com/yusing/godoxy/internal/net/gphttp" nettypes "github.com/yusing/godoxy/internal/net/types" + "github.com/yusing/godoxy/internal/notif" + "github.com/yusing/godoxy/internal/route/routes" gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" "github.com/yusing/goutils/http/reverseproxy" - strutils "github.com/yusing/goutils/strings" + "github.com/yusing/goutils/synk" ) type ( Command struct { - raw string - exec CommandHandler + raw string + exec CommandHandler + isResponseHandler bool } ) +func (cmd *Command) IsResponseHandler() bool { + return cmd.isResponseHandler +} + const ( + CommandRequireAuth = "require_auth" CommandRewrite = "rewrite" CommandServe = "serve" CommandProxy = "proxy" @@ -31,18 +45,46 @@ const ( CommandSet = "set" CommandAdd = "add" CommandRemove = "remove" + CommandLog = "log" + CommandNotify = "notify" CommandPass = "pass" CommandPassAlt = "bypass" ) var commands = map[string]struct { - help Help - validate ValidateFunc - build func(args any) CommandHandler + help Help + validate ValidateFunc + build func(args any) CommandHandler + isResponseHandler bool }{ + CommandRequireAuth: { + help: Help{ + command: CommandRequireAuth, + description: makeLines("Require HTTP authentication for incoming requests"), + args: map[string]string{}, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 0 { + return nil, ErrExpectNoArg + } + return nil, nil + }, + build: func(args any) CommandHandler { + return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + if !auth.AuthOrProceed(w, r) { + return errTerminated + } + return nil + }) + }, + }, CommandRewrite: { help: Help{ command: CommandRewrite, + description: makeLines( + "Rewrite a request path from one prefix to another, e.g.:", + helpExample(CommandRewrite, "/foo", "/bar"), + ), args: map[string]string{ "from": "the path to rewrite, must start with /", "to": "the path to rewrite to, must start with /", @@ -67,24 +109,29 @@ var commands = map[string]struct { }, build: func(args any) CommandHandler { orig, repl := args.(*StrTuple).Unpack() - return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) { + return NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { path := r.URL.Path if len(path) > 0 && path[0] != '/' { path = "/" + path } if !strings.HasPrefix(path, orig) { - return + return nil } path = repl + path[len(orig):] r.URL.Path = path r.URL.RawPath = r.URL.EscapedPath() r.RequestURI = r.URL.RequestURI() + return nil }) }, }, CommandServe: { help: Help{ command: CommandServe, + description: makeLines( + "Serve static files from a local file system path, e.g.:", + helpExample(CommandServe, "/var/www"), + ), args: map[string]string{ "root": "the file system path to serve, must be an existing directory", }, @@ -92,14 +139,19 @@ var commands = map[string]struct { validate: validateFSPath, build: func(args any) CommandHandler { root := args.(string) - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) { + return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { http.ServeFile(w, r, path.Join(root, path.Clean(r.URL.Path))) + return nil }) }, }, CommandRedirect: { help: Help{ command: CommandRedirect, + description: makeLines( + "Redirect request to another URL, e.g.:", + helpExample(CommandRedirect, "https://example.com"), + ), args: map[string]string{ "to": "the url to redirect to, can be relative or absolute URL", }, @@ -107,14 +159,19 @@ var commands = map[string]struct { validate: validateURL, build: func(args any) CommandHandler { target := args.(*nettypes.URL).String() - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) { + return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { http.Redirect(w, r, target, http.StatusTemporaryRedirect) + return nil }) }, }, CommandError: { help: Help{ command: CommandError, + description: makeLines( + "Send an HTTP error response and terminate processing, e.g.:", + helpExample(CommandError, "400", "bad request"), + ), args: map[string]string{ "code": "the http status code to return", "text": "the error message to return", @@ -136,14 +193,21 @@ var commands = map[string]struct { }, build: func(args any) CommandHandler { code, text := args.(*Tuple[int, string]).Unpack() - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) { + return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + // error command should overwrite the response body + GetInitResponseModifier(w).ResetBody() http.Error(w, text, code) + return nil }) }, }, CommandRequireBasicAuth: { help: Help{ command: CommandRequireBasicAuth, + description: makeLines( + "Require HTTP basic authentication for incoming requests, e.g.:", + helpExample(CommandRequireBasicAuth, "Restricted Area"), + ), args: map[string]string{ "realm": "the authentication realm", }, @@ -156,35 +220,63 @@ var commands = map[string]struct { }, build: func(args any) CommandHandler { realm := args.(string) - return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) { + return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("WWW-Authenticate", `Basic realm="`+realm+`"`) http.Error(w, "Unauthorized", http.StatusUnauthorized) + return nil }) }, }, CommandProxy: { help: Help{ command: CommandProxy, + description: makeLines( + "Proxy the request to the specified absolute URL, e.g.:", + helpExample(CommandProxy, "http://upstream:8080"), + ), args: map[string]string{ "to": "the url to proxy to, must be an absolute URL", }, }, - validate: validateAbsoluteURL, + validate: validateURL, build: func(args any) CommandHandler { target := args.(*nettypes.URL) if target.Scheme == "" { target.Scheme = "http" } + if target.Host == "" { + return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + url := target.URL + url.Host = routes.TryGetUpstreamHostPort(r) + if url.Host == "" { + return fmt.Errorf("no upstream host: %s", r.URL.String()) + } + rp := reverseproxy.NewReverseProxy(url.Host, &url, gphttp.NewTransport()) + r.URL.Path = target.Path + r.URL.RawPath = r.URL.EscapedPath() + r.RequestURI = r.URL.RequestURI() + rp.ServeHTTP(w, r) + return nil + }) + } rp := reverseproxy.NewReverseProxy("", &target.URL, gphttp.NewTransport()) - return TerminatingCommand(rp.ServeHTTP) + return TerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + rp.ServeHTTP(w, r) + return nil + }) }, }, CommandSet: { help: Help{ command: CommandSet, + description: makeLines( + "Set a field in the request or response, e.g.:", + helpExample(CommandSet, "header", "User-Agent", "godoxy"), + ), args: map[string]string{ - "field": "the field to set", - "value": "the value to set", + "target": fmt.Sprintf("the target to set, can be %s", strings.Join(AllFields, ", ")), + "field": "the field to set", + "value": "the value to set", }, }, validate: func(args []string) (any, gperr.Error) { @@ -197,9 +289,14 @@ var commands = map[string]struct { CommandAdd: { help: Help{ command: CommandAdd, + description: makeLines( + "Add a value to a field in the request or response, e.g.:", + helpExample(CommandAdd, "header", "X-Foo", "bar"), + ), args: map[string]string{ - "field": "the field to add", - "value": "the value to add", + "target": fmt.Sprintf("the target to add, can be %s", strings.Join(AllFields, ", ")), + "field": "the field to add", + "value": "the value to add", }, }, validate: func(args []string) (any, gperr.Error) { @@ -212,8 +309,13 @@ var commands = map[string]struct { CommandRemove: { help: Help{ command: CommandRemove, + description: makeLines( + "Remove a field from the request or response, e.g.:", + helpExample(CommandRemove, "header", "User-Agent"), + ), args: map[string]string{ - "field": "the field to remove", + "target": fmt.Sprintf("the target to remove, can be %s", strings.Join(AllFields, ", ")), + "field": "the field to remove", }, }, validate: func(args []string) (any, gperr.Error) { @@ -223,17 +325,157 @@ var commands = map[string]struct { return args.(CommandHandler) }, }, + CommandLog: { + isResponseHandler: true, + help: Help{ + command: CommandLog, + description: makeLines( + "The template supports the following variables:", + helpListItem("Request", "the request object"), + helpListItem("Response", "the response object"), + "", + "Example:", + helpExample(CommandLog, "info", "/dev/stdout", "{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}"), + ), + args: map[string]string{ + "level": "the log level", + "path": "the log path (/dev/stdout for stdout, /dev/stderr for stderr)", + "template": "the template to log", + }, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 3 { + return nil, ErrExpectThreeArgs + } + tmpl, err := validateTemplate(args[2], true) + if err != nil { + return nil, err + } + level, err := validateLevel(args[0]) + if err != nil { + return nil, err + } + // NOTE: file will stay opened forever + // it leverages accesslog.NewFileIO so + // it will be opened only once for the same path + f, err := openFile(args[1]) + if err != nil { + return nil, err + } + return &onLogArgs{level, f, tmpl}, nil + }, + build: func(args any) CommandHandler { + level, f, tmpl := args.(*onLogArgs).Unpack() + var logger io.Writer + if f == stdout || f == stderr { + logger = logging.NewLoggerWithFixedLevel(level, f) + } else { + logger = f + } + return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { + err := executeReqRespTemplateTo(tmpl, logger, w, r) + if err != nil { + return err + } + return nil + }) + }, + }, + CommandNotify: { + isResponseHandler: true, + help: Help{ + command: CommandNotify, + description: makeLines( + "The template supports the following variables:", + helpListItem("Request", "the request object"), + helpListItem("Response", "the response object"), + "", + "Example:", + helpExample(CommandNotify, "info", "ntfy", "Received request to {{ .Request.URL }}", "{{ .Request.Method }} {{ .Response.StatusCode }}"), + ), + args: map[string]string{ + "level": "the log level", + "provider": "the notification provider (must be defined in config `providers.notification`)", + "title": "the title of the notification", + "body": "the body of the notification", + }, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 4 { + return nil, ErrExpectFourArgs + } + titleTmpl, err := validateTemplate(args[2], false) + if err != nil { + return nil, err + } + bodyTmpl, err := validateTemplate(args[3], false) + if err != nil { + return nil, err + } + level, err := validateLevel(args[0]) + if err != nil { + return nil, err + } + // TODO: validate provider + // currently it is not possible, because rule validation happens on UnmarshalYAMLValidate + // and we cannot call config.ActiveConfig.Load() because it will cause import cycle + + // err = validateNotifProvider(args[1]) + // if err != nil { + // return nil, err + // } + return &onNotifyArgs{level, args[1], titleTmpl, bodyTmpl}, nil + }, + build: func(args any) CommandHandler { + level, provider, titleTmpl, bodyTmpl := args.(*onNotifyArgs).Unpack() + to := []string{provider} + + return OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { + buf := bufPool.Get() + defer bufPool.Put(buf) + + respBuf := bytes.NewBuffer(buf) + + err := executeReqRespTemplateTo(titleTmpl, respBuf, w, r) + if err != nil { + return err + } + titleLen := respBuf.Len() + err = executeReqRespTemplateTo(bodyTmpl, respBuf, w, r) + if err != nil { + return err + } + + notif.Notify(¬if.LogMessage{ + Level: level, + Title: string(buf[:titleLen]), + Body: notif.MessageBodyBytes(buf[titleLen:]), + To: to, + }) + return nil + }) + }, + }, } +type reqResponseTemplateData struct { + Request *http.Request + Response struct { + StatusCode int + Header http.Header + } +} + +var bufPool = synk.GetBytesPoolWithUniqueMemory() + +type onLogArgs = Tuple3[zerolog.Level, io.WriteCloser, templateOrStr] +type onNotifyArgs = Tuple4[zerolog.Level, string, templateOrStr, templateOrStr] + // Parse implements strutils.Parser. func (cmd *Command) Parse(v string) error { - lines := strutils.SplitLine(v) - if len(lines) == 0 { - return nil - } - - executors := make([]CommandHandler, 0, len(lines)) - for _, line := range lines { + executors := make([]CommandHandler, 0) + isResponseHandler := false + for line := range strings.SplitSeq(v, "\n") { if line == "" { continue } @@ -257,13 +499,21 @@ func (cmd *Command) Parse(v string) error { } validArgs, err := builder.validate(args) if err != nil { - return err.Subject(directive).Withf("%s", builder.help.String()) + // Only attach help for the directive that failed, avoid bringing in unrelated KV errors + return err.Subject(directive).With(builder.help.Error()) } - executors = append(executors, builder.build(validArgs)) + handler := builder.build(validArgs) + executors = append(executors, handler) + if builder.isResponseHandler || handler.IsResponseHandler() { + isResponseHandler = true + } } if len(executors) == 0 { + cmd.raw = v + cmd.exec = nil + cmd.isResponseHandler = false return nil } @@ -274,10 +524,14 @@ func (cmd *Command) Parse(v string) error { cmd.raw = v cmd.exec = exec + if exec.IsResponseHandler() { + isResponseHandler = true + } + cmd.isResponseHandler = isResponseHandler return nil } -func buildCmd(executors []CommandHandler) (CommandHandler, error) { +func buildCmd(executors []CommandHandler) (cmd CommandHandler, err error) { for i, exec := range executors { switch exec.(type) { case TerminatingCommand, BypassCommand: @@ -308,6 +562,10 @@ func (cmd *Command) isBypass() bool { } } +func (cmd *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) error { + return cmd.exec.Handle(w, r) +} + func (cmd *Command) String() string { return cmd.raw } diff --git a/internal/route/rules/do_log_test.go b/internal/route/rules/do_log_test.go new file mode 100644 index 00000000..272f6a31 --- /dev/null +++ b/internal/route/rules/do_log_test.go @@ -0,0 +1,400 @@ +package rules + +import ( + "fmt" + "maps" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/serialization" + gperr "github.com/yusing/goutils/errs" +) + +// mockUpstream creates a simple upstream handler for testing +func mockUpstream(status int, body string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(status) + w.Write([]byte(body)) + } +} + +// mockUpstreamWithHeaders creates an upstream that returns specific headers +func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + maps.Copy(w.Header(), headers) + w.WriteHeader(status) + w.Write([]byte(body)) + } +} + +func parseRules(data string, target *Rules) gperr.Error { + _, err := serialization.ConvertString(data, reflect.ValueOf(target)) + return err +} + +func TestLogCommand_TemporaryFile(t *testing.T) { + upstream := mockUpstreamWithHeaders(200, "success response", http.Header{ + "Content-Type": []string{"application/json"}, + }) + + // Create a temporary file for logging + tempFile, err := os.CreateTemp("", "test-log-*.log") + require.NoError(t, err) + tempFile.Close() + defer os.Remove(tempFile.Name()) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: log-request-response + do: | + log info %q '{{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }} {{ index (index .Response.Header "Content-Type") 0 }}' +`, tempFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("POST", "/api/users", nil) + req.Header.Set("User-Agent", "test-agent") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "success response", w.Body.String()) + + // Read and verify log content + content, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + logContent := string(content) + + assert.Equal(t, "POST /api/users 200 application/json\n", logContent) +} + +func TestLogCommand_StdoutAndStderr(t *testing.T) { + upstream := mockUpstream(200, "success") + + var rules Rules + err := parseRules(` +- name: log-stdout + do: | + log info /dev/stdout "stdout: {{ .Request.Method }} {{ .Response.StatusCode }}" +- name: log-stderr + do: | + log error /dev/stderr "stderr: {{ .Request.URL.Path }} {{ .Response.StatusCode }}" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + // Note: We can't easily capture stdout/stderr in unit tests, + // but we can verify no errors occurred and the handler completed +} + +func TestLogCommand_DifferentLogLevels(t *testing.T) { + upstream := mockUpstream(404, "not found") + + // Create temporary files for different log levels + infoFile, err := os.CreateTemp("", "test-info-*.log") + require.NoError(t, err) + infoFile.Close() + defer os.Remove(infoFile.Name()) + + warnFile, err := os.CreateTemp("", "test-warn-*.log") + require.NoError(t, err) + warnFile.Close() + defer os.Remove(warnFile.Name()) + + errorFile, err := os.CreateTemp("", "test-error-*.log") + require.NoError(t, err) + errorFile.Close() + defer os.Remove(errorFile.Name()) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: log-info + do: | + log info %s "INFO: {{ .Request.Method }} {{ .Response.StatusCode }}" +- name: log-warn + do: | + log warn %s "WARN: {{ .Request.URL.Path }} {{ .Response.StatusCode }}" +- name: log-error + do: | + log error %s "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}" +`, infoFile.Name(), warnFile.Name(), errorFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("DELETE", "/api/resource/123", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 404, w.Code) + + // Verify each log file + infoContent, err := os.ReadFile(infoFile.Name()) + require.NoError(t, err) + assert.Equal(t, "INFO: DELETE 404", strings.TrimSpace(string(infoContent))) + + warnContent, err := os.ReadFile(warnFile.Name()) + require.NoError(t, err) + assert.Equal(t, "WARN: /api/resource/123 404", strings.TrimSpace(string(warnContent))) + + errorContent, err := os.ReadFile(errorFile.Name()) + require.NoError(t, err) + assert.Equal(t, "ERROR: DELETE /api/resource/123 404", strings.TrimSpace(string(errorContent))) +} + +func TestLogCommand_TemplateVariables(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", "custom-value") + w.Header().Set("Content-Length", "42") + w.WriteHeader(201) + w.Write([]byte("created")) + }) + + // Create temporary file + tempFile, err := os.CreateTemp("", "test-template-*.log") + require.NoError(t, err) + tempFile.Close() + defer os.Remove(tempFile.Name()) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: log-with-templates + do: | + log info %s 'Request: {{ .Request.Method }} {{ .Request.URL }} Host: {{ .Request.Host }} User-Agent: {{ index .Request.Header "User-Agent" 0 }} Response: {{ .Response.StatusCode }} Custom-Header: {{ index .Response.Header "X-Custom-Header" 0 }} Content-Length: {{ index .Response.Header "Content-Length" 0 }}' +`, tempFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("PUT", "/api/resource", nil) + req.Header.Set("User-Agent", "test-client/1.0") + req.Host = "example.com" + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 201, w.Code) + + // Verify log content + content, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + logContent := strings.TrimSpace(string(content)) + + assert.Equal(t, "Request: PUT /api/resource Host: example.com User-Agent: test-client/1.0 Response: 201 Custom-Header: custom-value Content-Length: 42", logContent) +} + +func TestLogCommand_ConditionalLogging(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/error": + w.WriteHeader(500) + w.Write([]byte("internal server error")) + case "/notfound": + w.WriteHeader(404) + w.Write([]byte("not found")) + default: + w.WriteHeader(200) + w.Write([]byte("success")) + } + }) + + // Create temporary files + successFile, err := os.CreateTemp("", "test-success-*.log") + require.NoError(t, err) + successFile.Close() + defer os.Remove(successFile.Name()) + + errorFile, err := os.CreateTemp("", "test-error-*.log") + require.NoError(t, err) + errorFile.Close() + defer os.Remove(errorFile.Name()) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: log-success + on: status 2xx + do: | + log info %q "SUCCESS: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}" +- name: log-error + on: status 4xx | status 5xx + do: | + log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}" +`, successFile.Name(), errorFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test success request + req1 := httptest.NewRequest("GET", "/success", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + assert.Equal(t, 200, w1.Code) + + // Test not found request + req2 := httptest.NewRequest("GET", "/notfound", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, 404, w2.Code) + + // Test server error request + req3 := httptest.NewRequest("POST", "/error", nil) + w3 := httptest.NewRecorder() + handler.ServeHTTP(w3, req3) + assert.Equal(t, 500, w3.Code) + + // Verify success log + successContent, err := os.ReadFile(successFile.Name()) + require.NoError(t, err) + successLines := strings.Split(strings.TrimSpace(string(successContent)), "\n") + assert.Len(t, successLines, 1) + assert.Equal(t, "SUCCESS: GET /success 200", successLines[0]) + + // Verify error log + errorContent, err := os.ReadFile(errorFile.Name()) + require.NoError(t, err) + errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n") + assert.Len(t, errorLines, 2) + assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0]) + assert.Equal(t, "ERROR: POST /error 500", errorLines[1]) +} + +func TestLogCommand_MultipleLogEntries(t *testing.T) { + upstream := mockUpstream(200, "response") + + // Create temporary file + tempFile, err := os.CreateTemp("", "test-multiple-*.log") + require.NoError(t, err) + tempFile.Close() + defer os.Remove(tempFile.Name()) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: log-multiple + do: | + log info %q "{{ .Request.Method }} {{ .Request.URL.Path }} {{ .Response.StatusCode }}"`, tempFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Make multiple requests + requests := []struct { + method string + path string + }{ + {"GET", "/users"}, + {"POST", "/users"}, + {"PUT", "/users/1"}, + {"DELETE", "/users/1"}, + } + + for _, reqInfo := range requests { + req := httptest.NewRequest(reqInfo.method, reqInfo.path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + assert.Equal(t, 200, w.Code) + } + + // Verify all requests were logged + content, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + logContent := strings.TrimSpace(string(content)) + lines := strings.Split(logContent, "\n") + + assert.Len(t, lines, len(requests)) + + for i, reqInfo := range requests { + expectedLog := reqInfo.method + " " + reqInfo.path + " 200" + assert.Equal(t, expectedLog, lines[i]) + } +} + +func TestLogCommand_FilePermissions(t *testing.T) { + upstream := mockUpstream(200, "success") + + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "test-log-dir") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create a log file path within the temp directory + logFilePath := filepath.Join(tempDir, "test.log") + + var rules Rules + err = parseRules(fmt.Sprintf(` +- on: status 2xx + do: log info %q "{{ .Request.Method }} {{ .Response.StatusCode }}"`, logFilePath), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + + // Verify file was created and is writable + _, err = os.Stat(logFilePath) + require.NoError(t, err) + + // Test writing to the file again to ensure it's not closed + req2 := httptest.NewRequest("POST", "/test2", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + + // Verify both entries are in the file + content, err := os.ReadFile(logFilePath) + require.NoError(t, err) + logContent := strings.TrimSpace(string(content)) + lines := strings.Split(logContent, "\n") + + assert.Len(t, lines, 2) + assert.Equal(t, "GET 200", lines[0]) + assert.Equal(t, "POST 200", lines[1]) +} + +func TestLogCommand_InvalidTemplate(t *testing.T) { + upstream := mockUpstream(200, "success") + + var rules Rules + + // Test with invalid template syntax + err := parseRules(` +- name: log-invalid + do: | + log info /dev/stdout "{{ .Invalid.Field }}"`, &rules) + // Should not error during parsing, but template execution will fail gracefully + assert.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + // Should not panic + assert.NotPanics(t, func() { + handler.ServeHTTP(w, req) + }) + + assert.Equal(t, 200, w.Code) +} diff --git a/internal/route/rules/do_set.go b/internal/route/rules/do_set.go new file mode 100644 index 00000000..b64e83ff --- /dev/null +++ b/internal/route/rules/do_set.go @@ -0,0 +1,328 @@ +package rules + +import ( + "bytes" + "io" + "net/http" + "net/url" + "strconv" + + gperr "github.com/yusing/goutils/errs" + ioutils "github.com/yusing/goutils/io" +) + +type ( + FieldHandler struct { + set, add, remove CommandHandler + } + FieldModifier string +) + +const ( + ModFieldSet FieldModifier = "set" + ModFieldAdd FieldModifier = "add" + ModFieldRemove FieldModifier = "remove" +) + +const ( + FieldHeader = "header" + FieldResponseHeader = "resp_header" + FieldQuery = "query" + FieldCookie = "cookie" + FieldBody = "body" + FieldResponseBody = "resp_body" + FieldStatusCode = "status" +) + +var AllFields = []string{FieldHeader, FieldResponseHeader, FieldQuery, FieldCookie, FieldBody, FieldResponseBody, FieldStatusCode} + +// NOTE: should not use canonicalized header keys, respect to user's input +var modFields = map[string]struct { + help Help + validate ValidateFunc + builder func(args any) *FieldHandler +}{ + FieldHeader: { + help: Help{ + command: FieldHeader, + args: map[string]string{ + "key": "the header key", + "value": "the header template", + }, + }, + validate: toKeyValueTemplate, + builder: func(args any) *FieldHandler { + k, tmpl := args.(*keyValueTemplate).Unpack() + return &FieldHandler{ + set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + r.Header[k] = []string{v} + return nil + }), + add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + r.Header[k] = append(r.Header[k], v) + return nil + }), + remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + delete(r.Header, k) + return nil + }), + } + }, + }, + FieldResponseHeader: { + help: Help{ + command: FieldResponseHeader, + args: map[string]string{ + "key": "the response header key", + "value": "the response header template", + }, + }, + validate: toKeyValueTemplate, + builder: func(args any) *FieldHandler { + k, tmpl := args.(*keyValueTemplate).Unpack() + return &FieldHandler{ + set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + w.Header()[k] = []string{v} + return nil + }), + add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + w.Header()[k] = append(w.Header()[k], v) + return nil + }), + remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + delete(w.Header(), k) + return nil + }), + } + }, + }, + FieldQuery: { + help: Help{ + command: FieldQuery, + args: map[string]string{ + "key": "the query key", + "value": "the query template", + }, + }, + validate: toKeyValueTemplate, + builder: func(args any) *FieldHandler { + k, tmpl := args.(*keyValueTemplate).Unpack() + return &FieldHandler{ + set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + queries.Set(k, v) + }) + return nil + }), + add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + queries.Add(k, v) + }) + return nil + }), + remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + GetSharedData(w).UpdateQueries(r, func(queries url.Values) { + queries.Del(k) + }) + return nil + }), + } + }, + }, + FieldCookie: { + help: Help{ + command: FieldCookie, + args: map[string]string{ + "key": "the cookie key", + "value": "the cookie value", + }, + }, + validate: toKeyValueTemplate, + builder: func(args any) *FieldHandler { + k, tmpl := args.(*keyValueTemplate).Unpack() + return &FieldHandler{ + set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + for i, c := range cookies { + if c.Name == k { + cookies[i].Value = v + return cookies + } + } + return append(cookies, &http.Cookie{Name: k, Value: v}) + }) + return nil + }), + add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + v, err := executeRequestTemplateString(tmpl, r) + if err != nil { + return err + } + GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + return append(cookies, &http.Cookie{Name: k, Value: v}) + }) + return nil + }), + remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + GetSharedData(w).UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { + index := -1 + for i, c := range cookies { + if c.Name == k { + index = i + break + } + } + if index != -1 { + if len(cookies) == 1 { + return []*http.Cookie{} + } + return append(cookies[:index], cookies[index+1:]...) + } + return cookies + }) + return nil + }), + } + }, + }, + FieldBody: { + help: Help{ + command: FieldBody, + description: makeLines( + "Override the request body that will be sent to the upstream", + "The template supports the following variables:", + helpListItem("Request", "the request object"), + "", + "Example:", + helpExample(FieldBody, "HTTP STATUS: {{ .Request.Method }} {{ .Request.URL.Path }}"), + ), + args: map[string]string{ + "template": "the body template", + }, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + return validateTemplate(args[0], true) + }, + builder: func(args any) *FieldHandler { + tmpl := args.(templateOrStr) + return &FieldHandler{ + set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + if r.Body != nil { + r.Body.Close() + r.Body = nil + } + + buf := pool.Get() + b := bytes.NewBuffer(buf) + + err := executeRequestTemplateTo(tmpl, b, r) + if err != nil { + return err + } + r.Body = ioutils.NewHookReadCloser(io.NopCloser(b), func() { + pool.Put(buf) + }) + return nil + }), + } + }, + }, + FieldResponseBody: { + help: Help{ + command: FieldResponseBody, + description: makeLines( + "Override the response body that will be sent to the client", + "The template supports the following variables:", + helpListItem("Request", "the request object"), + helpListItem("Response", "the response object"), + "", + "Example:", + helpExample(FieldResponseBody, "HTTP STATUS: {{ .Request.Method }} {{ .Response.StatusCode }}"), + ), + args: map[string]string{ + "template": "the response body template", + }, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + return validateTemplate(args[0], true) + }, + builder: func(args any) *FieldHandler { + tmpl := args.(templateOrStr) + return &FieldHandler{ + set: OnResponseCommand(func(w http.ResponseWriter, r *http.Request) error { + rm := GetInitResponseModifier(w) + rm.ResetBody() + return executeReqRespTemplateTo(tmpl, rm, rm, r) + }), + } + }, + }, + FieldStatusCode: { + help: Help{ + command: FieldStatusCode, + description: makeLines( + "Override the status code that will be sent to the client, e.g.:", + helpExample(FieldStatusCode, "200"), + ), + args: map[string]string{ + "code": "the status code", + }, + }, + validate: func(args []string) (any, gperr.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + status, err := strconv.Atoi(args[0]) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + if status < 100 || status > 599 { + return nil, ErrInvalidArguments.Withf("status code must be between 100 and 599, got %d", status) + } + return status, nil + }, + builder: func(args any) *FieldHandler { + status := args.(int) + return &FieldHandler{ + set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) error { + GetInitResponseModifier(w).WriteHeader(status) + return nil + }), + } + }, + }, +} diff --git a/internal/route/rules/do_set_test.go b/internal/route/rules/do_set_test.go new file mode 100644 index 00000000..3f007d6e --- /dev/null +++ b/internal/route/rules/do_set_test.go @@ -0,0 +1,643 @@ +package rules + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFieldHandler_Header(t *testing.T) { + tests := []struct { + name string + key string + value string + modifier FieldModifier + setup func(*http.Request) + verify func(*http.Request, *httptest.ResponseRecorder) + }{ + { + name: "set header", + key: "X-Test", + value: "test-value", + modifier: ModFieldSet, + setup: func(r *http.Request) { + r.Header.Set("X-Test", "old-value") + }, + verify: func(r *http.Request, w *httptest.ResponseRecorder) { + got := r.Header.Get("X-Test") + assert.Equal(t, "test-value", got, "Expected header X-Test to be 'test-value'") + }, + }, + { + name: "add header", + key: "X-Test", + value: "new-value", + modifier: ModFieldAdd, + setup: func(r *http.Request) { + r.Header.Set("X-Test", "existing-value") + }, + verify: func(r *http.Request, w *httptest.ResponseRecorder) { + values := r.Header["X-Test"] + require.Len(t, values, 2, "Expected 2 header values") + assert.Equal(t, "existing-value", values[0], "Expected first value of X-Test header to be 'existing-value'") + assert.Equal(t, "new-value", values[1], "Expected second value of X-Test header to be 'new-value'") + }, + }, + { + name: "remove header", + key: "X-Test", + value: "", + modifier: ModFieldRemove, + setup: func(r *http.Request) { + r.Header.Set("X-Test", "to-be-removed") + }, + verify: func(r *http.Request, w *httptest.ResponseRecorder) { + got := r.Header.Get("X-Test") + assert.Empty(t, got, "Expected header X-Test to be removed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + tt.setup(req) + w := httptest.NewRecorder() + + tmpl, tErr := validateTemplate(tt.value, false) + if tErr != nil { + t.Fatalf("Failed to validate template: %v", tErr) + } + handler := modFields[FieldHeader].builder(&keyValueTemplate{tt.key, tmpl}) + var cmd CommandHandler + switch tt.modifier { + case ModFieldSet: + cmd = handler.set + case ModFieldAdd: + cmd = handler.add + case ModFieldRemove: + cmd = handler.remove + } + + err := cmd.Handle(w, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + + tt.verify(req, w) + }) + } +} + +func TestFieldHandler_ResponseHeader(t *testing.T) { + tests := []struct { + name string + key string + value string + modifier FieldModifier + setup func(*httptest.ResponseRecorder) + verify func(*httptest.ResponseRecorder) + }{ + { + name: "set response header", + key: "X-Response-Test", + value: "response-value", + modifier: ModFieldSet, + verify: func(w *httptest.ResponseRecorder) { + got := w.Header().Get("X-Response-Test") + assert.Equal(t, "response-value", got, "Expected response header X-Response-Test to be 'response-value'") + }, + }, + { + name: "add response header", + key: "X-Response-Test", + value: "additional-value", + modifier: ModFieldAdd, + setup: func(w *httptest.ResponseRecorder) { + w.Header().Set("X-Response-Test", "existing-value") + }, + verify: func(w *httptest.ResponseRecorder) { + values := w.Header()["X-Response-Test"] + require.Len(t, values, 2) + assert.Equal(t, values[0], "existing-value") + assert.Equal(t, values[1], "additional-value") + }, + }, + { + name: "remove response header", + key: "X-Response-Test", + value: "", + modifier: ModFieldRemove, + verify: func(w *httptest.ResponseRecorder) { + assert.Empty(t, w.Header().Get("X-Response-Test")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + if tt.setup != nil { + tt.setup(w) + } + + tmpl, tErr := validateTemplate(tt.value, false) + if tErr != nil { + t.Fatalf("Failed to validate template: %v", tErr) + } + handler := modFields[FieldResponseHeader].builder(&keyValueTemplate{tt.key, tmpl}) + var cmd CommandHandler + switch tt.modifier { + case ModFieldSet: + cmd = handler.set + case ModFieldAdd: + cmd = handler.add + case ModFieldRemove: + cmd = handler.remove + } + + err := cmd.Handle(w, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + + tt.verify(w) + }) + } +} + +func TestFieldHandler_Query(t *testing.T) { + tests := []struct { + name string + key string + value string + modifier FieldModifier + setup func(*http.Request) + verify func(*http.Request) + }{ + { + name: "set query", + key: "test", + value: "new-value", + modifier: ModFieldSet, + setup: func(r *http.Request) { + r.URL.RawQuery = "test=old-value&other=keep" + }, + verify: func(r *http.Request) { + got := r.URL.Query().Get("test") + assert.Equal(t, "new-value", got, "Expected query 'test' to be 'new-value'") + gotOther := r.URL.Query().Get("other") + assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'") + }, + }, + { + name: "add query", + key: "test", + value: "additional-value", + modifier: ModFieldAdd, + setup: func(r *http.Request) { + r.URL.RawQuery = "test=existing-value" + }, + verify: func(r *http.Request) { + values := r.URL.Query()["test"] + require.Len(t, values, 2, "Expected 2 query values") + assert.Equal(t, "existing-value", values[0], "Expected first value of test query param to be 'existing-value'") + assert.Equal(t, "additional-value", values[1], "Expected second value of test query param to be 'additional-value'") + }, + }, + { + name: "remove query", + key: "test", + value: "", + modifier: ModFieldRemove, + setup: func(r *http.Request) { + r.URL.RawQuery = "test=to-be-removed&other=keep" + }, + verify: func(r *http.Request) { + got := r.URL.Query().Get("test") + assert.Empty(t, got, "Expected query 'test' to be removed") + gotOther := r.URL.Query().Get("other") + assert.Equal(t, "keep", gotOther, "Expected query 'other' to be 'keep'") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + tt.setup(req) + w := httptest.NewRecorder() + + tmpl, tErr := validateTemplate(tt.value, false) + if tErr != nil { + t.Fatalf("Failed to validate template: %v", tErr) + } + handler := modFields[FieldQuery].builder(&keyValueTemplate{tt.key, tmpl}) + var cmd CommandHandler + switch tt.modifier { + case ModFieldSet: + cmd = handler.set + case ModFieldAdd: + cmd = handler.add + case ModFieldRemove: + cmd = handler.remove + } + + err := cmd.Handle(w, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + + tt.verify(req) + }) + } +} + +func TestFieldHandler_Cookie(t *testing.T) { + tests := []struct { + name string + key string + value string + modifier FieldModifier + setup func(*http.Request) + verify func(*http.Request) + }{ + { + name: "set cookie", + key: "test", + value: "new-value", + modifier: ModFieldSet, + setup: func(r *http.Request) { + r.AddCookie(&http.Cookie{Name: "test", Value: "old-value"}) + }, + verify: func(r *http.Request) { + cookie, err := r.Cookie("test") + assert.NoError(t, err, "Expected cookie 'test' to exist") + if err == nil { + assert.Equal(t, "new-value", cookie.Value, "Expected cookie 'test' to be 'new-value'") + } + }, + }, + { + name: "add cookie", + key: "test", + value: "additional-value", + modifier: ModFieldAdd, + setup: func(r *http.Request) { + r.AddCookie(&http.Cookie{Name: "test", Value: "existing-value"}) + }, + verify: func(r *http.Request) { + cookies := r.Cookies() + testCookies := make([]string, 0) + for _, c := range cookies { + if c.Name == "test" { + testCookies = append(testCookies, c.Value) + } + } + require.Len(t, testCookies, 2, "Expected 2 cookies with name 'test'") + assert.Equal(t, "existing-value", testCookies[0], "Expected first value of 'test' cookie to be 'existing-value'") + assert.Equal(t, "additional-value", testCookies[1], "Expected second value of 'test' cookie to be 'additional-value'") + }, + }, + { + name: "remove cookie", + key: "test", + value: "", + modifier: ModFieldRemove, + setup: func(r *http.Request) { + r.AddCookie(&http.Cookie{Name: "test", Value: "to-be-removed"}) + r.AddCookie(&http.Cookie{Name: "other", Value: "keep"}) + }, + verify: func(r *http.Request) { + _, err := r.Cookie("test") + assert.Error(t, err, "Expected cookie 'test' to be removed") + cookie, err := r.Cookie("other") + assert.NoError(t, err, "Expected cookie 'other' to exist") + if err == nil { + assert.Equal(t, "keep", cookie.Value, "Expected cookie 'other' to be 'keep'") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + tt.setup(req) + w := httptest.NewRecorder() + + tmpl, tErr := validateTemplate(tt.value, false) + if tErr != nil { + t.Fatalf("Failed to validate template: %v", tErr) + } + handler := modFields[FieldCookie].builder(&keyValueTemplate{tt.key, tmpl}) + var cmd CommandHandler + switch tt.modifier { + case ModFieldSet: + cmd = handler.set + case ModFieldAdd: + cmd = handler.add + case ModFieldRemove: + cmd = handler.remove + } + + err := cmd.Handle(w, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + + tt.verify(req) + }) + } +} + +func TestFieldHandler_Body(t *testing.T) { + tests := []struct { + name string + template string + setup func(*http.Request) + verify func(*http.Request) + }{ + { + name: "set body with template", + template: "Hello {{ .Request.Method }} {{ .Request.URL.Path }}", + setup: func(r *http.Request) { + r.Method = "POST" + r.URL.Path = "/test" + }, + verify: func(r *http.Request) { + body, err := io.ReadAll(r.Body) + assert.NoError(t, err, "Failed to read body") + expected := "Hello POST /test" + assert.Equal(t, expected, string(body), "Expected body content") + }, + }, + { + name: "set body with existing body", + template: "Overridden", + setup: func(r *http.Request) { + r.Body = io.NopCloser(strings.NewReader("original body")) + }, + verify: func(r *http.Request) { + body, err := io.ReadAll(r.Body) + assert.NoError(t, err, "Failed to read body") + assert.Equal(t, "Overridden", string(body), "Expected body to be 'Overridden'") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + tt.setup(req) + w := httptest.NewRecorder() + + tmpl, tErr := validateTemplate(tt.template, false) + if tErr != nil { + t.Fatalf("Failed to parse template: %v", tErr) + } + + handler := modFields[FieldBody].builder(tmpl) + err := handler.set.Handle(w, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + + tt.verify(req) + }) + } +} + +func TestFieldHandler_ResponseBody(t *testing.T) { + tests := []struct { + name string + template string + setup func(*http.Request) + verify func(*ResponseModifier) + }{ + { + name: "set response body with template", + template: "Response: {{ .Request.Method }} {{ .Request.URL.Path }}", + setup: func(r *http.Request) { + r.Method = "GET" + r.URL.Path = "/api/test" + }, + verify: func(rm *ResponseModifier) { + content := rm.buf.String() + expected := "Response: GET /api/test" + assert.Equal(t, expected, content, "Expected response body") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + tt.setup(req) + w := httptest.NewRecorder() + + // Create ResponseModifier wrapper + rm := NewResponseModifier(w) + + tmpl, tErr := validateTemplate(tt.template, false) + if tErr != nil { + t.Fatalf("Failed to parse template: %v", tErr) + } + + handler := modFields[FieldResponseBody].builder(tmpl) + err := handler.set.Handle(rm, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + + tt.verify(rm) + }) + } +} + +func TestFieldHandler_StatusCode(t *testing.T) { + tests := []struct { + name string + status int + verify func(*httptest.ResponseRecorder) + }{ + { + name: "set status code 200", + status: 200, + verify: func(w *httptest.ResponseRecorder) { + assert.Equal(t, 200, w.Code, "Expected status code 200") + }, + }, + { + name: "set status code 404", + status: 404, + verify: func(w *httptest.ResponseRecorder) { + assert.Equal(t, 404, w.Code, "Expected status code 404") + }, + }, + { + name: "set status code 500", + status: 500, + verify: func(w *httptest.ResponseRecorder) { + assert.Equal(t, 500, w.Code, "Expected status code 500") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + rm := NewResponseModifier(w) + var cmd Command + err := cmd.Parse(fmt.Sprintf("set %s %d", FieldStatusCode, tt.status)) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + err = cmd.ServeHTTP(rm, req) + if err != nil { + t.Fatalf("Handler returned error: %v", err) + } + rm.FlushRelease() + + tt.verify(w) + }) + } +} + +func TestFieldValidation(t *testing.T) { + tests := []struct { + name string + field string + args []string + wantError bool + }{ + { + name: "header valid", + field: FieldHeader, + args: []string{"key", "value"}, + wantError: false, + }, + { + name: "header invalid - missing value", + field: FieldHeader, + args: []string{"key"}, + wantError: true, + }, + { + name: "response header valid", + field: FieldResponseHeader, + args: []string{"key", "value"}, + wantError: false, + }, + { + name: "query valid", + field: FieldQuery, + args: []string{"key", "value"}, + wantError: false, + }, + { + name: "cookie valid", + field: FieldCookie, + args: []string{"key", "value"}, + wantError: false, + }, + { + name: "body valid template", + field: FieldBody, + args: []string{"Hello {{ .Request.Method }}"}, + wantError: false, + }, + { + name: "body invalid template syntax", + field: FieldBody, + args: []string{"Hello {{ .InvalidField "}, + wantError: true, + }, + { + name: "response body valid template", + field: FieldResponseBody, + args: []string{"Response: {{ .Request.Method }}"}, + wantError: false, + }, + { + name: "status code valid", + field: FieldStatusCode, + args: []string{"200"}, + wantError: false, + }, + { + name: "status code invalid - too low", + field: FieldStatusCode, + args: []string{"99"}, + wantError: true, + }, + { + name: "status code invalid - too high", + field: FieldStatusCode, + args: []string{"600"}, + wantError: true, + }, + { + name: "status code invalid - not a number", + field: FieldStatusCode, + args: []string{"not-a-number"}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + field, exists := modFields[tt.field] + assert.True(t, exists, "Field %s does not exist", tt.field) + + _, err := field.validate(tt.args) + if tt.wantError { + assert.Error(t, err, "Expected error but got none") + } else { + assert.NoError(t, err, "Expected no error but got: %v", err) + } + }) + } +} + +func TestAllFields(t *testing.T) { + expectedFields := []string{ + FieldHeader, + FieldResponseHeader, + FieldQuery, + FieldCookie, + FieldBody, + FieldResponseBody, + FieldStatusCode, + } + + require.Len(t, AllFields, len(expectedFields), "Expected %d fields", len(expectedFields)) + + for _, expected := range expectedFields { + found := false + for _, actual := range AllFields { + if actual == expected { + found = true + break + } + } + assert.True(t, found, "Expected field %s not found in AllFields", expected) + } +} + +func TestModFields(t *testing.T) { + for fieldName, field := range modFields { + // Test that each field has required components + assert.NotNil(t, field.validate, "Field %s has nil validate function", fieldName) + assert.NotNil(t, field.builder, "Field %s has nil builder function", fieldName) + assert.NotEmpty(t, field.help.command, "Field %s has empty help command", fieldName) + } +} diff --git a/internal/route/rules/do_test.go b/internal/route/rules/do_test.go index 16b84efc..034e1742 100644 --- a/internal/route/rules/do_test.go +++ b/internal/route/rules/do_test.go @@ -99,10 +99,15 @@ func TestParseCommands(t *testing.T) { }, // proxy directive tests { - name: "proxy_valid", + name: "proxy_valid_abs", input: "proxy http://localhost:8080", wantErr: nil, }, + { + name: "proxy_valid_rel", + input: "proxy /foo/bar", + wantErr: nil, + }, { name: "proxy_missing_target", input: "proxy", diff --git a/internal/route/rules/error_format_test.go b/internal/route/rules/error_format_test.go new file mode 100644 index 00000000..9a2846b6 --- /dev/null +++ b/internal/route/rules/error_format_test.go @@ -0,0 +1,23 @@ +package rules + +import ( + "testing" + + gperr "github.com/yusing/goutils/errs" +) + +func TestErrorFormat(t *testing.T) { + var rules Rules + err := parseRules(` +- on: error 405 + do: error 405 error +- on: header too many args + do: error 405 error +- name: missing do + on: status 200 +- on: header X-Header + do: set invalid_command +- do: set resp_body "{{ .Request.Method {{ .Request.URL.Path }}" +`, &rules) + gperr.LogError("error", err) +} diff --git a/internal/route/rules/errors.go b/internal/route/rules/errors.go index b8ba7a88..b15efa7f 100644 --- a/internal/route/rules/errors.go +++ b/internal/route/rules/errors.go @@ -7,15 +7,21 @@ import ( var ( ErrUnterminatedQuotes = gperr.New("unterminated quotes") ErrUnterminatedBrackets = gperr.New("unterminated brackets") + ErrUnterminatedEnvVar = gperr.New("unterminated env var") ErrUnknownDirective = gperr.New("unknown directive") + ErrUnknownModField = gperr.New("unknown field") ErrEnvVarNotFound = gperr.New("env variable not found") ErrInvalidArguments = gperr.New("invalid arguments") ErrInvalidOnTarget = gperr.New("invalid `rule.on` target") ErrInvalidCommandSequence = gperr.New("invalid command sequence") - ErrInvalidSetTarget = gperr.New("invalid `rule.set` target") - ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg") - ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg") - ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args") - ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'") + ErrExpectNoArg = gperr.Wrap(ErrInvalidArguments, "expect no arg") + ErrExpectOneArg = gperr.Wrap(ErrInvalidArguments, "expect 1 arg") + ErrExpectTwoArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 args") + ErrExpectTwoOrThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 2 or 3 args") + ErrExpectThreeArgs = gperr.Wrap(ErrInvalidArguments, "expect 3 args") + ErrExpectFourArgs = gperr.Wrap(ErrInvalidArguments, "expect 4 args") + ErrExpectKVOptionalV = gperr.Wrap(ErrInvalidArguments, "expect 'key' or 'key value'") + + errTerminated = gperr.New("terminated") ) diff --git a/internal/route/rules/fields.go b/internal/route/rules/fields.go deleted file mode 100644 index 3abcff6f..00000000 --- a/internal/route/rules/fields.go +++ /dev/null @@ -1,142 +0,0 @@ -package rules - -import ( - "net/http" - "net/url" -) - -type ( - FieldHandler struct { - set, add, remove CommandHandler - } - FieldModifier string -) - -const ( - ModFieldSet FieldModifier = "set" - ModFieldAdd FieldModifier = "add" - ModFieldRemove FieldModifier = "remove" -) - -const ( - FieldHeader = "header" - FieldQuery = "query" - FieldCookie = "cookie" -) - -var modFields = map[string]struct { - help Help - validate ValidateFunc - builder func(args any) *FieldHandler -}{ - FieldHeader: { - help: Help{ - command: FieldHeader, - args: map[string]string{ - "key": "the header key", - "value": "the header value", - }, - }, - validate: toStrTuple, - builder: func(args any) *FieldHandler { - k, v := args.(*StrTuple).Unpack() - return &FieldHandler{ - set: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) { - w.Header()[k] = []string{v} - }), - add: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) { - h := w.Header() - h[k] = append(h[k], v) - }), - remove: NonTerminatingCommand(func(w http.ResponseWriter, r *http.Request) { - delete(w.Header(), k) - }), - } - }, - }, - FieldQuery: { - help: Help{ - command: FieldQuery, - args: map[string]string{ - "key": "the query key", - "value": "the query value", - }, - }, - validate: toStrTuple, - builder: func(args any) *FieldHandler { - k, v := args.(*StrTuple).Unpack() - return &FieldHandler{ - set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { - cached.UpdateQueries(r, func(queries url.Values) { - queries.Set(k, v) - }) - return true - }), - add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { - cached.UpdateQueries(r, func(queries url.Values) { - queries.Add(k, v) - }) - return true - }), - remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { - cached.UpdateQueries(r, func(queries url.Values) { - queries.Del(k) - }) - return true - }), - } - }, - }, - FieldCookie: { - help: Help{ - command: FieldCookie, - args: map[string]string{ - "key": "the cookie key", - "value": "the cookie value", - }, - }, - validate: toStrTuple, - builder: func(args any) *FieldHandler { - k, v := args.(*StrTuple).Unpack() - return &FieldHandler{ - set: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { - cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { - for i, c := range cookies { - if c.Name == k { - cookies[i].Value = v - return cookies - } - } - return append(cookies, &http.Cookie{Name: k, Value: v}) - }) - return true - }), - add: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { - cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { - return append(cookies, &http.Cookie{Name: k, Value: v}) - }) - return true - }), - remove: DynamicCommand(func(cached Cache, w http.ResponseWriter, r *http.Request) bool { - cached.UpdateCookies(r, func(cookies []*http.Cookie) []*http.Cookie { - index := -1 - for i, c := range cookies { - if c.Name == k { - index = i - break - } - } - if index != -1 { - if len(cookies) == 1 { - return []*http.Cookie{} - } - return append(cookies[:index], cookies[index+1:]...) - } - return cookies - }) - return true - }), - } - }, - }, -} diff --git a/internal/route/rules/help.go b/internal/route/rules/help.go index 9b222b70..1f92f4e4 100644 --- a/internal/route/rules/help.go +++ b/internal/route/rules/help.go @@ -1,40 +1,134 @@ package rules -import "strings" +import ( + "fmt" + "slices" + "strconv" + "strings" + + gperr "github.com/yusing/goutils/errs" + "github.com/yusing/goutils/strings/ansi" +) type Help struct { command string - description string + description []string args map[string]string // args[arg] -> description } -/* -Generate help string, e.g. +func makeLines(lines ...string) []string { + return lines +} - rewrite - from: the path to rewrite, must start with / - to: the path to rewrite to, must start with / -*/ -func (h *Help) String() string { +func helpExample(cmd string, args ...string) string { var sb strings.Builder - sb.WriteString(h.command) - sb.WriteString(" ") - for arg := range h.args { - sb.WriteString(strings.ToUpper(arg)) - sb.WriteRune(' ') - } - if h.description != "" { - sb.WriteString("\n\t") - sb.WriteString(h.description) - sb.WriteRune('\n') - } - sb.WriteRune('\n') - for arg, desc := range h.args { - sb.WriteRune('\t') - sb.WriteString(strings.ToUpper(arg)) - sb.WriteString(": ") - sb.WriteString(desc) - sb.WriteRune('\n') + sb.WriteString(" ") + sb.WriteString(ansi.WithANSI(cmd, ansi.HighlightGreen)) + for _, arg := range args { + var out strings.Builder + pos := 0 + for { + start := strings.Index(arg[pos:], "{{") + if start == -1 { + if pos < len(arg) { + // If no template at all (pos == 0), cyan highlight for whole-arg + // Otherwise, for mixed strings containing templates, leave non-template text unhighlighted + if pos == 0 { + out.WriteString(ansi.WithANSI(arg[pos:], ansi.HighlightCyan)) + } else { + out.WriteString(arg[pos:]) + } + } + break + } + start += pos + if start > pos { + // Non-template text should not be highlighted + out.WriteString(arg[pos:start]) + } + end := strings.Index(arg[start+2:], "}}") + if end == -1 { + // Unmatched template start; write remainder without highlighting + out.WriteString(arg[start:]) + break + } + end += start + 2 + inner := strings.TrimSpace(arg[start+2 : end]) + parts := strings.Split(inner, ".") + out.WriteString(helpTemplateVar(parts...)) + pos = end + 2 + } + fmt.Fprintf(&sb, ` "%s"`, out.String()) } return sb.String() } + +func helpListItem(key string, value string) string { + var sb strings.Builder + sb.WriteString(" ") + sb.WriteString(ansi.WithANSI(key, ansi.HighlightYellow)) + sb.WriteString(": ") + sb.WriteString(value) + return sb.String() +} + +// helpFuncCall generates a string like "fn(arg1, arg2, arg3)" +func helpFuncCall(fn string, args ...string) string { + var sb strings.Builder + sb.WriteString(ansi.WithANSI(fn, ansi.HighlightRed)) + sb.WriteString("(") + for i, arg := range args { + fmt.Fprintf(&sb, `"%s"`, ansi.WithANSI(arg, ansi.HighlightCyan)) + if i < len(args)-1 { + sb.WriteString(", ") + } + } + sb.WriteString(")") + return sb.String() +} + +// helpTemplateVar generates a string like "{{ .Request.Method }} {{ .Request.URL.Path }}" +func helpTemplateVar(parts ...string) string { + var sb strings.Builder + sb.WriteString(ansi.WithANSI("{{ ", ansi.HighlightWhite)) + for i, part := range parts { + sb.WriteString(ansi.WithANSI(part, ansi.HighlightCyan)) + if i < len(parts)-1 { + sb.WriteString(".") + } + } + sb.WriteString(ansi.WithANSI(" }}", ansi.HighlightWhite)) + return sb.String() +} + +/* +Generate help string as error, e.g. + + rewrite + from: the path to rewrite, must start with / + to: the path to rewrite to, must start with / +*/ +func (h *Help) Error() gperr.Error { + var lines gperr.MultilineError + + lines.Adds(ansi.WithANSI(h.command, ansi.HighlightGreen)) + lines.AddStrings(h.description...) + lines.Adds(" args:") + + argKeys := make([]string, 0, len(h.args)) + longestArg := 0 + for arg := range h.args { + if len(arg) > longestArg { + longestArg = len(arg) + } + argKeys = append(argKeys, arg) + } + + // sort argKeys alphabetically to make output stable + slices.Sort(argKeys) + for _, arg := range argKeys { + desc := h.args[arg] + lines.Addf(" %-"+strconv.Itoa(longestArg)+"s: %s", ansi.WithANSI(arg, ansi.HighlightCyan), desc) + } + return &lines +} diff --git a/internal/route/rules/http_flow_test.go b/internal/route/rules/http_flow_test.go new file mode 100644 index 00000000..375a8992 --- /dev/null +++ b/internal/route/rules/http_flow_test.go @@ -0,0 +1,940 @@ +package rules_test + +import ( + "fmt" + "maps" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/route" + "github.com/yusing/godoxy/internal/route/routes" + "github.com/yusing/godoxy/internal/serialization" + gperr "github.com/yusing/goutils/errs" + "golang.org/x/crypto/bcrypt" + + . "github.com/yusing/godoxy/internal/route/rules" +) + +// mockUpstream creates a simple upstream handler for testing +func mockUpstream(status int, body string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(status) + w.Write([]byte(body)) + } +} + +// mockUpstreamWithHeaders creates an upstream that returns specific headers +func mockUpstreamWithHeaders(status int, body string, headers http.Header) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + maps.Copy(w.Header(), headers) + w.WriteHeader(status) + w.Write([]byte(body)) + } +} + +func mockRoute(alias string) *route.FileServer { + return &route.FileServer{Route: &route.Route{Alias: alias}} +} + +func parseRules(data string, target *Rules) gperr.Error { + _, err := serialization.ConvertString(strings.TrimSpace(data), reflect.ValueOf(target)) + return err +} + +func TestHTTPFlow_BasicPreRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", r.Header.Get("X-Custom-Header")) + w.WriteHeader(200) + w.Write([]byte("upstream response")) + }) + + var rules Rules + err := parseRules(` +- name: add-header + on: path / + do: set header X-Custom-Header test-value +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "upstream response", w.Body.String()) + assert.Equal(t, "test-value", w.Header().Get("X-Custom-Header")) +} + +func TestHTTPFlow_BypassRule(t *testing.T) { + upstream := mockUpstream(200, "upstream response") + + var rules Rules + err := parseRules(` +- name: bypass-condition + on: path /bypass + do: bypass +- name: should-not-execute + on: path /bypass + do: error 500 "should not reach here" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/bypass", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "upstream response", w.Body.String()) +} + +func TestHTTPFlow_TerminatingCommand(t *testing.T) { + upstream := mockUpstream(200, "should not be called") + + var rules Rules + err := parseRules(` +- name: error-response + on: path /error + do: error 403 Forbidden +- name: should-not-execute + on: path /error + do: set header X-Header ignored +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/error", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 403, w.Code) + assert.Equal(t, "Forbidden\n", w.Body.String()) + assert.Empty(t, w.Header().Get("X-Header")) +} + +func TestHTTPFlow_RedirectFlow(t *testing.T) { + upstream := mockUpstream(200, "should not be called") + + var rules Rules + err := parseRules(` +- name: redirect-rule + on: path /old-path + do: redirect /new-path +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/old-path", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 307, w.Code) // TemporaryRedirect + assert.Equal(t, "/new-path", w.Header().Get("Location")) +} + +func TestHTTPFlow_RewriteFlow(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("path: " + r.URL.Path)) + }) + + var rules Rules + err := parseRules(` +- name: rewrite-rule + on: path glob(/api/*) + do: rewrite /api/ /v1/ +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/api/users", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "path: /v1/users", w.Body.String()) +} + +func TestHTTPFlow_MultiplePreRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("upstream: " + r.Header.Get("X-Request-Id"))) + }) + + var rules Rules + err := parseRules(` +- name: add-request-id + on: path / + do: set header X-Request-Id req-123 +- name: add-auth-header + on: path / + do: set header X-Auth-Token token-456 +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "upstream: req-123", w.Body.String()) + assert.Equal(t, "token-456", req.Header.Get("X-Auth-Token")) +} + +func TestHTTPFlow_PostResponseRule(t *testing.T) { + upstream := mockUpstreamWithHeaders(200, "success", http.Header{ + "X-Upstream": []string{"upstream-value"}, + }) + + tempFile, err := os.CreateTemp("", "test-log-*.txt") + // Create a temporary file for logging + require.NoError(t, err) + defer os.Remove(tempFile.Name()) + tempFile.Close() + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: log-response + on: path /test + do: log info %s "{{ .Request.Method }} {{ .Response.StatusCode }}" +`, tempFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "success", w.Body.String()) + assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream")) + + // Check log file + content, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + assert.Equal(t, "GET 200\n", string(content)) +} + +func TestHTTPFlow_ResponseRuleWithStatusCondition(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/success" { + w.WriteHeader(200) + w.Write([]byte("success")) + } else { + w.WriteHeader(404) + w.Write([]byte("not found")) + } + }) + + var rules Rules + + // Create a temporary file for logging + tempFile, err := os.CreateTemp("", "test-error-log-*.txt") + require.NoError(t, err) + defer os.Remove(tempFile.Name()) + tempFile.Close() + + err = parseRules(fmt.Sprintf(` +- name: log-errors + on: status 4xx + do: log error %s "{{ .Request.URL }} returned {{ .Response.StatusCode }}" +`, tempFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test successful request (should not log) + req1 := httptest.NewRequest("GET", "/success", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + + // Test error request (should log) + req2 := httptest.NewRequest("GET", "/notfound", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 404, w2.Code) + + // Check log file + content, err := os.ReadFile(tempFile.Name()) + require.NoError(t, err) + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + require.Len(t, lines, 1, "only 4xx requests should be logged") + assert.Equal(t, "/notfound returned 404", lines[0]) +} + +func TestHTTPFlow_ConditionalRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("hello " + r.Header.Get("X-Username"))) + }) + + var rules Rules + err := parseRules(` +- name: auth-required + on: header Authorization + do: | + set header X-Username authenticated-user + set resp_header X-Username authenticated-user +- name: default + do: | + set header X-Username anonymous + set resp_header X-Username anonymous +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test with Authorization header + req1 := httptest.NewRequest("GET", "/", nil) + req1.Header.Set("Authorization", "Bearer token") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "hello authenticated-user", w1.Body.String()) + assert.Equal(t, "authenticated-user", w1.Header().Get("X-Username")) + + // Test without Authorization header + req2 := httptest.NewRequest("GET", "/", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + assert.Equal(t, 200, w2.Code) + assert.Equal(t, "hello anonymous", w2.Body.String()) + assert.Equal(t, "anonymous", w2.Header().Get("X-Username")) +} + +func TestHTTPFlow_ComplexFlowWithPreAndPostRules(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate different responses based on path + if r.URL.Path == "/protected" { + if r.Header.Get("X-Auth") != "valid" { + w.WriteHeader(401) + w.Write([]byte("unauthorized")) + return + } + } + w.Header().Set("X-Response-Time", "100ms") + w.WriteHeader(200) + w.Write([]byte("success")) + }) + + // Create temporary files for logging + logFile, err := os.CreateTemp("", "test-access-log-*.txt") + require.NoError(t, err) + defer os.Remove(logFile.Name()) + logFile.Close() + + errorLogFile, err := os.CreateTemp("", "test-error-log-*.txt") + require.NoError(t, err) + defer os.Remove(errorLogFile.Name()) + errorLogFile.Close() + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: add-correlation-id + do: set resp_header X-Correlation-Id random_uuid +- name: validate-auth + on: path /protected + do: require_basic_auth "Protected Area" +- name: log-all-requests + do: | + log info %q "{{ .Request.Method }} {{ .Request.URL }} -> {{ .Response.StatusCode }}" +- name: log-errors + on: status 4xx + do: | + log error %q "ERROR: {{ .Request.Method }} {{ .Request.URL }} {{ .Response.StatusCode }}" +`, logFile.Name(), errorLogFile.Name()), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test successful request + req1 := httptest.NewRequest("GET", "/public", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "success", w1.Body.String()) + assert.Equal(t, "random_uuid", w1.Header().Get("X-Correlation-Id")) + assert.Equal(t, "100ms", w1.Header().Get("X-Response-Time")) + + // Test unauthorized protected request + req2 := httptest.NewRequest("GET", "/protected", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 401, w2.Code) + assert.Equal(t, w2.Body.String(), "Unauthorized\n") + + // Test authorized protected request + req3 := httptest.NewRequest("GET", "/protected", nil) + req3.SetBasicAuth("user", "pass") + w3 := httptest.NewRecorder() + handler.ServeHTTP(w3, req3) + + // This should fail because our simple upstream expects X-Auth: valid header + // but the basic auth requirement should add the appropriate header + assert.Equal(t, 401, w3.Code) + + // Check log files + logContent, err := os.ReadFile(logFile.Name()) + require.NoError(t, err) + lines := strings.Split(strings.TrimSpace(string(logContent)), "\n") + require.Len(t, lines, 3, "all requests should be logged") + assert.Equal(t, "GET /public -> 200", lines[0]) + assert.Equal(t, "GET /protected -> 401", lines[1]) + assert.Equal(t, "GET /protected -> 401", lines[2]) + + errorLogContent, err := os.ReadFile(errorLogFile.Name()) + require.NoError(t, err) + // Should have at least one 401 error logged + lines = strings.Split(strings.TrimSpace(string(errorLogContent)), "\n") + require.Len(t, lines, 2, "all errors should be logged") + assert.Equal(t, "ERROR: GET /protected 401", lines[0]) + assert.Equal(t, "ERROR: GET /protected 401", lines[1]) +} + +func TestHTTPFlow_DefaultRule(t *testing.T) { + upstream := mockUpstream(200, "upstream response") + + var rules Rules + err := parseRules(` +- name: default + do: set resp_header X-Default-Applied true +- name: special-rule + on: path /special + do: set resp_header X-Special-Handled true +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test default rule + req1 := httptest.NewRequest("GET", "/regular", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "true", w1.Header().Get("X-Default-Applied")) + assert.Empty(t, w1.Header().Get("X-Special-Handled")) + + // Test special rule + default rule + req2 := httptest.NewRequest("GET", "/special", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + assert.Equal(t, "true", w2.Header().Get("X-Default-Applied")) + assert.Equal(t, "true", w2.Header().Get("X-Special-Handled")) +} + +func TestHTTPFlow_HeaderManipulation(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Echo back a header + headerValue := r.Header.Get("X-Test-Header") + w.Header().Set("X-Echoed-Header", headerValue) + w.WriteHeader(200) + w.Write([]byte("header echoed")) + }) + + var rules Rules + err := parseRules(` +- name: remove-sensitive-header + do: remove resp_header X-Secret +- name: add-custom-header + do: add resp_header X-Custom-Header custom-value +- name: modify-existing-header + on: header X-Test-Header + do: set header X-Test-Header modified-value +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Secret", "secret-value") + req.Header.Set("X-Test-Header", "original-value") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "modified-value", w.Header().Get("X-Echoed-Header")) + assert.Equal(t, "custom-value", w.Header().Get("X-Custom-Header")) + // Ensure the secret header was removed and not passed to upstream + // (we can't directly test this, but the upstream shouldn't see it) +} + +func TestHTTPFlow_QueryParameterHandling(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + w.WriteHeader(200) + w.Write([]byte("query: " + query.Get("param"))) + }) + + var rules Rules + err := parseRules(` +- name: add-query-param + on: query param + do: set query param added-value +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/path?param=original", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + // The set command should have modified the query parameter + assert.Equal(t, "query: added-value", w.Body.String()) +} + +func TestHTTPFlow_ServeCommand(t *testing.T) { + // Create a temporary directory with test files + tempDir, err := os.MkdirTemp("", "test-serve-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Create test files directly in the temp directory + testFile := filepath.Join(tempDir, "index.html") + err = os.WriteFile(testFile, []byte("

Test Page

"), 0644) + require.NoError(t, err) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: serve-static + on: path glob(/files/*) + do: serve %s +`, tempDir), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(mockUpstream(200, "should not be called")) + + // Test serving a file - serve command serves files relative to the root directory + // The path /files/index.html gets mapped to tempDir + "/files/index.html" + // We need to create the file at the expected path + filesDir := filepath.Join(tempDir, "files") + err = os.Mkdir(filesDir, 0755) + require.NoError(t, err) + + filesIndexFile := filepath.Join(filesDir, "index.html") + err = os.WriteFile(filesIndexFile, []byte("

Test Page

"), 0644) + require.NoError(t, err) + + req1 := httptest.NewRequest("GET", "/files/index.html", nil) + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + // The serve command should work, but might redirect + // Let's just verify it doesn't call the upstream + assert.NotEqual(t, "should not be called", w1.Body.String()) + + // Test file not found + req2 := httptest.NewRequest("GET", "/files/nonexistent.html", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 404, w2.Code) +} + +func TestHTTPFlow_ProxyCommand(t *testing.T) { + // Create a mock upstream server + upstreamServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream-Header", "upstream-value") + w.WriteHeader(200) + w.Write([]byte("upstream response")) + })) + defer upstreamServer.Close() + + var rules Rules + err := parseRules(fmt.Sprintf(` +- name: proxy-to-upstream + on: path glob(/api/*) + do: proxy %s +`, upstreamServer.URL), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(mockUpstream(200, "should not be called")) + + req := httptest.NewRequest("GET", "/api/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // The proxy command should forward the request to the upstream server + assert.Equal(t, 200, w.Code) + assert.Equal(t, "upstream response", w.Body.String()) + assert.Equal(t, "upstream-value", w.Header().Get("X-Upstream-Header")) +} + +func TestHTTPFlow_NotifyCommand(t *testing.T) { + // TODO: +} + +func TestHTTPFlow_FormConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("form processed")) + }) + + var rules Rules + err := parseRules(` +- name: process-form + on: form username + do: set resp_header X-Username "{{ index .Request.Form.username 0 }}" +- name: process-postform + on: postform email + do: set resp_header X-Email "{{ index .Request.PostForm.email 0 }}" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test form condition + formData := url.Values{"username": {"john_doe"}} + req1 := httptest.NewRequest("POST", "/", strings.NewReader(formData.Encode())) + req1.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "john_doe", w1.Header().Get("X-Username")) + + // Test postform condition + postFormData := url.Values{"email": {"john@example.com"}} + req2 := httptest.NewRequest("POST", "/", strings.NewReader(postFormData.Encode())) + req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + assert.Equal(t, "john@example.com", w2.Header().Get("X-Email")) +} + +func TestHTTPFlow_RemoteConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("remote processed")) + }) + + var rules Rules + err := parseRules(` +- name: allow-localhost + on: remote 127.0.0.1 + do: set resp_header X-Access "local" +- name: block-private + on: remote 192.168.0.0/16 + do: error 403 "Private network blocked" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test localhost condition + req1 := httptest.NewRequest("GET", "/", nil) + req1.RemoteAddr = "127.0.0.1:12345" + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "local", w1.Header().Get("X-Access")) + + // Test private network block + req2 := httptest.NewRequest("GET", "/", nil) + req2.RemoteAddr = "192.168.1.100:12345" + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 403, w2.Code) + assert.Equal(t, "Private network blocked\n", w2.Body.String()) +} + +func TestHTTPFlow_BasicAuthConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("auth processed")) + }) + + // Generate bcrypt hashes for passwords + adminHash, err := bcrypt.GenerateFromPassword([]byte("adminpass"), bcrypt.DefaultCost) + require.NoError(t, err) + guestHash, err := bcrypt.GenerateFromPassword([]byte("guestpass"), bcrypt.DefaultCost) + require.NoError(t, err) + + var rules Rules + err = parseRules(fmt.Sprintf(` +- name: check-auth + on: basic_auth admin %s + do: set resp_header X-Auth-Status "admin" +- name: check-other-user + on: basic_auth guest %s + do: set resp_header X-Auth-Status "guest" +`, string(adminHash), string(guestHash)), &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test admin user + req1 := httptest.NewRequest("GET", "/", nil) + req1.SetBasicAuth("admin", "adminpass") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "admin", w1.Header().Get("X-Auth-Status")) + + // Test guest user + req2 := httptest.NewRequest("GET", "/", nil) + req2.SetBasicAuth("guest", "guestpass") + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + assert.Equal(t, "guest", w2.Header().Get("X-Auth-Status")) +} + +func TestHTTPFlow_RouteConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("route processed")) + }) + + var rules Rules + err := parseRules(` +- name: backend-route + on: route backend + do: set resp_header X-Route "backend" +- name: frontend-route + on: route frontend + do: set resp_header X-Route "frontend" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test API route + req1 := httptest.NewRequest("GET", "/", nil) + req1 = routes.WithRouteContext(req1, mockRoute("backend")) + + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "backend", w1.Header().Get("X-Route")) + + // Test admin route + req2 := httptest.NewRequest("GET", "/", nil) + req2 = routes.WithRouteContext(req2, mockRoute("frontend")) + + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + assert.Equal(t, "frontend", w2.Header().Get("X-Route")) +} + +func TestHTTPFlow_ResponseStatusConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(405) + w.Write([]byte("method not allowed")) + }) + + var rules Rules + err := parseRules(` +- name: method-not-allowed + on: status 405 + do: | + error 405 'error' +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, 405, w.Code) + assert.Equal(t, "error\n", w.Body.String()) +} + +func TestHTTPFlow_ResponseHeaderConditions(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Response-Header", "response header") + w.WriteHeader(200) + w.Write([]byte("processed")) + }) + + t.Run("any_value", func(t *testing.T) { + var rules Rules + err := parseRules(` +- on: resp_header X-Response-Header + do: | + error 405 "error" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, 405, w.Code) + assert.Equal(t, "error\n", w.Body.String()) + }) + t.Run("with_value", func(t *testing.T) { + var rules Rules + err := parseRules(` +- on: resp_header X-Response-Header "response header" + do: | + error 405 "error" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, 405, w.Code) + assert.Equal(t, "error\n", w.Body.String()) + }) + + t.Run("with_value_not_matched", func(t *testing.T) { + var rules Rules + err := parseRules(` +- on: resp_header X-Response-Header "not-matched" + do: | + error 405 "error" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "processed", w.Body.String()) + }) +} + +func TestHTTPFlow_ComplexRuleCombinations(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("complex processed")) + }) + + var rules Rules + err := parseRules(` +- name: admin-api + on: | + path glob(/api/admin/*) + header Authorization + method POST + do: | + set resp_header X-Access-Level "admin" + set resp_header X-API-Version "v1" +- name: user-api + on: | + path glob(/api/users/*) & method GET + do: | + set resp_header X-Access-Level "user" + set resp_header X-API-Version "v1" +- name: public-api + on: | + path glob(/api/public/*) & method GET + do: | + set resp_header X-Access-Level "public" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + // Test admin API (should match first rule) + req1 := httptest.NewRequest("POST", "/api/admin/users", nil) + req1.Header.Set("Authorization", "Bearer token") + w1 := httptest.NewRecorder() + handler.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + assert.Equal(t, "admin", w1.Header().Get("X-Access-Level")) + assert.Equal(t, "v1", w1.Header()["X-API-Version"][0]) + + // Test user API (should match second rule) + req2 := httptest.NewRequest("GET", "/api/users/profile", nil) + w2 := httptest.NewRecorder() + handler.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + assert.Equal(t, "user", w2.Header().Get("X-Access-Level")) + assert.Equal(t, "v1", w2.Header()["X-API-Version"][0]) + + // Test public API (should match third rule) + req3 := httptest.NewRequest("GET", "/api/public/info", nil) + w3 := httptest.NewRecorder() + handler.ServeHTTP(w3, req3) + + assert.Equal(t, 200, w3.Code) + assert.Equal(t, "public", w3.Header().Get("X-Access-Level")) + assert.Empty(t, w3.Header()["X-API-Version"]) +} + +func TestHTTPFlow_ResponseModifier(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("original response")) + }) + + var rules Rules + err := parseRules(` +- name: modify-response + do: | + set resp_header X-Modified "true" + set resp_body "Modified: {{ .Request.Method }} {{ .Request.URL.Path }}" +`, &rules) + require.NoError(t, err) + + handler := rules.BuildHandler(upstream) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "true", w.Header().Get("X-Modified")) + assert.Equal(t, "Modified: GET /test\n", w.Body.String()) +} diff --git a/internal/route/rules/io.go b/internal/route/rules/io.go new file mode 100644 index 00000000..8cfab7b2 --- /dev/null +++ b/internal/route/rules/io.go @@ -0,0 +1,36 @@ +package rules + +import ( + "io" + "os" + + "github.com/yusing/godoxy/internal/logging/accesslog" + gperr "github.com/yusing/goutils/errs" +) + +type noopWriteCloser struct { + io.Writer +} + +func (n noopWriteCloser) Close() error { + return nil +} + +var ( + stdout io.WriteCloser = noopWriteCloser{os.Stdout} + stderr io.WriteCloser = noopWriteCloser{os.Stderr} +) + +func openFile(path string) (io.WriteCloser, gperr.Error) { + switch path { + case "/dev/stdout": + return stdout, nil + case "/dev/stderr": + return stderr, nil + } + f, err := accesslog.NewFileIO(path) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + return f, nil +} diff --git a/internal/route/rules/matcher.go b/internal/route/rules/matcher.go new file mode 100644 index 00000000..7c27d908 --- /dev/null +++ b/internal/route/rules/matcher.go @@ -0,0 +1,120 @@ +package rules + +import ( + "regexp" + "strings" + + "github.com/gobwas/glob" + gperr "github.com/yusing/goutils/errs" +) + +type ( + Matcher func(string) bool + MatcherType string +) + +const ( + MatcherTypeString MatcherType = "string" + MatcherTypeGlob MatcherType = "glob" + MatcherTypeRegex MatcherType = "regex" +) + +func unquoteExpr(s string) (string, gperr.Error) { + if s == "" { + return "", nil + } + switch s[0] { + case '"', '\'', '`': + if s[0] != s[len(s)-1] { + return "", ErrUnterminatedQuotes + } + return s[1 : len(s)-1], nil + default: + return s, nil + } +} + +func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) { + idx := strings.IndexByte(s, '(') + if idx == -1 { + return MatcherTypeString, s, nil + } + idxEnd := strings.LastIndexByte(s, ')') + if idxEnd == -1 { + return "", "", ErrUnterminatedBrackets + } + + expr, err = unquoteExpr(s[idx+1 : idxEnd]) + if err != nil { + return "", "", err + } + matcherType = MatcherType(strings.ToLower(s[:idx])) + + switch matcherType { + case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString: + return + default: + return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType) + } +} + +func ParseMatcher(expr string) (Matcher, gperr.Error) { + negate := false + if strings.HasPrefix(expr, "!") { + negate = true + expr = expr[1:] + } + + t, expr, err := ExtractExpr(expr) + if err != nil { + return nil, err + } + + switch t { + case MatcherTypeString: + return StringMatcher(expr, negate) + case MatcherTypeGlob: + return GlobMatcher(expr, negate) + case MatcherTypeRegex: + return RegexMatcher(expr, negate) + } + // won't reach here + return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t) +} + +func StringMatcher(s string, negate bool) (Matcher, gperr.Error) { + if negate { + return func(s2 string) bool { + return s != s2 + }, nil + } + return func(s2 string) bool { + return s == s2 + }, nil +} + +func GlobMatcher(expr string, negate bool) (Matcher, gperr.Error) { + g, err := glob.Compile(expr) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + if negate { + return func(s string) bool { + return !g.Match(s) + }, nil + } + return g.Match, nil +} + +func RegexMatcher(expr string, negate bool) (Matcher, gperr.Error) { + re, err := regexp.Compile(expr) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + if negate { + return func(s string) bool { + return !re.MatchString(s) + }, nil + } + return re.MatchString, nil +} diff --git a/internal/route/rules/matcher_bench_test.go b/internal/route/rules/matcher_bench_test.go new file mode 100644 index 00000000..26484999 --- /dev/null +++ b/internal/route/rules/matcher_bench_test.go @@ -0,0 +1,35 @@ +package rules + +import "testing" + +func BenchmarkMatcher(b *testing.B) { + b.Run("StringMatcher", func(b *testing.B) { + matcher, err := StringMatcher("foo", false) + if err != nil { + b.Fatal(err) + } + for b.Loop() { + matcher("foo") + } + }) + + b.Run("GlobMatcher", func(b *testing.B) { + matcher, err := GlobMatcher("foo*bar?baz*[abc]*.txt", false) + if err != nil { + b.Fatal(err) + } + for b.Loop() { + matcher("foooooobarzbazcb.txt") + } + }) + + b.Run("RegexMatcher", func(b *testing.B) { + matcher, err := RegexMatcher(`^(foo\d+|bar(_baz)?)[a-z]{3,}\.txt$`, false) + if err != nil { + b.Fatal(err) + } + for b.Loop() { + matcher("foo123abcd.txt") + } + }) +} diff --git a/internal/route/rules/validate_test.go b/internal/route/rules/matcher_test.go similarity index 56% rename from internal/route/rules/validate_test.go rename to internal/route/rules/matcher_test.go index 75aadc32..0a5ed2d6 100644 --- a/internal/route/rules/validate_test.go +++ b/internal/route/rules/matcher_test.go @@ -49,6 +49,18 @@ func TestExtractExpr(t *testing.T) { wantT: MatcherTypeRegex, wantExpr: "^[A-Z]+$", }, + { + name: "regex with parentheses", + in: "regex(test(group))", + wantT: MatcherTypeRegex, + wantExpr: "test(group)", + }, + { + name: "regex complex", + in: `regex("^(_next/static|_next/image|favicon.ico).*$")`, + wantT: MatcherTypeRegex, + wantExpr: "^(_next/static|_next/image|favicon.ico).*$", + }, { name: "quoted expr", in: "glob(`'foo'`)", @@ -96,3 +108,62 @@ func TestExtractExprInvalid(t *testing.T) { }) } } + +func TestNegated(t *testing.T) { + tests := []struct { + name string + expr string + in string + want bool + }{ + { + name: "negated_string_match", + expr: "!string(`foo`)", + in: "foo", + want: false, + }, + { + name: "negated_string_no_match", + expr: "!string(`foo`)", + in: "bar", + want: true, + }, + { + name: "negated_glob_match", + expr: "!glob(`foo`)", + in: "foo", + want: false, + }, + { + name: "negated_glob_no_match", + expr: "!glob(`foo`)", + in: "bar", + want: true, + }, + { + name: "negated_regex_match", + expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)", + in: "favicon.ico", + want: false, + }, + { + name: "negated_regex_no_match", + expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)", + in: "bar", + want: true, + }, + { + name: "negated_regex_no_match2", + expr: "!regex(`^(_next/static|_next/image|favicon.ico).*$`)", + in: "/", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matcher, err := ParseMatcher(tt.expr) + expect.NoError(t, err) + expect.Equal(t, tt.want, matcher(tt.in)) + }) + } +} diff --git a/internal/route/rules/on.go b/internal/route/rules/on.go index 84a28105..5bee9436 100644 --- a/internal/route/rules/on.go +++ b/internal/route/rules/on.go @@ -8,16 +8,20 @@ import ( "github.com/yusing/godoxy/internal/route/routes" gperr "github.com/yusing/goutils/errs" - strutils "github.com/yusing/goutils/strings" ) type RuleOn struct { - raw string - checker Checker + raw string + checker Checker + isResponseChecker bool } -func (on *RuleOn) Check(cached Cache, r *http.Request) bool { - return on.checker.Check(cached, r) +func (on *RuleOn) IsResponseChecker() bool { + return on.isResponseChecker +} + +func (on *RuleOn) Check(w http.ResponseWriter, r *http.Request) bool { + return on.checker.Check(w, r) } const ( @@ -32,20 +36,27 @@ const ( OnRemote = "remote" OnBasicAuth = "basic_auth" OnRoute = "route" + + // on response + OnResponseHeader = "resp_header" + OnStatus = "status" ) var checkers = map[string]struct { - help Help - validate ValidateFunc - builder func(args any) CheckFunc + help Help + validate ValidateFunc + builder func(args any) CheckFunc + isResponseChecker bool }{ OnHeader: { help: Help{ command: OnHeader, - description: `Value supports string, glob pattern, or regex pattern, e.g.: - header username "user" - header username glob("user*") - header username regex("user.*")`, + description: makeLines( + "Value supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnHeader, "username", "user"), + helpExample(OnHeader, "username", helpFuncCall("glob", "user*")), + helpExample(OnHeader, "username", helpFuncCall("regex", "user.*")), + ), args: map[string]string{ "key": "the header key", "[value]": "the header value", @@ -55,22 +66,52 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return len(r.Header[k]) > 0 } } - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return slices.ContainsFunc(r.Header[k], matcher) } }, }, + OnResponseHeader: { + isResponseChecker: true, + help: Help{ + command: OnResponseHeader, + description: makeLines( + "Value supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnResponseHeader, "username", "user"), + helpExample(OnResponseHeader, "username", helpFuncCall("glob", "user*")), + helpExample(OnResponseHeader, "username", helpFuncCall("regex", "user.*")), + ), + args: map[string]string{ + "key": "the response header key", + "[value]": "the response header value", + }, + }, + validate: toKVOptionalVMatcher, + builder: func(args any) CheckFunc { + k, matcher := args.(*MapValueMatcher).Unpack() + if matcher == nil { + return func(w http.ResponseWriter, r *http.Request) bool { + return len(GetInitResponseModifier(w).Header()[k]) > 0 + } + } + return func(w http.ResponseWriter, r *http.Request) bool { + return slices.ContainsFunc(GetInitResponseModifier(w).Header()[k], matcher) + } + }, + }, OnQuery: { help: Help{ command: OnQuery, - description: `Value supports string, glob pattern, or regex pattern, e.g.: - query username "user" - query username glob("user*") - query username regex("user.*")`, + description: makeLines( + "Value supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnQuery, "username", "user"), + helpExample(OnQuery, "username", helpFuncCall("glob", "user*")), + helpExample(OnQuery, "username", helpFuncCall("regex", "user.*")), + ), args: map[string]string{ "key": "the query key", "[value]": "the query value", @@ -80,22 +121,24 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(cached Cache, r *http.Request) bool { - return len(cached.GetQueries(r)[k]) > 0 + return func(w http.ResponseWriter, r *http.Request) bool { + return len(GetSharedData(w).GetQueries(r)[k]) > 0 } } - return func(cached Cache, r *http.Request) bool { - return slices.ContainsFunc(cached.GetQueries(r)[k], matcher) + return func(w http.ResponseWriter, r *http.Request) bool { + return slices.ContainsFunc(GetSharedData(w).GetQueries(r)[k], matcher) } }, }, OnCookie: { help: Help{ command: OnCookie, - description: `Value supports string, glob pattern, or regex pattern, e.g.: - cookie username "user" - cookie username glob("user*") - cookie username regex("user.*")`, + description: makeLines( + "Value supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnCookie, "username", "user"), + helpExample(OnCookie, "username", helpFuncCall("glob", "user*")), + helpExample(OnCookie, "username", helpFuncCall("regex", "user.*")), + ), args: map[string]string{ "key": "the cookie key", "[value]": "the cookie value", @@ -105,8 +148,8 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(cached Cache, r *http.Request) bool { - cookies := cached.GetCookies(r) + return func(w http.ResponseWriter, r *http.Request) bool { + cookies := GetSharedData(w).GetCookies(r) for _, cookie := range cookies { if cookie.Name == k { return true @@ -115,8 +158,8 @@ var checkers = map[string]struct { return false } } - return func(cached Cache, r *http.Request) bool { - cookies := cached.GetCookies(r) + return func(w http.ResponseWriter, r *http.Request) bool { + cookies := GetSharedData(w).GetCookies(r) for _, cookie := range cookies { if cookie.Name == k { if matcher(cookie.Value) { @@ -131,10 +174,12 @@ var checkers = map[string]struct { OnForm: { help: Help{ command: OnForm, - description: `Value supports string, glob pattern, or regex pattern, e.g.: - form username "user" - form username glob("user*") - form username regex("user.*")`, + description: makeLines( + "Value supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnForm, "username", "user"), + helpExample(OnForm, "username", helpFuncCall("glob", "user*")), + helpExample(OnForm, "username", helpFuncCall("regex", "user.*")), + ), args: map[string]string{ "key": "the form key", "[value]": "the form value", @@ -144,11 +189,11 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return r.FormValue(k) != "" } } - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return matcher(r.FormValue(k)) } }, @@ -156,10 +201,12 @@ var checkers = map[string]struct { OnPostForm: { help: Help{ command: OnPostForm, - description: `Value supports string, glob pattern, or regex pattern, e.g.: - postform username "user" - postform username glob("user*") - postform username regex("user.*")`, + description: makeLines( + "Value supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnPostForm, "username", "user"), + helpExample(OnPostForm, "username", helpFuncCall("glob", "user*")), + helpExample(OnPostForm, "username", helpFuncCall("regex", "user.*")), + ), args: map[string]string{ "key": "the form key", "[value]": "the form value", @@ -169,11 +216,11 @@ var checkers = map[string]struct { builder: func(args any) CheckFunc { k, matcher := args.(*MapValueMatcher).Unpack() if matcher == nil { - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return r.PostFormValue(k) != "" } } - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return matcher(r.PostFormValue(k)) } }, @@ -188,7 +235,7 @@ var checkers = map[string]struct { validate: validateMethod, builder: func(args any) CheckFunc { method := args.(string) - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return r.Method == method } }, @@ -196,11 +243,13 @@ var checkers = map[string]struct { OnHost: { help: Help{ command: OnHost, - description: `Supports string, glob pattern, or regex pattern, e.g.: - host example.com - host glob(example*.com) - host regex(example\w+\.com) - host regex(example\.com$)`, + description: makeLines( + "Supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnHost, "example.com"), + helpExample(OnHost, helpFuncCall("glob", "example*.com")), + helpExample(OnHost, helpFuncCall("regex", `(example\w+\.com)`)), + helpExample(OnHost, helpFuncCall("regex", `example\.com$`)), + ), args: map[string]string{ "host": "the host name", }, @@ -208,7 +257,7 @@ var checkers = map[string]struct { validate: validateSingleMatcher, builder: func(args any) CheckFunc { matcher := args.(Matcher) - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { return matcher(r.Host) } }, @@ -216,11 +265,13 @@ var checkers = map[string]struct { OnPath: { help: Help{ command: OnPath, - description: `Supports string, glob pattern, or regex pattern, e.g.: - path /path/to - path glob(/path/to/*) - path regex(^/path/to/.*$) - path regex(/path/[A-Z]+/)`, + description: makeLines( + "Supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnPath, "/path/to"), + helpExample(OnPath, helpFuncCall("glob", "/path/to/*")), + helpExample(OnPath, helpFuncCall("regex", `^/path/to/.*$`)), + helpExample(OnPath, helpFuncCall("regex", `/path/[A-Z]+/`)), + ), args: map[string]string{ "path": "the request path", }, @@ -228,7 +279,7 @@ var checkers = map[string]struct { validate: validateURLPathMatcher, builder: func(args any) CheckFunc { matcher := args.(Matcher) - return func(cached Cache, r *http.Request) bool { + return func(w http.ResponseWriter, r *http.Request) bool { reqPath := r.URL.Path if len(reqPath) > 0 && reqPath[0] != '/' { reqPath = "/" + reqPath @@ -250,16 +301,16 @@ var checkers = map[string]struct { // for /32 (IPv4) or /128 (IPv6), just compare the IP if ones, bits := ipnet.Mask.Size(); ones == bits { wantIP := ipnet.IP - return func(cached Cache, r *http.Request) bool { - ip := cached.GetRemoteIP(r) + return func(w http.ResponseWriter, r *http.Request) bool { + ip := GetSharedData(w).GetRemoteIP(r) if ip == nil { return false } return ip.Equal(wantIP) } } - return func(cached Cache, r *http.Request) bool { - ip := cached.GetRemoteIP(r) + return func(w http.ResponseWriter, r *http.Request) bool { + ip := GetSharedData(w).GetRemoteIP(r) if ip == nil { return false } @@ -278,18 +329,20 @@ var checkers = map[string]struct { validate: validateUserBCryptPassword, builder: func(args any) CheckFunc { cred := args.(*HashedCrendentials) - return func(cached Cache, r *http.Request) bool { - return cred.Match(cached.GetBasicAuth(r)) + return func(w http.ResponseWriter, r *http.Request) bool { + return cred.Match(GetSharedData(w).GetBasicAuth(r)) } }, }, OnRoute: { help: Help{ command: OnRoute, - description: `Supports string, glob pattern, or regex pattern, e.g.: - route example - route glob(example*) - route regex(example\w+)`, + description: makeLines( + "Supports string, glob pattern, or regex pattern, e.g.:", + helpExample(OnRoute, "example"), + helpExample(OnRoute, helpFuncCall("glob", "example*")), + helpExample(OnRoute, helpFuncCall("regex", "example\\w+")), + ), args: map[string]string{ "route": "the route name", }, @@ -297,11 +350,43 @@ var checkers = map[string]struct { validate: validateSingleMatcher, builder: func(args any) CheckFunc { matcher := args.(Matcher) - return func(_ Cache, r *http.Request) bool { + return func(_ http.ResponseWriter, r *http.Request) bool { return matcher(routes.TryGetUpstreamName(r)) } }, }, + OnStatus: { + isResponseChecker: true, + help: Help{ + command: OnStatus, + description: makeLines( + "Supported formats are:", + helpExample(OnStatus, ""), + helpExample(OnStatus, "-"), + helpExample(OnStatus, "1xx"), + helpExample(OnStatus, "2xx"), + helpExample(OnStatus, "3xx"), + helpExample(OnStatus, "4xx"), + helpExample(OnStatus, "5xx"), + ), + args: map[string]string{ + "status": "the status code range", + }, + }, + validate: validateStatusRange, + builder: func(args any) CheckFunc { + beg, end := args.(*IntTuple).Unpack() + if beg == end { + return func(w http.ResponseWriter, _ *http.Request) bool { + return GetInitResponseModifier(w).StatusCode() == beg + } + } + return func(w http.ResponseWriter, _ *http.Request) bool { + statusCode := GetInitResponseModifier(w).StatusCode() + return statusCode >= beg && statusCode <= end + } + }, + }, } var ( @@ -367,6 +452,66 @@ func splitAnd(s string) []string { return a[:i] } +// splitPipe splits a string by "|" but respects quotes, brackets, and escaped characters. +// It's similar to the parser.go logic but specifically for pipe splitting. +func splitPipe(s string) []string { + if s == "" { + return []string{} + } + + var result []string + var current strings.Builder + escaped := false + quote := rune(0) + brackets := 0 + + for _, r := range s { + if escaped { + current.WriteRune(r) + escaped = false + continue + } + + switch r { + case '\\': + escaped = true + current.WriteRune(r) + case '"', '\'', '`': + if quote == 0 && brackets == 0 { + quote = r + } else if r == quote { + quote = 0 + } + current.WriteRune(r) + case '(': + brackets++ + current.WriteRune(r) + case ')': + if brackets > 0 { + brackets-- + } + current.WriteRune(r) + case '|': + if quote == 0 && brackets == 0 { + // Found a pipe outside quotes/brackets, split here + result = append(result, strings.TrimSpace(current.String())) + current.Reset() + } else { + current.WriteRune(r) + } + default: + current.WriteRune(r) + } + } + + // Add the last part + if current.Len() > 0 { + result = append(result, strings.TrimSpace(current.String())) + } + + return result +} + // Parse implements strutils.Parser. func (on *RuleOn) Parse(v string) error { on.raw = v @@ -375,19 +520,24 @@ func (on *RuleOn) Parse(v string) error { checkAnd := make(CheckMatchAll, 0, len(rules)) errs := gperr.NewBuilder("rule.on syntax errors") + isResponseChecker := false for i, rule := range rules { if rule == "" { continue } - parsed, err := parseOn(rule) + parsed, isResp, err := parseOn(rule) if err != nil { errs.Add(err.Subjectf("line %d", i+1)) continue } + if isResp { + isResponseChecker = true + } checkAnd = append(checkAnd, parsed) } on.checker = checkAnd + on.isResponseChecker = isResponseChecker return errs.Error() } @@ -399,40 +549,57 @@ func (on *RuleOn) MarshalText() ([]byte, error) { return []byte(on.String()), nil } -func parseOn(line string) (Checker, gperr.Error) { - ors := strutils.SplitRune(line, '|') +func parseOn(line string) (Checker, bool, gperr.Error) { + ors := splitPipe(line) if len(ors) > 1 { errs := gperr.NewBuilder("rule.on syntax errors") checkOr := make(CheckMatchSingle, len(ors)) + isResponseChecker := false for i, or := range ors { - curCheckers, err := parseOn(or) + curCheckers, isResp, err := parseOn(or) if err != nil { errs.Add(err) continue } + if isResp { + isResponseChecker = true + } checkOr[i] = curCheckers.(CheckFunc) } if err := errs.Error(); err != nil { - return nil, err + return nil, false, err } - return checkOr, nil + return checkOr, isResponseChecker, nil } subject, args, err := parse(line) if err != nil { - return nil, err + return nil, false, err + } + + negate := false + if strings.HasPrefix(subject, "!") { + negate = true + subject = subject[1:] } checker, ok := checkers[subject] if !ok { - return nil, ErrInvalidOnTarget.Subject(subject) + return nil, false, ErrInvalidOnTarget.Subject(subject) } validArgs, err := checker.validate(args) if err != nil { - return nil, err.Subject(subject).Withf("%s", checker.help.String()) + return nil, false, err.Subject(subject).With(checker.help.Error()) } - return checker.builder(validArgs), nil + checkFunc := checker.builder(validArgs) + if negate { + origCheckFunc := checkFunc + checkFunc = func(w http.ResponseWriter, r *http.Request) bool { + return !origCheckFunc(w, r) + } + } + return checkFunc, checker.isResponseChecker, nil } diff --git a/internal/route/rules/on_internal_test.go b/internal/route/rules/on_internal_test.go index 2d8ed89a..41093ae3 100644 --- a/internal/route/rules/on_internal_test.go +++ b/internal/route/rules/on_internal_test.go @@ -7,6 +7,86 @@ import ( expect "github.com/yusing/goutils/testing" ) +func TestSplitPipe(t *testing.T) { + tests := []struct { + name string + input string + want []string + }{ + { + name: "empty", + input: "", + want: []string{}, + }, + { + name: "single", + input: "rule", + want: []string{"rule"}, + }, + { + name: "simple_pipe", + input: "rule1 | rule2", + want: []string{"rule1", "rule2"}, + }, + { + name: "multiple_pipes", + input: "rule1 | rule2 | rule3", + want: []string{"rule1", "rule2", "rule3"}, + }, + { + name: "pipe_in_quotes", + input: `path regex("^(_next/static|_next/image|favicon.ico).*$")`, + want: []string{`path regex("^(_next/static|_next/image|favicon.ico).*$")`}, + }, + { + name: "pipe_in_single_quotes", + input: `path regex('^(_next/static|_next/image|favicon.ico).*$')`, + want: []string{`path regex('^(_next/static|_next/image|favicon.ico).*$')`}, + }, + { + name: "pipe_in_backticks", + input: "path regex(`^(_next/static|_next/image|favicon.ico).*$`)", + want: []string{"path regex(`^(_next/static|_next/image|favicon.ico).*$`)"}, + }, + { + name: "pipe_in_brackets", + input: "path regex(^(_next/static|_next/image|favicon.ico).*$)", + want: []string{"path regex(^(_next/static|_next/image|favicon.ico).*$)"}, + }, + { + name: "escaped_pipe", + input: `path regex("^(_next/static\|_next/image\|favicon.ico).*$")`, + want: []string{`path regex("^(_next/static\|_next/image\|favicon.ico).*$")`}, + }, + { + name: "mixed_quotes_and_pipes", + input: `rule1 | path regex("^(_next/static|_next/image|favicon.ico).*$") | rule3`, + want: []string{"rule1", `path regex("^(_next/static|_next/image|favicon.ico).*$")`, "rule3"}, + }, + { + name: "nested_brackets", + input: "path regex(^(foo|bar(baz|qux)).*$)", + want: []string{"path regex(^(foo|bar(baz|qux)).*$)"}, + }, + { + name: "spaces_around", + input: " rule1 | rule2 | rule3 ", + want: []string{"rule1", "rule2", "rule3"}, + }, + { + name: "empty_segments", + input: "rule1 || rule2 | | rule3", + want: []string{"rule1", "", "rule2", "", "rule3"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := splitPipe(tt.input) + expect.Equal(t, got, tt.want) + }) + } +} + func TestSplitAnd(t *testing.T) { tests := []struct { name string @@ -179,6 +259,27 @@ func TestParseOn(t *testing.T) { input: "route example1 example2", wantErr: ErrExpectOneArg, }, + // pipe splitting tests + { + name: "pipe_simple", + input: "method GET | method POST", + wantErr: nil, + }, + { + name: "pipe_in_quotes", + input: `path regex("^(_next/static|_next/image|favicon.ico).*$")`, + wantErr: nil, + }, + { + name: "pipe_in_brackets", + input: "path regex(^(_next/static|_next/image|favicon.ico).*$)", + wantErr: nil, + }, + { + name: "pipe_mixed", + input: `method GET | path regex("^(_next/static|_next/image|favicon.ico).*$") | header Authorization`, + wantErr: nil, + }, } for _, tt := range tests { diff --git a/internal/route/rules/on_test.go b/internal/route/rules/on_test.go index 38a7f4e4..1b47d19b 100644 --- a/internal/route/rules/on_test.go +++ b/internal/route/rules/on_test.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "fmt" "net/http" + "net/http/httptest" "net/url" "testing" @@ -47,6 +48,18 @@ func genCorrectnessTestCases(field string, genRequest func(k, v string) *http.Re input: genRequest("bar", "abcd"), want: false, }, + { + name: field + "_negated_match", + checker: "!" + field + " foo", + input: genRequest("foo", "bar"), + want: false, + }, + { + name: field + "_negated_no_match", + checker: "!" + field + " foo", + input: genRequest("bar", "foo"), + want: true, + }, } } @@ -64,6 +77,18 @@ func TestOnCorrectness(t *testing.T) { input: &http.Request{Method: http.MethodPost}, want: false, }, + { + name: "method_negated_match", + checker: "!method GET", + input: &http.Request{Method: http.MethodGet}, + want: false, + }, + { + name: "method_negated_no_match", + checker: "!method GET", + input: &http.Request{Method: http.MethodPost}, + want: true, + }, { name: "host_match", checker: "host example.com", @@ -80,6 +105,22 @@ func TestOnCorrectness(t *testing.T) { }, want: false, }, + { + name: "host_negated_match", + checker: "!host example.com", + input: &http.Request{ + Host: "example.com", + }, + want: false, + }, + { + name: "host_negated_no_match", + checker: "!host example.com", + input: &http.Request{ + Host: "example.org", + }, + want: true, + }, { name: "path_exact_match", checker: "path /example", @@ -88,6 +129,22 @@ func TestOnCorrectness(t *testing.T) { }, want: true, }, + { + name: "path_negated_match", + checker: "!path /example", + input: &http.Request{ + URL: &url.URL{Path: "/example"}, + }, + want: false, + }, + { + name: "path_negated_no_match", + checker: "!path /example", + input: &http.Request{ + URL: &url.URL{Path: "/example/foo"}, + }, + want: true, + }, { name: "remote_match", checker: "remote 192.168.1.0/24", @@ -96,6 +153,22 @@ func TestOnCorrectness(t *testing.T) { }, want: true, }, + { + name: "remote_negated_match", + checker: "!remote 192.168.1.0/24", + input: &http.Request{ + RemoteAddr: "192.168.1.5", + }, + want: false, + }, + { + name: "remote_negated_no_match", + checker: "!remote 192.168.1.0/24", + input: &http.Request{ + RemoteAddr: "192.168.2.5", + }, + want: true, + }, { name: "remote_no_match", checker: "remote 192.168.1.0/24", @@ -124,6 +197,26 @@ func TestOnCorrectness(t *testing.T) { }, want: false, }, + { + name: "basic_auth_negated_match", + checker: "!basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + input: &http.Request{ + Header: http.Header{ + "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:password"))}, // "user:password" + }, + }, + want: false, + }, + { + name: "basic_auth_negated_no_match", + checker: "!basic_auth user " + string(expect.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost))), + input: &http.Request{ + Header: http.Header{ + "Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte("user:incorrect"))}, // "user:wrong" + }, + }, + want: true, + }, { name: "route_match", checker: "route example", @@ -141,6 +234,23 @@ func TestOnCorrectness(t *testing.T) { }, want: false, }, + { + name: "route_negated_match", + checker: "!route example", + input: routes.WithRouteContext(&http.Request{}, expect.Must(route.NewFileServer(&route.Route{ + Alias: "example", + Root: "/", + }))), + want: false, + }, + { + name: "route_negated_no_match", + checker: "!route example", + input: &http.Request{ + Header: http.Header{}, + }, + want: true, + }, { name: "regex_match", checker: `host regex(example\w+\.com)`, @@ -157,6 +267,22 @@ func TestOnCorrectness(t *testing.T) { }, want: false, }, + { + name: "regex_negated_match", + checker: `!host regex(example\w+\.com)`, + input: &http.Request{ + Host: "example.org", + }, + want: true, + }, + { + name: "regex_negated_no_match", + checker: `!host regex(example\w+\.com)`, + input: &http.Request{ + Host: "exampleabc.com", + }, + want: false, + }, { name: "glob match", checker: `host glob(*.example.com)`, @@ -181,6 +307,22 @@ func TestOnCorrectness(t *testing.T) { }, want: false, }, + { + name: "glob negated_match", + checker: `!host glob(*.example.com)`, + input: &http.Request{ + Host: "example.com", + }, + want: true, + }, + { + name: "glob negated_no_match", + checker: `!host glob(*.example.com)`, + input: &http.Request{ + Host: "a.example.com", + }, + want: false, + }, } tests = append(tests, genCorrectnessTestCases("header", func(k, v string) *http.Request { @@ -219,10 +361,11 @@ func TestOnCorrectness(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() var on RuleOn err := on.Parse(tt.checker) expect.NoError(t, err) - got := on.Check(Cache{}, tt.input) + got := on.Check(w, tt.input) expect.Equal(t, tt.want, got, fmt.Sprintf("expect %s to %v", tt.checker, tt.want)) }) } diff --git a/internal/route/rules/parser.go b/internal/route/rules/parser.go index c841df49..87375ffa 100644 --- a/internal/route/rules/parser.go +++ b/internal/route/rules/parser.go @@ -3,9 +3,9 @@ package rules import ( "bytes" "fmt" - "os" "unicode" + "github.com/yusing/goutils/env" gperr "github.com/yusing/goutils/errs" ) @@ -33,7 +33,7 @@ func parse(v string) (subject string, args []string, err gperr.Error) { brackets := 0 var envVar bytes.Buffer - var missingEnvVars bytes.Buffer + var missingEnvVars []string inEnvVar := false expectingBrace := false @@ -70,6 +70,10 @@ func parse(v string) (subject string, args []string, err gperr.Error) { escaped = false continue } + if expectingBrace && r != '{' && r != '$' { // not escaped and not env var + buf.WriteRune('$') + expectingBrace = false + } switch r { case '\\': escaped = true @@ -90,9 +94,11 @@ func parse(v string) (subject string, args []string, err gperr.Error) { } case '}': if inEnvVar { - envValue, ok := os.LookupEnv(envVar.String()) + // NOTE: use env.LookupEnv instead of os.LookupEnv to support environment variable prefixes + // like ${API_ADDR} will lookup for GODOXY_API_ADDR, GOPROXY_API_ADDR and API_ADDR. + envValue, ok := env.LookupEnv(envVar.String()) if !ok { - fmt.Fprintf(&missingEnvVars, "%q, ", envVar.String()) + missingEnvVars = append(missingEnvVars, envVar.String()) } else { buf.WriteString(envValue) } @@ -140,15 +146,21 @@ func parse(v string) (subject string, args []string, err gperr.Error) { } } + if expectingBrace { + buf.WriteRune('$') + } + if quote != 0 { err = ErrUnterminatedQuotes } else if brackets != 0 { err = ErrUnterminatedBrackets + } else if inEnvVar { + err = ErrUnterminatedEnvVar } else { flush(false) } - if missingEnvVars.Len() > 0 { - err = gperr.Join(err, ErrEnvVarNotFound.Subject(missingEnvVars.String())) + if len(missingEnvVars) > 0 { + err = gperr.Join(err, ErrEnvVarNotFound.With(gperr.Multiline().AddStrings(missingEnvVars...))) } return subject, args, err } diff --git a/internal/route/rules/parser_test.go b/internal/route/rules/parser_test.go index cb8e9b94..1b0e8b9c 100644 --- a/internal/route/rules/parser_test.go +++ b/internal/route/rules/parser_test.go @@ -49,9 +49,9 @@ func TestParser(t *testing.T) { }, { name: "regex_escaped", - input: `foo regex(\b\B\s\S\w\W\d\D\$\.)`, + input: `foo regex(\b\B\s\S\w\W\d\D\$\.\(\)\{\}\|\?\"\')`, subject: "foo", - args: []string{`regex(\b\B\s\S\w\W\d\D\$\.)`}, + args: []string{`regex(\b\B\s\S\w\W\d\D\$\.\(\)\{\}\|\?"')`}, }, { name: "quote inside argument", @@ -71,6 +71,12 @@ func TestParser(t *testing.T) { subject: "foo", args: []string{"glob(\"`/**/to/path`\")"}, }, + { + name: "complex_regex", + input: `path !regex("^(_next/static|_next/image|favicon.ico).*$")`, + subject: "path", + args: []string{`!regex("^(_next/static|_next/image|favicon.ico).*$")`}, + }, { name: "chaos", input: `error 403 "Forbidden "foo" "bar""`, @@ -170,6 +176,53 @@ func TestParser(t *testing.T) { }) } }) + + t.Run("negated", func(t *testing.T) { + test := `!error 403 "Forbidden"` + subject, args, err := parse(test) + expect.NoError(t, err) + expect.Equal(t, subject, "!error") + expect.Equal(t, args, []string{"403", "Forbidden"}) + }) +} + +func TestFullParse(t *testing.T) { + input := ` +- name: login page + on: path /login + do: pass +- name: require auth + on: path !regex("^(_next/static|_next/image|favicon.ico).*$") + do: require_auth +- name: redirect to login + on: status 401 | status 403 + do: proxy /login +- name: proxy to backend + on: path glob("/api/v1/*") + do: proxy http://localhost:8999/ +- name: proxy to backend (old /auth) + on: path glob("/auth/*") + do: proxy http://localhost:8999/api/v1/` + + var rules Rules + err := parseRules(input, &rules) + expect.NoError(t, err) + expect.Equal(t, len(rules), 5) + expect.Equal(t, rules[0].Name, "login page") + expect.Equal(t, rules[0].On.String(), "path /login") + expect.Equal(t, rules[0].Do.String(), "pass") + expect.Equal(t, rules[1].Name, "require auth") + expect.Equal(t, rules[1].On.String(), `path !regex("^(_next/static|_next/image|favicon.ico).*$")`) + expect.Equal(t, rules[1].Do.String(), "require_auth") + expect.Equal(t, rules[2].Name, "redirect to login") + expect.Equal(t, rules[2].On.String(), "status 401 | status 403") + expect.Equal(t, rules[2].Do.String(), "proxy /login") + expect.Equal(t, rules[3].Name, "proxy to backend") + expect.Equal(t, rules[3].On.String(), `path glob("/api/v1/*")`) + expect.Equal(t, rules[3].Do.String(), "proxy http://localhost:8999/") + expect.Equal(t, rules[4].Name, "proxy to backend (old /auth)") + expect.Equal(t, rules[4].On.String(), `path glob("/auth/*")`) + expect.Equal(t, rules[4].Do.String(), "proxy http://localhost:8999/api/v1/") } func BenchmarkParser(b *testing.B) { diff --git a/internal/route/rules/presets/embed.go b/internal/route/rules/presets/embed.go new file mode 100644 index 00000000..f9f83b61 --- /dev/null +++ b/internal/route/rules/presets/embed.go @@ -0,0 +1,48 @@ +package rulepresets + +import ( + "embed" + "reflect" + "sync" + + "github.com/rs/zerolog/log" + "github.com/yusing/godoxy/internal/route/rules" + "github.com/yusing/godoxy/internal/serialization" +) + +//go:embed *.yml +var fs embed.FS + +var rulePresets = make(map[string]rules.Rules) + +var once sync.Once + +func GetRulePreset(name string) (rules.Rules, bool) { + once.Do(initPresets) + rules, ok := rulePresets[name] + return rules, ok +} + +// init all rule presetsl lazily +func initPresets() { + files, err := fs.ReadDir(".") + if err != nil { + log.Error().Err(err).Msg("failed to read rule presets") + return + } + for _, file := range files { + var rules rules.Rules + content, err := fs.ReadFile(file.Name()) + if err != nil { + log.Error().Str("name", file.Name()).Err(err).Msg("failed to read rule preset") + continue + } + _, err = serialization.ConvertString(string(content), reflect.ValueOf(&rules)) + if err != nil { + log.Error().Str("name", file.Name()).Err(err).Msg("failed to unmarshal rule preset") + continue + } + rulePresets[file.Name()] = rules + log.Debug().Str("name", file.Name()).Msg("loaded rule preset") + } +} diff --git a/internal/route/rules/presets/webui.yml b/internal/route/rules/presets/webui.yml new file mode 100644 index 00000000..96f7ea0b --- /dev/null +++ b/internal/route/rules/presets/webui.yml @@ -0,0 +1,17 @@ +- name: login page + on: path /login + do: pass +- name: protected + on: | + !path regex("(_next/static|_next/image|favicon.ico).*") + !path glob("/api/v1/auth/*") + !path /api/v1/version + do: require_auth +- name: proxy to backend + on: path glob("/api/v1/*") + do: proxy http://${API_ADDR}/ +- name: proxy to auth api + on: path glob("/auth/*") + do: | + rewrite /auth /api/v1/auth + proxy http://${API_ADDR}/ diff --git a/internal/route/rules/response_modifier.go b/internal/route/rules/response_modifier.go new file mode 100644 index 00000000..0e70d7fa --- /dev/null +++ b/internal/route/rules/response_modifier.go @@ -0,0 +1,173 @@ +package rules + +import ( + "bufio" + "bytes" + "errors" + "net" + "net/http" + "strconv" + + gperr "github.com/yusing/goutils/errs" + "github.com/yusing/goutils/synk" +) + +type ResponseModifier struct { + w http.ResponseWriter + b []byte // the bytes got from pool + buf *bytes.Buffer + statusCode int + shared Cache + + hijacked bool + + errs gperr.Builder +} + +type Response struct { + StatusCode int + Header http.Header +} + +var pool = synk.GetBytesPoolWithUniqueMemory() + +func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier { + for { + switch ww := w.(type) { + case *ResponseModifier: + return ww + case interface{ Unwrap() http.ResponseWriter }: + w = ww.Unwrap() + default: + return nil + } + } +} + +// GetInitResponseModifier returns the response modifier for the given response writer. +// If the response writer is already wrapped, it will return the wrapped response modifier. +// Otherwise, it will return a new response modifier. +func GetInitResponseModifier(w http.ResponseWriter) *ResponseModifier { + if rm := unwrapResponseModifier(w); rm != nil { + return rm + } + return NewResponseModifier(w) +} + +// GetSharedData returns the shared data for the given response writer. +// It will initialize the shared data if not initialized. +func GetSharedData(w http.ResponseWriter) Cache { + rm := GetInitResponseModifier(w) + if rm.shared == nil { + rm.shared = NewCache() + } + return rm.shared +} + +// NewResponseModifier returns a new response modifier for the given response writer. +// +// It should only be called once, at the very beginning of the request. +func NewResponseModifier(w http.ResponseWriter) *ResponseModifier { + b := pool.Get() + return &ResponseModifier{ + w: w, + buf: bytes.NewBuffer(b), + b: b, + } +} + +// func (rm *ResponseModifier) Unwrap() http.ResponseWriter { +// return rm.w +// } + +func (rm *ResponseModifier) WriteHeader(code int) { + rm.statusCode = code +} + +func (rm *ResponseModifier) ResetBody() { + rm.buf.Reset() +} + +func (rm *ResponseModifier) ContentLength() int { + return rm.buf.Len() +} + +func (rm *ResponseModifier) StatusCode() int { + if rm.statusCode == 0 { + return http.StatusOK + } + return rm.statusCode +} + +func (rm *ResponseModifier) Header() http.Header { + return rm.w.Header() +} + +func (rm *ResponseModifier) Response() Response { + return Response{StatusCode: rm.StatusCode(), Header: rm.Header()} +} + +func (rm *ResponseModifier) Write(b []byte) (int, error) { + return rm.buf.Write(b) +} + +// AppendError appends an error to the response modifier +// the error will be formatted as "rule error: " +// +// It will be aggregated and returned in FlushRelease. +func (rm *ResponseModifier) AppendError(rule Rule, err error) { + rm.errs.Addf("rule %q error: %w", rule.Name, err) +} + +func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rm.w.(http.Hijacker); ok { + rm.hijacked = true + return hijacker.Hijack() + } + return nil, nil, errors.New("hijack not supported") +} + +// FlushRelease flushes the response modifier and releases the resources +// it returns the number of bytes written and the aggregated error +// if there is any error (rule errors or write error), it will be returned +func (rm *ResponseModifier) FlushRelease() (int, error) { + n := 0 + if !rm.hijacked { + h := rm.w.Header() + // for k := range h { + // if strings.EqualFold(k, "content-length") { + // h.Del(k) + // } + // } + h.Set("Content-Length", strconv.Itoa(rm.buf.Len())) + rm.w.WriteHeader(rm.StatusCode()) + nn, werr := rm.w.Write(rm.buf.Bytes()) + n += nn + if werr != nil { + rm.errs.Addf("write error: %w", werr) + } + + // flush the response writer + if flusher, ok := rm.w.(http.Flusher); ok { + flusher.Flush() + } else if errFlusher, ok := rm.w.(interface{ Flush() error }); ok { + ferr := errFlusher.Flush() + if ferr != nil { + rm.errs.Addf("flush error: %w", ferr) + } + } + } + + // release the buffer and reset the pointers + pool.Put(rm.b) + rm.b = nil + rm.buf = nil + + // release the shared data + if rm.shared != nil { + rm.shared.Release() + rm.shared = nil + } + + return n, rm.errs.Error() +} diff --git a/internal/route/rules/rules.go b/internal/route/rules/rules.go index 8ebe6652..8556f805 100644 --- a/internal/route/rules/rules.go +++ b/internal/route/rules/rules.go @@ -1,9 +1,12 @@ package rules import ( + "errors" + "fmt" "net/http" "github.com/bytedance/sonic" + gperr "github.com/yusing/goutils/errs" ) type ( @@ -46,15 +49,16 @@ type ( } ) +func (rule *Rule) IsResponseRule() bool { + return rule.On.IsResponseChecker() || rule.Do.IsResponseHandler() +} + // BuildHandler returns a http.HandlerFunc that implements the rules. -// -// if a bypass rule matches, -// the request is passed to the upstream and no more rules are executed. -// -// if no rule matches, the default rule is executed -// if no rule matches and default rule is not set, -// the request is passed to the upstream. -func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { +func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc { + if len(rules) == 0 { + return up + } + defaultRule := Rule{ Name: "default", Do: Command{ @@ -63,55 +67,168 @@ func (rules Rules) BuildHandler(up http.Handler) http.HandlerFunc { }, } - nonDefaultRules := make(Rules, 0, len(rules)) - for _, rule := range rules { + var nonDefaultRules Rules + hasDefaultRule := false + for i, rule := range rules { if rule.Name == "default" { defaultRule = rule + hasDefaultRule = true } else { + // set name to index if name is empty + if rule.Name == "" { + rule.Name = fmt.Sprintf("rule[%d]", i) + } nonDefaultRules = append(nonDefaultRules, rule) } } if len(nonDefaultRules) == 0 { if defaultRule.Do.isBypass() { - return up.ServeHTTP + return up + } + if defaultRule.IsResponseRule() { + return func(w http.ResponseWriter, r *http.Request) { + rm := NewResponseModifier(w) + w = rm + up(w, r) + err := defaultRule.Do.exec.Handle(w, r) + if err != nil && !errors.Is(err, errTerminated) { + rm.AppendError(defaultRule, err) + } + } } return func(w http.ResponseWriter, r *http.Request) { - cache := NewCache() - defer cache.Release() - if defaultRule.Do.exec.Handle(cache, w, r) { - up.ServeHTTP(w, r) + rm := NewResponseModifier(w) + w = rm + err := defaultRule.Do.exec.Handle(w, r) + if err == nil { + up(w, r) + return + } + if !errors.Is(err, errTerminated) { + rm.AppendError(defaultRule, err) } } } - if len(nonDefaultRules) == 0 { - nonDefaultRules = rules + preRules := make(Rules, 0, len(nonDefaultRules)) + postRules := make(Rules, 0, len(nonDefaultRules)) + for _, rule := range nonDefaultRules { + if rule.IsResponseRule() { + postRules = append(postRules, rule) + } else { + preRules = append(preRules, rule) + } } + isDefaultRulePost := hasDefaultRule && defaultRule.IsResponseRule() + defaultTerminates := isTerminatingHandler(defaultRule.Do.exec) + return func(w http.ResponseWriter, r *http.Request) { - cache := NewCache() - defer cache.Release() + rm := NewResponseModifier(w) + defer func() { + if _, err := rm.FlushRelease(); err != nil { + gperr.LogError("error executing rules", err) + } + }() - for _, rule := range nonDefaultRules { - if rule.Check(cache, r) { - if rule.Do.isBypass() { - up.ServeHTTP(w, r) - return + w = rm + + shouldCallUpstream := true + preMatched := false + + if hasDefaultRule && !isDefaultRulePost && !defaultTerminates { + if defaultRule.Do.isBypass() { + // continue to upstream + } else { + err := defaultRule.Handle(w, r) + if err != nil { + if !errors.Is(err, errTerminated) { + rm.AppendError(defaultRule, err) + } + shouldCallUpstream = false } - if !rule.Handle(cache, w, r) { + } + } + + if shouldCallUpstream { + for _, rule := range preRules { + if rule.Check(w, r) { + preMatched = true + if rule.Do.isBypass() { + break // post rules should still execute + } + err := rule.Handle(w, r) + if err != nil { + if !errors.Is(err, errTerminated) { + rm.AppendError(rule, err) + } + shouldCallUpstream = false + break + } + } + } + } + + if hasDefaultRule && !isDefaultRulePost && defaultTerminates && shouldCallUpstream && !preMatched { + if defaultRule.Do.isBypass() { + // continue to upstream + } else { + err := defaultRule.Handle(w, r) + if err != nil { + if !errors.Is(err, errTerminated) { + rm.AppendError(defaultRule, err) + return + } + shouldCallUpstream = false + } + } + } + + if shouldCallUpstream { + up(w, r) + } + + // if no post rules, we are done here + if len(postRules) == 0 && !isDefaultRulePost { + return + } + + for _, rule := range postRules { + if rule.Check(w, r) { + err := rule.Handle(w, r) + if err != nil { + if !errors.Is(err, errTerminated) { + rm.AppendError(rule, err) + } return } } } - // bypass or proceed - if defaultRule.Do.isBypass() || defaultRule.Handle(cache, w, r) { - up.ServeHTTP(w, r) + if isDefaultRulePost { + err := defaultRule.Handle(w, r) + if err != nil && !errors.Is(err, errTerminated) { + rm.AppendError(defaultRule, err) + } } } } +func isTerminatingHandler(handler CommandHandler) bool { + switch h := handler.(type) { + case TerminatingCommand: + return true + case Commands: + if len(h) == 0 { + return false + } + return isTerminatingHandler(h[len(h)-1]) + default: + return false + } +} + func (rules Rules) MarshalJSON() ([]byte, error) { names := make([]string, len(rules)) for i, rule := range rules { @@ -124,11 +241,14 @@ func (rule *Rule) String() string { return rule.Name } -func (rule *Rule) Check(cached Cache, r *http.Request) bool { - return rule.On.checker.Check(cached, r) +func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool { + if rule.On.checker == nil { + return true + } + v := rule.On.checker.Check(w, r) + return v } -func (rule *Rule) Handle(cached Cache, w http.ResponseWriter, r *http.Request) (proceed bool) { - proceed = rule.Do.exec.Handle(cached, w, r) - return proceed +func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error { + return rule.Do.exec.Handle(w, r) } diff --git a/internal/route/rules/rules_bench_test.go b/internal/route/rules/rules_bench_test.go new file mode 100644 index 00000000..c45f8cd4 --- /dev/null +++ b/internal/route/rules/rules_bench_test.go @@ -0,0 +1,74 @@ +package rules + +import ( + "bytes" + "io" + "net/http" + "net/url" + "testing" +) + +func BenchmarkRules(b *testing.B) { + var rules Rules + err := parseRules(` +- name: admin-api + on: | + path glob(/api/admin/*) + header Authorization + method POST + do: | + set resp_header X-Access-Level "admin" + set resp_header X-API-Version "v1" +- name: user-api + on: | + path glob(/api/users/*) & method GET + do: | + set resp_header X-Access-Level "user" + set resp_header X-API-Version "v1" +- name: public-api + on: | + path glob(/api/public/*) & method GET + do: | + set resp_header X-Access-Level "public" +`, &rules) + if err != nil { + b.Fatal(err) + } + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + b.Run("BuildHandler", func(b *testing.B) { + for b.Loop() { + rules.BuildHandler(upstream) + } + }) + + b.Run("RunHandler", func(b *testing.B) { + var r = &http.Request{ + Body: io.NopCloser(bytes.NewReader([]byte(""))), + URL: &url.URL{Path: "/api/users/"}, + } + var w noopResponseWriter + handler := rules.BuildHandler(upstream) + b.ResetTimer() + for b.Loop() { + handler.ServeHTTP(w, r) + } + }) +} + +type noopResponseWriter struct { +} + +func (w noopResponseWriter) Header() http.Header { + return http.Header{} +} + +func (w noopResponseWriter) Write(b []byte) (int, error) { + return len(b), nil +} + +func (w noopResponseWriter) WriteHeader(int) { +} diff --git a/internal/route/rules/rules_test.go b/internal/route/rules/rules_test.go deleted file mode 100644 index 1bc93ff6..00000000 --- a/internal/route/rules/rules_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package rules - -import ( - "testing" - - "github.com/yusing/godoxy/internal/serialization" - expect "github.com/yusing/goutils/testing" -) - -func TestParseRule(t *testing.T) { - test := []map[string]any{ - { - "name": "test", - "on": "method POST", - "do": "error 403 Forbidden", - }, - { - "name": "auth", - "on": `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`, - "do": "bypass", - }, - { - "name": "default", - "do": "require_basic_auth any_realm", - }, - } - - var rules struct { - Rules Rules - } - err := serialization.MapUnmarshalValidate(serialization.SerializedObject{"rules": test}, &rules) - expect.NoError(t, err) - expect.Equal(t, len(rules.Rules), len(test)) - expect.Equal(t, rules.Rules[0].Name, "test") - expect.Equal(t, rules.Rules[0].On.String(), "method POST") - expect.Equal(t, rules.Rules[0].Do.String(), "error 403 Forbidden") - - expect.Equal(t, rules.Rules[1].Name, "auth") - expect.Equal(t, rules.Rules[1].On.String(), `basic_auth "username" "password" | basic_auth username2 "password2" | basic_auth "username3" "password3"`) - expect.Equal(t, rules.Rules[1].Do.String(), "bypass") - - expect.Equal(t, rules.Rules[2].Name, "default") - expect.Equal(t, rules.Rules[2].Do.String(), "require_basic_auth any_realm") -} - -// TODO: real tests. diff --git a/internal/route/rules/template.go b/internal/route/rules/template.go new file mode 100644 index 00000000..538fb5a3 --- /dev/null +++ b/internal/route/rules/template.go @@ -0,0 +1,43 @@ +package rules + +import ( + "bytes" + "io" + "net/http" +) + +type templateOrStr interface { + Execute(w io.Writer, data any) error +} + +type strTemplate string + +func (t strTemplate) Execute(w io.Writer, _ any) error { + n, err := w.Write([]byte(t)) + if err != nil { + return err + } + if n != len(t) { + return io.ErrShortWrite + } + return nil +} + +type keyValueTemplate = Tuple[string, templateOrStr] + +func executeRequestTemplateString(tmpl templateOrStr, r *http.Request) (string, error) { + var buf bytes.Buffer + err := tmpl.Execute(&buf, reqResponseTemplateData{Request: r}) + if err != nil { + return "", err + } + return buf.String(), nil +} + +func executeRequestTemplateTo(tmpl templateOrStr, o io.Writer, r *http.Request) error { + return tmpl.Execute(o, reqResponseTemplateData{Request: r}) +} + +func executeReqRespTemplateTo(tmpl templateOrStr, o io.Writer, w http.ResponseWriter, r *http.Request) error { + return tmpl.Execute(o, reqResponseTemplateData{Request: r, Response: GetInitResponseModifier(w).Response()}) +} diff --git a/internal/route/rules/validate.go b/internal/route/rules/validate.go index 5f2f3cc7..19fd3c88 100644 --- a/internal/route/rules/validate.go +++ b/internal/route/rules/validate.go @@ -6,10 +6,11 @@ import ( "os" "path" "path/filepath" - "regexp" + "strconv" "strings" + "text/template" - "github.com/gobwas/glob" + "github.com/rs/zerolog" nettypes "github.com/yusing/godoxy/internal/net/types" gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" @@ -21,6 +22,17 @@ type ( First T1 Second T2 } + Tuple3[T1, T2, T3 any] struct { + First T1 + Second T2 + Third T3 + } + Tuple4[T1, T2, T3, T4 any] struct { + First T1 + Second T2 + Third T3 + Fourth T4 + } StrTuple = Tuple[string, string] IntTuple = Tuple[int, int] MapValueMatcher = Tuple[string, Matcher] @@ -30,97 +42,24 @@ func (t *Tuple[T1, T2]) Unpack() (T1, T2) { return t.First, t.Second } +func (t *Tuple3[T1, T2, T3]) Unpack() (T1, T2, T3) { + return t.First, t.Second, t.Third +} + +func (t *Tuple4[T1, T2, T3, T4]) Unpack() (T1, T2, T3, T4) { + return t.First, t.Second, t.Third, t.Fourth +} + func (t *Tuple[T1, T2]) String() string { return fmt.Sprintf("%v:%v", t.First, t.Second) } -type ( - Matcher func(string) bool - MatcherType string -) - -const ( - MatcherTypeString MatcherType = "string" - MatcherTypeGlob MatcherType = "glob" - MatcherTypeRegex MatcherType = "regex" -) - -func unquoteExpr(s string) (string, gperr.Error) { - if s == "" { - return "", nil - } - switch s[0] { - case '"', '\'', '`': - if s[0] != s[len(s)-1] { - return "", ErrUnterminatedQuotes - } - return s[1 : len(s)-1], nil - default: - return s, nil - } +func (t *Tuple3[T1, T2, T3]) String() string { + return fmt.Sprintf("%v:%v:%v", t.First, t.Second, t.Third) } -func ExtractExpr(s string) (matcherType MatcherType, expr string, err gperr.Error) { - idx := strings.IndexByte(s, '(') - if idx == -1 { - return MatcherTypeString, s, nil - } - idxEnd := strings.LastIndexByte(s, ')') - if idxEnd == -1 { - return "", "", ErrUnterminatedBrackets - } - - expr, err = unquoteExpr(s[idx+1 : idxEnd]) - if err != nil { - return "", "", err - } - matcherType = MatcherType(strings.ToLower(s[:idx])) - - switch matcherType { - case MatcherTypeGlob, MatcherTypeRegex, MatcherTypeString: - return - default: - return "", "", ErrInvalidArguments.Withf("invalid matcher type: %s", matcherType) - } -} - -func ParseMatcher(expr string) (Matcher, gperr.Error) { - t, expr, err := ExtractExpr(expr) - if err != nil { - return nil, err - } - switch t { - case MatcherTypeString: - return StringMatcher(expr) - case MatcherTypeGlob: - return GlobMatcher(expr) - case MatcherTypeRegex: - return RegexMatcher(expr) - } - // won't reach here - return nil, ErrInvalidArguments.Withf("invalid matcher type: %s", t) -} - -func StringMatcher(s string) (Matcher, gperr.Error) { - return func(s2 string) bool { - return s == s2 - }, nil -} - -func GlobMatcher(expr string) (Matcher, gperr.Error) { - g, err := glob.Compile(expr) - if err != nil { - return nil, ErrInvalidArguments.With(err) - } - return g.Match, nil -} - -func RegexMatcher(expr string) (Matcher, gperr.Error) { - re, err := regexp.Compile(expr) - if err != nil { - return nil, ErrInvalidArguments.With(err) - } - return re.MatchString, nil +func (t *Tuple4[T1, T2, T3, T4]) String() string { + return fmt.Sprintf("%v:%v:%v:%v", t.First, t.Second, t.Third, t.Fourth) } // validateSingleMatcher returns Matcher with the matcher validated. @@ -131,14 +70,6 @@ func validateSingleMatcher(args []string) (any, gperr.Error) { return ParseMatcher(args[0]) } -// toStrTuple returns *StrTuple. -func toStrTuple(args []string) (any, gperr.Error) { - if len(args) != 2 { - return nil, ErrExpectTwoArgs - } - return &StrTuple{args[0], args[1]}, nil -} - // toKVOptionalVMatcher returns *MapValueMatcher that value is optional. func toKVOptionalVMatcher(args []string) (any, gperr.Error) { switch len(args) { @@ -155,6 +86,18 @@ func toKVOptionalVMatcher(args []string) (any, gperr.Error) { } } +func toKeyValueTemplate(args []string) (any, gperr.Error) { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + + tmpl, err := validateTemplate(args[1], false) + if err != nil { + return nil, err + } + return &keyValueTemplate{args[0], tmpl}, nil +} + // validateURL returns types.URL with the URL validated. func validateURL(args []string) (any, gperr.Error) { if len(args) != 1 { @@ -164,6 +107,12 @@ func validateURL(args []string) (any, gperr.Error) { if err != nil { return nil, ErrInvalidArguments.With(err) } + if u.Scheme == "" { + // expect relative URL, must starts with / + if !strings.HasPrefix(u.Path, "/") { + return nil, ErrInvalidArguments.Withf("relative URL must starts with /") + } + } return u, nil } @@ -250,6 +199,57 @@ func validateMethod(args []string) (any, gperr.Error) { return method, nil } +func validateStatusCode(status string) (int, error) { + statusCode, err := strconv.Atoi(status) + if err != nil { + return 0, err + } + if statusCode < 100 || statusCode > 599 { + return 0, fmt.Errorf("status code out of range: %s", status) + } + return statusCode, nil +} + +// validateStatusRange returns Tuple[int, int] with the status range validated. +// accepted formats are: +// - +// - - +// - 1xx +// - 2xx +// - 3xx +// - 4xx +// - 5xx +func validateStatusRange(args []string) (any, gperr.Error) { + if len(args) != 1 { + return nil, ErrExpectOneArg + } + + beg, end, ok := strings.Cut(args[0], "-") + if !ok { // + end = beg + } + + switch beg { + case "1xx": + return &IntTuple{100, 199}, nil + case "2xx": + return &IntTuple{200, 299}, nil + case "3xx": + return &IntTuple{300, 399}, nil + case "4xx": + return &IntTuple{400, 499}, nil + case "5xx": + return &IntTuple{500, 599}, nil + } + + begInt, begErr := validateStatusCode(beg) + endInt, endErr := validateStatusCode(end) + if begErr != nil || endErr != nil { + return nil, ErrInvalidArguments.With(gperr.Join(begErr, endErr)) + } + return &IntTuple{begInt, endInt}, nil +} + // validateUserBCryptPassword returns *HashedCrendential with the password validated. func validateUserBCryptPassword(args []string) (any, gperr.Error) { if len(args) != 2 { @@ -260,20 +260,77 @@ func validateUserBCryptPassword(args []string) (any, gperr.Error) { // validateModField returns CommandHandler with the field validated. func validateModField(mod FieldModifier, args []string) (CommandHandler, gperr.Error) { + if len(args) == 0 { + return nil, ErrExpectTwoOrThreeArgs + } setField, ok := modFields[args[0]] if !ok { - return nil, ErrInvalidSetTarget.Subject(args[0]) + return nil, ErrUnknownModField.Subject(args[0]) + } + if mod == ModFieldRemove { + if len(args) != 2 { + return nil, ErrExpectTwoArgs + } + // setField expect validateStrTuple + args = append(args, "") } validArgs, err := setField.validate(args[1:]) if err != nil { - return nil, err.Withf(setField.help.String()) + return nil, err.With(setField.help.Error()) } modder := setField.builder(validArgs) switch mod { case ModFieldAdd: - return modder.add, nil + add := modder.add + if add == nil { + return nil, ErrInvalidArguments.Withf("add is not supported for %s", mod) + } + return add, nil case ModFieldRemove: - return modder.remove, nil + remove := modder.remove + if remove == nil { + return nil, ErrInvalidArguments.Withf("remove is not supported for %s", mod) + } + return remove, nil } - return modder.set, nil + set := modder.set + if set == nil { + return nil, ErrInvalidArguments.Withf("set is not supported for %s", mod) + } + return set, nil } + +func isTemplate(tmplStr string) bool { + return strings.Contains(tmplStr, "{{") +} + +func validateTemplate(tmplStr string, newline bool) (templateOrStr, gperr.Error) { + if newline && !strings.HasSuffix(tmplStr, "\n") { + tmplStr += "\n" + } + + if !isTemplate(tmplStr) { + return strTemplate(tmplStr), nil + } + + tmpl, err := template.New("template").Parse(tmplStr) + if err != nil { + return nil, ErrInvalidArguments.With(err) + } + return tmpl, nil +} + +func validateLevel(level string) (zerolog.Level, gperr.Error) { + l, err := zerolog.ParseLevel(level) + if err != nil { + return zerolog.NoLevel, ErrInvalidArguments.With(err) + } + return l, nil +} + +// func validateNotifProvider(provider string) gperr.Error { +// if !notif.HasProvider(provider) { +// return ErrInvalidArguments.Subject(provider) +// } +// return nil +// }