Files
godoxy-yusing/internal/entrypoint/entrypoint.go
Yuzerion 31eea0a885 feat(entrypoint): add inbound mTLS profiles for HTTPS (#220)
Introduce reusable `inbound_mtls_profiles` in root config and support
`entrypoint.inbound_mtls_profile` to require client certificates for all
HTTPS traffic on an entrypoint. Profiles can trust the system CA store,
custom PEM CA files, or both, and are compiled into TLS client-auth
pools during entrypoint initialization.

Also add route-scoped `inbound_mtls_profile` support for HTTP-based
routes when no global entrypoint profile is configured. Route-level mTLS
selection is driven by TLS SNI, preserves existing behavior for open and
unmatched hosts, and returns the intended 421 response when secure
requests omit SNI or when Host and SNI resolve to different routes.

Add validation for missing profile references and unsupported non-HTTP
route usage, update config and route documentation/examples, expand
inbound mTLS handshake and routing regression coverage, and bump
`goutils` for HTTPS listener test support.
2026-04-15 12:14:22 +08:00

206 lines
4.9 KiB
Go

package entrypoint
import (
"crypto/x509"
"net/http"
"strings"
"sync/atomic"
"testing"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
"github.com/yusing/godoxy/internal/route/rules"
"github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/pool"
"github.com/yusing/goutils/task"
)
type HTTPRoutes interface {
Get(alias string) (types.HTTPRoute, bool)
}
type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute
type Entrypoint struct {
task *task.Task
cfg *Config
middleware *middleware.Middleware
notFoundHandler http.Handler
accessLogger accesslog.AccessLogger
findRouteFunc findRouteFunc
shortLinkMatcher *ShortLinkMatcher
streamRoutes *pool.Pool[types.StreamRoute]
excludedRoutes *pool.Pool[types.Route]
// this only affects future http servers creation
httpPoolDisableLog atomic.Bool
servers *xsync.Map[string, *httpServer] // listen addr -> server
inboundMTLSProfiles map[string]*x509.CertPool
}
var _ entrypoint.Entrypoint = &Entrypoint{}
var emptyCfg Config
func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint {
tb.Helper()
testTask := task.GetTestTask(tb)
ep := NewEntrypoint(testTask, cfg)
entrypoint.SetCtx(testTask, ep)
return ep
}
func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint {
if cfg == nil {
cfg = &emptyCfg
}
ep := &Entrypoint{
task: parent.Subtask("entrypoint", false),
cfg: cfg,
findRouteFunc: findRouteAnyDomain,
shortLinkMatcher: newShortLinkMatcher(),
streamRoutes: pool.New[types.StreamRoute]("stream_routes", "stream_routes"),
excludedRoutes: pool.New[types.Route]("excluded_routes", "excluded_routes"),
servers: xsync.NewMap[string, *httpServer](),
inboundMTLSProfiles: make(map[string]*x509.CertPool),
}
return ep
}
func (ep *Entrypoint) Task() *task.Task {
return ep.task
}
func (ep *Entrypoint) SupportProxyProtocol() bool {
return ep.cfg.SupportProxyProtocol
}
func (ep *Entrypoint) DisablePoolsLog(v bool) {
ep.httpPoolDisableLog.Store(v)
// apply to all running http servers
for _, srv := range ep.servers.Range {
srv.routes.DisableLog(v)
}
// apply to other pools
ep.streamRoutes.DisableLog(v)
ep.excludedRoutes.DisableLog(v)
}
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
return ep.shortLinkMatcher
}
func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] {
return newHTTPPoolAdapter(ep)
}
func (ep *Entrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] {
return ep.streamRoutes
}
func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] {
return ep.excludedRoutes
}
func (ep *Entrypoint) GetServer(addr string) (HTTPServer, bool) {
return ep.servers.Load(addr)
}
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
if len(domains) == 0 {
ep.findRouteFunc = findRouteAnyDomain
} else {
for i, domain := range domains {
if !strings.HasPrefix(domain, ".") {
domains[i] = "." + domain
}
}
ep.findRouteFunc = findRouteByDomains(domains)
}
}
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
if len(mws) == 0 {
ep.middleware = nil
return nil
}
mid, err := middleware.BuildMiddlewareFromChainRaw("entrypoint", mws)
if err != nil {
return err
}
ep.middleware = mid
log.Debug().Msg("entrypoint middleware loaded")
return nil
}
func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) {
ep.notFoundHandler = rules.BuildHandler(serveNotFound)
}
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) error {
if cfg == nil {
ep.accessLogger = nil
return nil
}
accessLogger, err := accesslog.NewAccessLogger(parent, cfg)
if err != nil {
return err
}
ep.accessLogger = accessLogger
log.Debug().Msg("entrypoint access logger created")
return nil
}
func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute {
before, _, ok := strings.Cut(host, ".")
if ok {
target := before
if r, ok := routes.Get(target); ok {
return r
}
}
if r, ok := routes.Get(host); ok {
return r
}
// try striping the trailing :port from the host
if before, _, ok := strings.Cut(host, ":"); ok {
if r, ok := routes.Get(before); ok {
return r
}
}
return nil
}
func findRouteByDomains(domains []string) func(routes HTTPRoutes, host string) types.HTTPRoute {
return func(routes HTTPRoutes, host string) types.HTTPRoute {
host, _, _ = strings.Cut(host, ":") // strip the trailing :port
for _, domain := range domains {
if target, ok := strings.CutSuffix(host, domain); ok {
if r, ok := routes.Get(target); ok {
return r
}
}
}
// fallback to exact match
if r, ok := routes.Get(host); ok {
return r
}
return nil
}
}