package entrypoint import ( "net/http" "strings" "sync/atomic" "testing" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" "github.com/yusing/godoxy/internal/logging/accesslog" "github.com/yusing/godoxy/internal/net/gphttp/middleware" "github.com/yusing/godoxy/internal/route/rules" "github.com/yusing/godoxy/internal/types" "github.com/yusing/goutils/pool" "github.com/yusing/goutils/task" ) type HTTPRoutes interface { Get(alias string) (types.HTTPRoute, bool) } type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute type Entrypoint struct { task *task.Task cfg *Config middleware *middleware.Middleware notFoundHandler http.Handler accessLogger accesslog.AccessLogger findRouteFunc findRouteFunc shortLinkMatcher *ShortLinkMatcher streamRoutes *pool.Pool[types.StreamRoute] excludedRoutes *pool.Pool[types.Route] // this only affects future http servers creation httpPoolDisableLog atomic.Bool servers *xsync.Map[string, *httpServer] // listen addr -> server } var _ entrypoint.Entrypoint = &Entrypoint{} var emptyCfg Config func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint { tb.Helper() testTask := task.GetTestTask(tb) ep := NewEntrypoint(testTask, cfg) entrypoint.SetCtx(testTask, ep) return ep } func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint { if cfg == nil { cfg = &emptyCfg } ep := &Entrypoint{ task: parent.Subtask("entrypoint", false), cfg: cfg, findRouteFunc: findRouteAnyDomain, shortLinkMatcher: newShortLinkMatcher(), streamRoutes: pool.New[types.StreamRoute]("stream_routes", "stream_routes"), excludedRoutes: pool.New[types.Route]("excluded_routes", "excluded_routes"), servers: xsync.NewMap[string, *httpServer](), } return ep } func (ep *Entrypoint) Task() *task.Task { return ep.task } func (ep *Entrypoint) SupportProxyProtocol() bool { return ep.cfg.SupportProxyProtocol } func (ep *Entrypoint) DisablePoolsLog(v bool) { ep.httpPoolDisableLog.Store(v) // apply to all running http servers for _, srv := range ep.servers.Range { srv.routes.DisableLog(v) } // apply to other pools ep.streamRoutes.DisableLog(v) ep.excludedRoutes.DisableLog(v) } func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher { return ep.shortLinkMatcher } func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] { return newHTTPPoolAdapter(ep) } func (ep *Entrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] { return ep.streamRoutes } func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] { return ep.excludedRoutes } func (ep *Entrypoint) GetServer(addr string) (HTTPServer, bool) { return ep.servers.Load(addr) } func (ep *Entrypoint) SetFindRouteDomains(domains []string) { if len(domains) == 0 { ep.findRouteFunc = findRouteAnyDomain } else { for i, domain := range domains { if !strings.HasPrefix(domain, ".") { domains[i] = "." + domain } } ep.findRouteFunc = findRouteByDomains(domains) } } func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error { if len(mws) == 0 { ep.middleware = nil return nil } mid, err := middleware.BuildMiddlewareFromChainRaw("entrypoint", mws) if err != nil { return err } ep.middleware = mid log.Debug().Msg("entrypoint middleware loaded") return nil } func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) { ep.notFoundHandler = rules.BuildHandler(serveNotFound) } func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) error { if cfg == nil { ep.accessLogger = nil return nil } accessLogger, err := accesslog.NewAccessLogger(parent, cfg) if err != nil { return err } ep.accessLogger = accessLogger log.Debug().Msg("entrypoint access logger created") return nil } func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute { before, _, ok := strings.Cut(host, ".") if ok { target := before if r, ok := routes.Get(target); ok { return r } } if r, ok := routes.Get(host); ok { return r } // try striping the trailing :port from the host if before, _, ok := strings.Cut(host, ":"); ok { if r, ok := routes.Get(before); ok { return r } } return nil } func findRouteByDomains(domains []string) func(routes HTTPRoutes, host string) types.HTTPRoute { return func(routes HTTPRoutes, host string) types.HTTPRoute { host, _, _ = strings.Cut(host, ":") // strip the trailing :port for _, domain := range domains { if target, ok := strings.CutSuffix(host, domain); ok { if r, ok := routes.Get(target); ok { return r } } } // fallback to exact match if r, ok := routes.Get(host); ok { return r } return nil } }