mirror of
https://github.com/yusing/godoxy.git
synced 2026-02-19 00:47:41 +01:00
203 lines
4.8 KiB
Go
203 lines
4.8 KiB
Go
package entrypoint
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/rs/zerolog/log"
|
|
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
|
"github.com/yusing/godoxy/internal/logging/accesslog"
|
|
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
|
"github.com/yusing/godoxy/internal/route/rules"
|
|
"github.com/yusing/godoxy/internal/types"
|
|
"github.com/yusing/goutils/pool"
|
|
"github.com/yusing/goutils/task"
|
|
)
|
|
|
|
type HTTPRoutes interface {
|
|
Get(alias string) (types.HTTPRoute, bool)
|
|
}
|
|
|
|
type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute
|
|
|
|
type Entrypoint struct {
|
|
task *task.Task
|
|
|
|
cfg *Config
|
|
|
|
middleware *middleware.Middleware
|
|
notFoundHandler http.Handler
|
|
accessLogger accesslog.AccessLogger
|
|
findRouteFunc findRouteFunc
|
|
shortLinkMatcher *ShortLinkMatcher
|
|
|
|
streamRoutes *pool.Pool[types.StreamRoute]
|
|
excludedRoutes *pool.Pool[types.Route]
|
|
|
|
// this only affects future http servers creation
|
|
httpPoolDisableLog atomic.Bool
|
|
|
|
servers *xsync.Map[string, *httpServer] // listen addr -> server
|
|
}
|
|
|
|
var _ entrypoint.Entrypoint = &Entrypoint{}
|
|
|
|
var emptyCfg Config
|
|
|
|
func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint {
|
|
tb.Helper()
|
|
|
|
testTask := task.GetTestTask(tb)
|
|
ep := NewEntrypoint(testTask, cfg)
|
|
entrypoint.SetCtx(testTask, ep)
|
|
return ep
|
|
}
|
|
|
|
func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint {
|
|
if cfg == nil {
|
|
cfg = &emptyCfg
|
|
}
|
|
|
|
ep := &Entrypoint{
|
|
task: parent.Subtask("entrypoint", false),
|
|
cfg: cfg,
|
|
findRouteFunc: findRouteAnyDomain,
|
|
shortLinkMatcher: newShortLinkMatcher(),
|
|
streamRoutes: pool.New[types.StreamRoute]("stream_routes", "stream_routes"),
|
|
excludedRoutes: pool.New[types.Route]("excluded_routes", "excluded_routes"),
|
|
servers: xsync.NewMap[string, *httpServer](),
|
|
}
|
|
return ep
|
|
}
|
|
|
|
func (ep *Entrypoint) Task() *task.Task {
|
|
return ep.task
|
|
}
|
|
|
|
func (ep *Entrypoint) SupportProxyProtocol() bool {
|
|
return ep.cfg.SupportProxyProtocol
|
|
}
|
|
|
|
func (ep *Entrypoint) DisablePoolsLog(v bool) {
|
|
ep.httpPoolDisableLog.Store(v)
|
|
// apply to all running http servers
|
|
for _, srv := range ep.servers.Range {
|
|
srv.routes.DisableLog(v)
|
|
}
|
|
// apply to other pools
|
|
ep.streamRoutes.DisableLog(v)
|
|
ep.excludedRoutes.DisableLog(v)
|
|
}
|
|
|
|
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
|
|
return ep.shortLinkMatcher
|
|
}
|
|
|
|
func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] {
|
|
return newHTTPPoolAdapter(ep)
|
|
}
|
|
|
|
func (ep *Entrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] {
|
|
return ep.streamRoutes
|
|
}
|
|
|
|
func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] {
|
|
return ep.excludedRoutes
|
|
}
|
|
|
|
func (ep *Entrypoint) GetServer(addr string) (HTTPServer, bool) {
|
|
return ep.servers.Load(addr)
|
|
}
|
|
|
|
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
|
|
if len(domains) == 0 {
|
|
ep.findRouteFunc = findRouteAnyDomain
|
|
} else {
|
|
for i, domain := range domains {
|
|
if !strings.HasPrefix(domain, ".") {
|
|
domains[i] = "." + domain
|
|
}
|
|
}
|
|
ep.findRouteFunc = findRouteByDomains(domains)
|
|
}
|
|
}
|
|
|
|
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
|
if len(mws) == 0 {
|
|
ep.middleware = nil
|
|
return nil
|
|
}
|
|
|
|
mid, err := middleware.BuildMiddlewareFromChainRaw("entrypoint", mws)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ep.middleware = mid
|
|
|
|
log.Debug().Msg("entrypoint middleware loaded")
|
|
return nil
|
|
}
|
|
|
|
func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) {
|
|
ep.notFoundHandler = rules.BuildHandler(serveNotFound)
|
|
}
|
|
|
|
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) error {
|
|
if cfg == nil {
|
|
ep.accessLogger = nil
|
|
return nil
|
|
}
|
|
|
|
accessLogger, err := accesslog.NewAccessLogger(parent, cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ep.accessLogger = accessLogger
|
|
log.Debug().Msg("entrypoint access logger created")
|
|
return nil
|
|
}
|
|
|
|
func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
//nolint:modernize
|
|
idx := strings.IndexByte(host, '.')
|
|
if idx != -1 {
|
|
target := host[:idx]
|
|
if r, ok := routes.Get(target); ok {
|
|
return r
|
|
}
|
|
}
|
|
if r, ok := routes.Get(host); ok {
|
|
return r
|
|
}
|
|
// try striping the trailing :port from the host
|
|
if before, _, ok := strings.Cut(host, ":"); ok {
|
|
if r, ok := routes.Get(before); ok {
|
|
return r
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func findRouteByDomains(domains []string) func(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
return func(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
host, _, _ = strings.Cut(host, ":") // strip the trailing :port
|
|
for _, domain := range domains {
|
|
if target, ok := strings.CutSuffix(host, domain); ok {
|
|
if r, ok := routes.Get(target); ok {
|
|
return r
|
|
}
|
|
}
|
|
}
|
|
|
|
// fallback to exact match
|
|
if r, ok := routes.Get(host); ok {
|
|
return r
|
|
}
|
|
return nil
|
|
}
|
|
}
|