Files
godoxy/internal/entrypoint/entrypoint.go
yusing 2a3823091d feat(entrypoint): add inbound mTLS profiles for HTTPS
Add root-level inbound_mtls_profiles combining optional system CAs with PEM
CA files, and entrypoint.inbound_mtls_profile to require client certificates
on every HTTPS connection. Route-level inbound_mtls_profile is allowed only
without a global profile; per-handshake TLS picks ClientCAs from SNI, and
requests fail with 421 when Host and SNI would select different mTLS routes.

Compile pools at init (SetInboundMTLSProfiles from state.initEntrypoint) and
reject unknown profile refs or mixed global-plus-route configuration.

Extend config.example.yml and package READMEs; add entrypoint and config
tests for TLS mutation, handshakes, and validation.
2026-04-13 15:14:57 +08:00

206 lines
4.9 KiB
Go

package entrypoint
import (
"crypto/x509"
"net/http"
"strings"
"sync/atomic"
"testing"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
"github.com/yusing/godoxy/internal/route/rules"
"github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/pool"
"github.com/yusing/goutils/task"
)
type HTTPRoutes interface {
Get(alias string) (types.HTTPRoute, bool)
}
type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute
type Entrypoint struct {
task *task.Task
cfg *Config
middleware *middleware.Middleware
notFoundHandler http.Handler
accessLogger accesslog.AccessLogger
findRouteFunc findRouteFunc
shortLinkMatcher *ShortLinkMatcher
streamRoutes *pool.Pool[types.StreamRoute]
excludedRoutes *pool.Pool[types.Route]
// this only affects future http servers creation
httpPoolDisableLog atomic.Bool
servers *xsync.Map[string, *httpServer] // listen addr -> server
inboundMTLSProfiles map[string]*x509.CertPool
}
var _ entrypoint.Entrypoint = &Entrypoint{}
var emptyCfg Config
func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint {
tb.Helper()
testTask := task.GetTestTask(tb)
ep := NewEntrypoint(testTask, cfg)
entrypoint.SetCtx(testTask, ep)
return ep
}
func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint {
if cfg == nil {
cfg = &emptyCfg
}
ep := &Entrypoint{
task: parent.Subtask("entrypoint", false),
cfg: cfg,
findRouteFunc: findRouteAnyDomain,
shortLinkMatcher: newShortLinkMatcher(),
streamRoutes: pool.New[types.StreamRoute]("stream_routes", "stream_routes"),
excludedRoutes: pool.New[types.Route]("excluded_routes", "excluded_routes"),
servers: xsync.NewMap[string, *httpServer](),
inboundMTLSProfiles: make(map[string]*x509.CertPool),
}
return ep
}
func (ep *Entrypoint) Task() *task.Task {
return ep.task
}
func (ep *Entrypoint) SupportProxyProtocol() bool {
return ep.cfg.SupportProxyProtocol
}
func (ep *Entrypoint) DisablePoolsLog(v bool) {
ep.httpPoolDisableLog.Store(v)
// apply to all running http servers
for _, srv := range ep.servers.Range {
srv.routes.DisableLog(v)
}
// apply to other pools
ep.streamRoutes.DisableLog(v)
ep.excludedRoutes.DisableLog(v)
}
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
return ep.shortLinkMatcher
}
func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] {
return newHTTPPoolAdapter(ep)
}
func (ep *Entrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] {
return ep.streamRoutes
}
func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] {
return ep.excludedRoutes
}
func (ep *Entrypoint) GetServer(addr string) (HTTPServer, bool) {
return ep.servers.Load(addr)
}
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
if len(domains) == 0 {
ep.findRouteFunc = findRouteAnyDomain
} else {
for i, domain := range domains {
if !strings.HasPrefix(domain, ".") {
domains[i] = "." + domain
}
}
ep.findRouteFunc = findRouteByDomains(domains)
}
}
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
if len(mws) == 0 {
ep.middleware = nil
return nil
}
mid, err := middleware.BuildMiddlewareFromChainRaw("entrypoint", mws)
if err != nil {
return err
}
ep.middleware = mid
log.Debug().Msg("entrypoint middleware loaded")
return nil
}
func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) {
ep.notFoundHandler = rules.BuildHandler(serveNotFound)
}
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) error {
if cfg == nil {
ep.accessLogger = nil
return nil
}
accessLogger, err := accesslog.NewAccessLogger(parent, cfg)
if err != nil {
return err
}
ep.accessLogger = accessLogger
log.Debug().Msg("entrypoint access logger created")
return nil
}
func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute {
before, _, ok := strings.Cut(host, ".")
if ok {
target := before
if r, ok := routes.Get(target); ok {
return r
}
}
if r, ok := routes.Get(host); ok {
return r
}
// try striping the trailing :port from the host
if before, _, ok := strings.Cut(host, ":"); ok {
if r, ok := routes.Get(before); ok {
return r
}
}
return nil
}
func findRouteByDomains(domains []string) func(routes HTTPRoutes, host string) types.HTTPRoute {
return func(routes HTTPRoutes, host string) types.HTTPRoute {
host, _, _ = strings.Cut(host, ":") // strip the trailing :port
for _, domain := range domains {
if target, ok := strings.CutSuffix(host, domain); ok {
if r, ok := routes.Get(target); ok {
return r
}
}
}
// fallback to exact match
if r, ok := routes.Get(host); ok {
return r
}
return nil
}
}