mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-13 20:49:54 +02:00
Add root-level inbound_mtls_profiles combining optional system CAs with PEM CA files, and entrypoint.inbound_mtls_profile to require client certificates on every HTTPS connection. Route-level inbound_mtls_profile is allowed only without a global profile; per-handshake TLS picks ClientCAs from SNI, and requests fail with 421 when Host and SNI would select different mTLS routes. Compile pools at init (SetInboundMTLSProfiles from state.initEntrypoint) and reject unknown profile refs or mixed global-plus-route configuration. Extend config.example.yml and package READMEs; add entrypoint and config tests for TLS mutation, handshakes, and validation.
206 lines
4.9 KiB
Go
206 lines
4.9 KiB
Go
package entrypoint
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"net/http"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/rs/zerolog/log"
|
|
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
|
|
"github.com/yusing/godoxy/internal/logging/accesslog"
|
|
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
|
|
"github.com/yusing/godoxy/internal/route/rules"
|
|
"github.com/yusing/godoxy/internal/types"
|
|
"github.com/yusing/goutils/pool"
|
|
"github.com/yusing/goutils/task"
|
|
)
|
|
|
|
type HTTPRoutes interface {
|
|
Get(alias string) (types.HTTPRoute, bool)
|
|
}
|
|
|
|
type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute
|
|
|
|
type Entrypoint struct {
|
|
task *task.Task
|
|
|
|
cfg *Config
|
|
|
|
middleware *middleware.Middleware
|
|
notFoundHandler http.Handler
|
|
accessLogger accesslog.AccessLogger
|
|
findRouteFunc findRouteFunc
|
|
shortLinkMatcher *ShortLinkMatcher
|
|
|
|
streamRoutes *pool.Pool[types.StreamRoute]
|
|
excludedRoutes *pool.Pool[types.Route]
|
|
|
|
// this only affects future http servers creation
|
|
httpPoolDisableLog atomic.Bool
|
|
|
|
servers *xsync.Map[string, *httpServer] // listen addr -> server
|
|
|
|
inboundMTLSProfiles map[string]*x509.CertPool
|
|
}
|
|
|
|
var _ entrypoint.Entrypoint = &Entrypoint{}
|
|
|
|
var emptyCfg Config
|
|
|
|
func NewTestEntrypoint(tb testing.TB, cfg *Config) *Entrypoint {
|
|
tb.Helper()
|
|
|
|
testTask := task.GetTestTask(tb)
|
|
ep := NewEntrypoint(testTask, cfg)
|
|
entrypoint.SetCtx(testTask, ep)
|
|
return ep
|
|
}
|
|
|
|
func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint {
|
|
if cfg == nil {
|
|
cfg = &emptyCfg
|
|
}
|
|
|
|
ep := &Entrypoint{
|
|
task: parent.Subtask("entrypoint", false),
|
|
cfg: cfg,
|
|
findRouteFunc: findRouteAnyDomain,
|
|
shortLinkMatcher: newShortLinkMatcher(),
|
|
streamRoutes: pool.New[types.StreamRoute]("stream_routes", "stream_routes"),
|
|
excludedRoutes: pool.New[types.Route]("excluded_routes", "excluded_routes"),
|
|
servers: xsync.NewMap[string, *httpServer](),
|
|
inboundMTLSProfiles: make(map[string]*x509.CertPool),
|
|
}
|
|
return ep
|
|
}
|
|
|
|
func (ep *Entrypoint) Task() *task.Task {
|
|
return ep.task
|
|
}
|
|
|
|
func (ep *Entrypoint) SupportProxyProtocol() bool {
|
|
return ep.cfg.SupportProxyProtocol
|
|
}
|
|
|
|
func (ep *Entrypoint) DisablePoolsLog(v bool) {
|
|
ep.httpPoolDisableLog.Store(v)
|
|
// apply to all running http servers
|
|
for _, srv := range ep.servers.Range {
|
|
srv.routes.DisableLog(v)
|
|
}
|
|
// apply to other pools
|
|
ep.streamRoutes.DisableLog(v)
|
|
ep.excludedRoutes.DisableLog(v)
|
|
}
|
|
|
|
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
|
|
return ep.shortLinkMatcher
|
|
}
|
|
|
|
func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] {
|
|
return newHTTPPoolAdapter(ep)
|
|
}
|
|
|
|
func (ep *Entrypoint) StreamRoutes() entrypoint.PoolLike[types.StreamRoute] {
|
|
return ep.streamRoutes
|
|
}
|
|
|
|
func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] {
|
|
return ep.excludedRoutes
|
|
}
|
|
|
|
func (ep *Entrypoint) GetServer(addr string) (HTTPServer, bool) {
|
|
return ep.servers.Load(addr)
|
|
}
|
|
|
|
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
|
|
if len(domains) == 0 {
|
|
ep.findRouteFunc = findRouteAnyDomain
|
|
} else {
|
|
for i, domain := range domains {
|
|
if !strings.HasPrefix(domain, ".") {
|
|
domains[i] = "." + domain
|
|
}
|
|
}
|
|
ep.findRouteFunc = findRouteByDomains(domains)
|
|
}
|
|
}
|
|
|
|
func (ep *Entrypoint) SetMiddlewares(mws []map[string]any) error {
|
|
if len(mws) == 0 {
|
|
ep.middleware = nil
|
|
return nil
|
|
}
|
|
|
|
mid, err := middleware.BuildMiddlewareFromChainRaw("entrypoint", mws)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ep.middleware = mid
|
|
|
|
log.Debug().Msg("entrypoint middleware loaded")
|
|
return nil
|
|
}
|
|
|
|
func (ep *Entrypoint) SetNotFoundRules(rules rules.Rules) {
|
|
ep.notFoundHandler = rules.BuildHandler(serveNotFound)
|
|
}
|
|
|
|
func (ep *Entrypoint) SetAccessLogger(parent task.Parent, cfg *accesslog.RequestLoggerConfig) error {
|
|
if cfg == nil {
|
|
ep.accessLogger = nil
|
|
return nil
|
|
}
|
|
|
|
accessLogger, err := accesslog.NewAccessLogger(parent, cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ep.accessLogger = accessLogger
|
|
log.Debug().Msg("entrypoint access logger created")
|
|
return nil
|
|
}
|
|
|
|
func findRouteAnyDomain(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
before, _, ok := strings.Cut(host, ".")
|
|
if ok {
|
|
target := before
|
|
if r, ok := routes.Get(target); ok {
|
|
return r
|
|
}
|
|
}
|
|
if r, ok := routes.Get(host); ok {
|
|
return r
|
|
}
|
|
// try striping the trailing :port from the host
|
|
if before, _, ok := strings.Cut(host, ":"); ok {
|
|
if r, ok := routes.Get(before); ok {
|
|
return r
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func findRouteByDomains(domains []string) func(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
return func(routes HTTPRoutes, host string) types.HTTPRoute {
|
|
host, _, _ = strings.Cut(host, ":") // strip the trailing :port
|
|
for _, domain := range domains {
|
|
if target, ok := strings.CutSuffix(host, domain); ok {
|
|
if r, ok := routes.Get(target); ok {
|
|
return r
|
|
}
|
|
}
|
|
}
|
|
|
|
// fallback to exact match
|
|
if r, ok := routes.Get(host); ok {
|
|
return r
|
|
}
|
|
return nil
|
|
}
|
|
}
|