mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-17 05:59:42 +02:00
Introduce reusable `inbound_mtls_profiles` in root config and support `entrypoint.inbound_mtls_profile` to require client certificates for all HTTPS traffic on an entrypoint. Profiles can trust the system CA store, custom PEM CA files, or both, and are compiled into TLS client-auth pools during entrypoint initialization. Also add route-scoped `inbound_mtls_profile` support for HTTP-based routes when no global entrypoint profile is configured. Route-level mTLS selection is driven by TLS SNI, preserves existing behavior for open and unmatched hosts, and returns the intended 421 response when secure requests omit SNI or when Host and SNI resolve to different routes. Add validation for missing profile references and unsupported non-HTTP route usage, update config and route documentation/examples, expand inbound mTLS handshake and routing regression coverage, and bump `goutils` for HTTPS listener test support.
206 lines
4.9 KiB
Go
206 lines
4.9 KiB
Go
package entrypoint
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/rs/zerolog/log"
|
|
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
|
"github.com/yusing/godoxy/internal/logging/accesslog"
|
|
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
|
"github.com/yusing/godoxy/internal/route/rules"
|
|
"github.com/yusing/godoxy/internal/types"
|
|
"github.com/yusing/goutils/pool"
|
|
"github.com/yusing/goutils/task"
|
|
)
|
|
|
|
type HTTPRoutes interface {
|
|
Get(alias string) (types.HTTPRoute, bool)
|
|
}
|
|
|
|
type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute
|
|
|
|
type Entrypoint struct {
|
|
task *task.Task
|
|
|
|
cfg *Config
|
|
|
|
middleware *middleware.Middleware
|
|
notFoundHandler http.Handler
|
|
accessLogger accesslog.AccessLogger
|
|
findRouteFunc findRouteFunc
|
|
shortLinkMatcher *ShortLinkMatcher
|
|
|
|
streamRoutes *pool.Pool[types.StreamRoute]
|
|
excludedRoutes *pool.Pool[types.Route]
|
|
|
|
// this only affects future http servers creation
|
|
httpPoolDisableLog atomic.Bool
|
|
|
|
servers *xsync.Map[string, *httpServer] // listen addr -> server
|
|
|
|
inboundMTLSProfiles map[string]*x509.CertPool
|
|
}
|
|
|
|
var _ entrypoint.Entrypoint = &Entrypoint{}
|
|
|
|
var emptyCfg Config
|
|
|
|
func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint {
|
|
tb.Helper()
|
|
|
|
testTask := task.GetTestTask(tb)
|
|
ep := NewEntrypoint(testTask, cfg)
|
|
entrypoint.SetCtx(testTask, ep)
|
|
return ep
|
|
}
|
|
|
|
func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint {
|
|
if cfg == nil {
|
|
cfg = &emptyCfg
|
|
}
|
|
|
|
ep := &Entrypoint{
|
|
task: parent.Subtask("entrypoint", false),
|
|
cfg: cfg,
|
|
findRouteFunc: findRouteAnyDomain,
|
|
shortLinkMatcher: newShortLinkMatcher(),
|
|
streamRoutes: pool.New[types.StreamRoute]("stream_routes", "stream_routes"),
|
|
excludedRoutes: pool.New[types.Route]("excluded_routes", "excluded_routes"),
|
|
servers: xsync.NewMap[string, *httpServer](),
|
|
inboundMTLSProfiles: make(map[string]*x509.CertPool),
|
|
}
|
|
return ep
|
|
}
|
|
|
|
func (ep *Entrypoint) Task() *task.Task {
|
|
return ep.task
|
|
}
|
|
|
|
func (ep *Entrypoint) SupportProxyProtocol() bool {
|
|
return ep.cfg.SupportProxyProtocol
|
|
}
|
|
|
|
func (ep *Entrypoint) DisablePoolsLog(v bool) {
|
|
ep.httpPoolDisableLog.Store(v)
|
|
// apply to all running http servers
|
|
for _, srv := range ep.servers.Range {
|
|
srv.routes.DisableLog(v)
|
|
}
|
|
// apply to other pools
|
|
ep.streamRoutes.DisableLog(v)
|
|
ep.excludedRoutes.DisableLog(v)
|
|
}
|
|
|
|
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
|
|
return ep.shortLinkMatcher
|
|
}
|
|
|
|
func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] {
|
|
return newHTTPPoolAdapter(ep)
|
|
}
|
|
|
|
func (ep *Entrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] {
|
|
return ep.streamRoutes
|
|
}
|
|
|
|
func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] {
|
|
return ep.excludedRoutes
|
|
}
|
|
|
|
func (ep *Entrypoint) GetServer(addr string) (HTTPServer, bool) {
|
|
return ep.servers.Load(addr)
|
|
}
|
|
|
|
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
|
|
if len(domains) == 0 {
|
|
ep.findRouteFunc = findRouteAnyDomain
|
|
} else {
|
|
for i, domain := range domains {
|
|
if !strings.HasPrefix(domain, ".") {
|
|
domains[i] = "." + domain
|
|
}
|
|
}
|
|
ep.findRouteFunc = findRouteByDomains(domains)
|
|
}
|
|
}
|
|
|
|
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
|
if len(mws) == 0 {
|
|
ep.middleware = nil
|
|
return nil
|
|
}
|
|
|
|
mid, err := middleware.BuildMiddlewareFromChainRaw("entrypoint", mws)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ep.middleware = mid
|
|
|
|
log.Debug().Msg("entrypoint middleware loaded")
|
|
return nil
|
|
}
|
|
|
|
func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) {
|
|
ep.notFoundHandler = rules.BuildHandler(serveNotFound)
|
|
}
|
|
|
|
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) error {
|
|
if cfg == nil {
|
|
ep.accessLogger = nil
|
|
return nil
|
|
}
|
|
|
|
accessLogger, err := accesslog.NewAccessLogger(parent, cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ep.accessLogger = accessLogger
|
|
log.Debug().Msg("entrypoint access logger created")
|
|
return nil
|
|
}
|
|
|
|
func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
before, _, ok := strings.Cut(host, ".")
|
|
if ok {
|
|
target := before
|
|
if r, ok := routes.Get(target); ok {
|
|
return r
|
|
}
|
|
}
|
|
if r, ok := routes.Get(host); ok {
|
|
return r
|
|
}
|
|
// try striping the trailing :port from the host
|
|
if before, _, ok := strings.Cut(host, ":"); ok {
|
|
if r, ok := routes.Get(before); ok {
|
|
return r
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func findRouteByDomains(domains []string) func(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
return func(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
host, _, _ = strings.Cut(host, ":") // strip the trailing :port
|
|
for _, domain := range domains {
|
|
if target, ok := strings.CutSuffix(host, domain); ok {
|
|
if r, ok := routes.Get(target); ok {
|
|
return r
|
|
}
|
|
}
|
|
}
|
|
|
|
// fallback to exact match
|
|
if r, ok := routes.Get(host); ok {
|
|
return r
|
|
}
|
|
return nil
|
|
}
|
|
}
|