diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 989353e0..525a6c6e 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -6,16 +6,20 @@ import ( "sync/atomic" "github.com/rs/zerolog/log" + entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" "github.com/yusing/godoxy/internal/logging/accesslog" "github.com/yusing/godoxy/internal/net/gphttp/middleware" "github.com/yusing/godoxy/internal/net/gphttp/middleware/errorpage" "github.com/yusing/godoxy/internal/route/routes" + "github.com/yusing/godoxy/internal/route/rules" "github.com/yusing/godoxy/internal/types" "github.com/yusing/goutils/task" ) type Entrypoint struct { middleware *middleware.Middleware + catchAllHandler http.Handler + notFoundHandler http.Handler accessLogger *accesslog.AccessLogger findRouteFunc func(host string) types.HTTPRoute } @@ -58,6 +62,22 @@ func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error { return nil } +func (ep *Entrypoint) SetCatchAllRules(rules rules.Rules) { + if len(rules) == 0 { + ep.catchAllHandler = nil + return + } + ep.catchAllHandler = rules.BuildHandler(http.HandlerFunc(ep.serveHTTP)) +} + +func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) { + if len(rules) == 0 { + ep.notFoundHandler = nil + return + } + ep.notFoundHandler = rules.BuildHandler(http.HandlerFunc(ep.serveNotFound)) +} + func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) (err error) { if cfg == nil { ep.accessLogger = nil @@ -72,7 +92,19 @@ func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.Request return err } +func (ep *Entrypoint) FindRoute(s string) types.HTTPRoute { + return ep.findRouteFunc(s) +} + func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if ep.catchAllHandler != nil { + ep.catchAllHandler.ServeHTTP(w, r) + return + } + ep.serveHTTP(w, r) +} + +func (ep *Entrypoint) serveHTTP(w http.ResponseWriter, r *http.Request) { if ep.accessLogger != nil { w = accesslog.NewResponseRecorder(w) defer ep.accessLogger.Log(r, w.(*accesslog.ResponseRecorder).Response()) @@ -87,8 +119,14 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else { route.ServeHTTP(w, r) } - return + case ep.notFoundHandler != nil: + ep.notFoundHandler.ServeHTTP(w, r) + default: + ep.serveNotFound(w, r) } +} + +func (ep *Entrypoint) serveNotFound(w http.ResponseWriter, r *http.Request) { // Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway? // On nginx, when route for domain does not exist, it returns StatusBadGateway. // Then scraper / scanners will know the subdomain is invalid. diff --git a/internal/entrypoint/entrypoint_benchmark_test.go b/internal/entrypoint/entrypoint_benchmark_test.go index 3739f9b2..f5796fce 100644 --- a/internal/entrypoint/entrypoint_benchmark_test.go +++ b/internal/entrypoint/entrypoint_benchmark_test.go @@ -1,4 +1,4 @@ -package entrypoint +package entrypoint_test import ( "io" @@ -10,6 +10,7 @@ import ( "strings" "testing" + . "github.com/yusing/godoxy/internal/entrypoint" "github.com/yusing/godoxy/internal/route" "github.com/yusing/godoxy/internal/route/routes" "github.com/yusing/godoxy/internal/types" diff --git a/internal/entrypoint/entrypoint_test.go b/internal/entrypoint/entrypoint_test.go index e5cd1bf2..d36f4879 100644 --- a/internal/entrypoint/entrypoint_test.go +++ b/internal/entrypoint/entrypoint_test.go @@ -1,8 +1,9 @@ -package entrypoint +package entrypoint_test import ( "testing" + . "github.com/yusing/godoxy/internal/entrypoint" "github.com/yusing/godoxy/internal/route" "github.com/yusing/godoxy/internal/route/routes" @@ -29,14 +30,14 @@ func run(t *testing.T, match []string, noMatch []string) { for _, test := range match { t.Run(test, func(t *testing.T) { - found := ep.findRouteFunc(test) + found := ep.FindRoute(test) expect.NotNil(t, found) }) } for _, test := range noMatch { t.Run(test, func(t *testing.T) { - found := ep.findRouteFunc(test) + found := ep.FindRoute(test) expect.Nil(t, found) }) } diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index 88678c09..1a0e1ed9 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -128,7 +128,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { } if len(r.Rules) > 0 { - r.handler = r.Rules.BuildHandler(r.Name(), r.handler) + r.handler = r.Rules.BuildHandler(r.handler) } if r.HealthMon != nil {