Files
godoxy-yusing/internal/config/state.go
Yuzerion 31eea0a885 feat(entrypoint): add inbound mTLS profiles for HTTPS (#220)
Introduce reusable `inbound_mtls_profiles` in root config and support
`entrypoint.inbound_mtls_profile` to require client certificates for all
HTTPS traffic on an entrypoint. Profiles can trust the system CA store,
custom PEM CA files, or both, and are compiled into TLS client-auth
pools during entrypoint initialization.

Also add route-scoped `inbound_mtls_profile` support for HTTP-based
routes when no global entrypoint profile is configured. Route-level mTLS
selection is driven by TLS SNI, preserves existing behavior for open and
unmatched hosts, and returns the intended 421 response when secure
requests omit SNI or when Host and SNI resolve to different routes.

Add validation for missing profile references and unsupported non-HTTP
route usage, update config and route documentation/examples, expand
inbound mTLS handshake and routing regression coverage, and bump
`goutils` for HTTPS listener test support.
2026-04-15 12:14:22 +08:00

555 lines
14 KiB
Go

package config
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"iter"
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"github.com/goccy/go-yaml"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
acl "github.com/yusing/godoxy/internal/acl/types"
"github.com/yusing/godoxy/internal/agentpool"
"github.com/yusing/godoxy/internal/api"
"github.com/yusing/godoxy/internal/autocert"
autocertctx "github.com/yusing/godoxy/internal/autocert/types"
"github.com/yusing/godoxy/internal/common"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/entrypoint"
entrypointctx "github.com/yusing/godoxy/internal/entrypoint/types"
homepage "github.com/yusing/godoxy/internal/homepage/types"
"github.com/yusing/godoxy/internal/logging"
"github.com/yusing/godoxy/internal/maxmind"
"github.com/yusing/godoxy/internal/metrics/systeminfo"
"github.com/yusing/godoxy/internal/metrics/uptime"
"github.com/yusing/godoxy/internal/notif"
route "github.com/yusing/godoxy/internal/route/provider"
"github.com/yusing/godoxy/internal/serialization"
"github.com/yusing/godoxy/internal/types"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/server"
"github.com/yusing/goutils/task"
)
type state struct {
config.Config
providers *xsync.Map[string, types.RouteProvider]
autocertProvider *autocert.Provider
entrypoint *entrypoint.Entrypoint
task *task.Task
// used for temporary logging
// discarded on failed reload
tmpLogBuf *bytes.Buffer
tmpLog zerolog.Logger
}
type CriticalError struct {
err error
}
func (e CriticalError) Error() string {
return e.err.Error()
}
func (e CriticalError) Unwrap() error {
return e.err
}
func NewState() config.State {
tmpLogBuf := bytes.NewBuffer(make([]byte, 0, 4096))
return &state{
providers: xsync.NewMap[string, types.RouteProvider](),
task: task.RootTask("config", false),
tmpLogBuf: tmpLogBuf,
tmpLog: logging.NewLoggerWithFixedLevel(zerolog.InfoLevel, tmpLogBuf),
}
}
var stateMu sync.RWMutex
func GetState() config.State {
return config.ActiveState.Load()
}
func SetState(state config.State) {
stateMu.Lock()
defer stateMu.Unlock()
cfg := state.Value()
config.ActiveState.Store(state)
homepage.ActiveConfig.Store(&cfg.Homepage)
}
func HasState() bool {
return config.ActiveState.Load() != nil
}
func Value() *config.Config {
return config.ActiveState.Load().Value()
}
func (state *state) InitFromFile(filename string) error {
data, err := os.ReadFile(filename)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
state.Config = config.DefaultConfig()
} else {
return CriticalError{err}
}
}
return state.Init(data)
}
func (state *state) Init(data []byte) error {
err := serialization.UnmarshalValidate(data, &state.Config, yaml.Unmarshal)
if err != nil {
return CriticalError{err}
}
g := gperr.NewGroup("config load error")
g.Go(state.initMaxMind)
g.Go(state.initProxmox)
g.Go(state.initAutoCert)
errs := g.Wait()
// these won't benefit from running on goroutines
errs.Add(state.initNotification())
errs.Add(state.initACL())
if err := state.initEntrypoint(); err != nil {
errs.Add(CriticalError{err})
}
errs.Add(state.loadRouteProviders())
return errs.Error()
}
func (state *state) Task() *task.Task {
return state.task
}
func (state *state) Context() context.Context {
return state.task.Context()
}
func (state *state) Value() *config.Config {
return &state.Config
}
func (state *state) Entrypoint() entrypointctx.Entrypoint {
return state.entrypoint
}
func (state *state) ShortLinkMatcher() config.ShortLinkMatcher {
return state.entrypoint.ShortLinkMatcher()
}
// AutoCertProvider returns the autocert provider.
//
// If the autocert provider is not configured, it returns nil.
func (state *state) AutoCertProvider() server.CertProvider {
if state.autocertProvider == nil {
return nil
}
return state.autocertProvider
}
func (state *state) LoadOrStoreProvider(key string, value types.RouteProvider) (actual types.RouteProvider, loaded bool) {
actual, loaded = state.providers.LoadOrStore(key, value)
return
}
func (state *state) DeleteProvider(key string) {
state.providers.Delete(key)
}
func (state *state) IterProviders() iter.Seq2[string, types.RouteProvider] {
return func(yield func(string, types.RouteProvider) bool) {
for k, v := range state.providers.Range {
if !yield(k, v) {
return
}
}
}
}
func (state *state) StartProviders() error {
errs := gperr.NewGroup("provider errors")
for _, p := range state.providers.Range {
errs.Go(func() error {
return p.Start(state.Task())
})
}
return errs.Wait().Error()
}
func (state *state) NumProviders() int {
return state.providers.Size()
}
func (state *state) FlushTmpLog() {
_, _ = state.tmpLogBuf.WriteTo(os.Stdout)
state.tmpLogBuf.Reset()
}
func (state *state) StartAPIServers() {
// API Handler needs to start after auth is initialized.
_, err := server.StartServer(state.task.Subtask("api_server", false), server.Options{
Name: "api",
HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(true),
})
if err != nil {
log.Err(err).Msg("failed to start API server")
}
// Local API Handler is used for unauthenticated access.
if common.LocalAPIHTTPAddr != "" {
if err := validateLocalAPIAddr(common.LocalAPIHTTPAddr, common.LocalAPIAllowNonLoopback); err != nil {
log.Err(err).Str("addr", common.LocalAPIHTTPAddr).Msg("refusing to start local API server")
return
}
if common.LocalAPIAllowNonLoopback && !isLoopbackLocalAPIHost(common.LocalAPIHTTPAddr) {
log.Warn().
Str("addr", common.LocalAPIHTTPAddr).
Msg("local API server is allowed to bind to non-loopback addresses")
}
_, err := server.StartServer(state.task.Subtask("local_api_server", false), server.Options{
Name: "local_api",
HTTPAddr: common.LocalAPIHTTPAddr,
Handler: api.NewHandler(false),
})
if err != nil {
log.Err(err).Msg("failed to start local API server")
}
}
}
func validateLocalAPIAddr(addr string, allowNonLoopback bool) error {
if isLoopbackLocalAPIHost(addr) {
return nil
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return err
}
if allowNonLoopback {
return nil
}
switch strings.ToLower(host) {
case "localhost":
return nil
case "":
return errors.New("local API address must bind to a loopback host, not all interfaces")
}
ip, err := netip.ParseAddr(host)
if err != nil {
return fmt.Errorf("local API address must use a loopback host: %w", err)
}
if !ip.IsLoopback() {
return fmt.Errorf("local API address must bind to a loopback host, got %q", host)
}
return nil
}
func isLoopbackLocalAPIHost(addr string) bool {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
ip, err := netip.ParseAddr(host)
return err == nil && ip.IsLoopback()
}
func (state *state) StartMetrics() {
systeminfo.Poller.Start(state.task)
uptime.Poller.Start(state.task)
}
// initACL initializes the ACL.
func (state *state) initACL() error {
if !state.ACL.Valid() {
return nil
}
err := state.ACL.Start(state.task)
if err != nil {
return err
}
acl.SetCtx(state.task, state.ACL)
return nil
}
func (state *state) initEntrypoint() error {
epCfg := state.Config.Entrypoint
matchDomains := state.MatchDomains
state.entrypoint = entrypoint.NewEntrypoint(state.task, &epCfg)
state.entrypoint.SetFindRouteDomains(matchDomains)
state.entrypoint.SetNotFoundRules(epCfg.Rules.NotFound)
if len(matchDomains) > 0 {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix(matchDomains[0])
}
if state.autocertProvider != nil {
if domain := getAutoCertDefaultDomain(state.autocertProvider); domain != "" {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix("." + domain)
}
}
entrypointctx.SetCtx(state.task, state.entrypoint)
errs := gperr.NewBuilder("entrypoint error")
errs.Add(state.entrypoint.SetMiddlewares(epCfg.Middlewares))
errs.Add(state.entrypoint.SetAccessLogger(state.task, epCfg.AccessLog))
errs.Add(state.entrypoint.SetInboundMTLSProfiles(state.Config.InboundMTLSProfiles))
return errs.Error()
}
func getAutoCertDefaultDomain(p *autocert.Provider) string {
if p == nil {
return ""
}
cert, err := tls.LoadX509KeyPair(p.GetCertPath(), p.GetKeyPath())
if err != nil || len(cert.Certificate) == 0 {
return ""
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return ""
}
domain := x509Cert.Subject.CommonName
if domain == "" && len(x509Cert.DNSNames) > 0 {
domain = x509Cert.DNSNames[0]
}
domain = strings.TrimSpace(domain)
if after, ok := strings.CutPrefix(domain, "*."); ok {
domain = after
}
return strings.ToLower(domain)
}
func (state *state) initMaxMind() error {
maxmindCfg := state.Providers.MaxMind
if maxmindCfg != nil {
return maxmind.SetInstance(state.task, maxmindCfg)
}
return nil
}
func (state *state) initNotification() error {
notifCfg := state.Providers.Notification
if len(notifCfg) == 0 {
return nil
}
dispatcher := notif.StartNotifDispatcher(state.task)
for _, notifier := range notifCfg {
dispatcher.RegisterProvider(notifier)
}
return nil
}
func (state *state) initAutoCert() error {
autocertCfg := state.AutoCert
if autocertCfg == nil {
autocertCfg = new(autocert.Config)
_ = autocertCfg.Validate()
}
user, legoCfg, err := autocertCfg.GetLegoConfig()
if err != nil {
return err
}
p, err := autocert.NewProvider(autocertCfg, user, legoCfg)
if err != nil {
return err
}
if err := p.ObtainCertIfNotExistsAll(); err != nil {
return err
}
p.ScheduleRenewalAll(state.task)
p.PrintCertExpiriesAll()
state.autocertProvider = p
autocertctx.SetCtx(state.task, p)
return nil
}
func (state *state) initProxmox() error {
proxmoxCfg := state.Providers.Proxmox
if len(proxmoxCfg) == 0 {
return nil
}
var errs gperr.Group
for _, cfg := range proxmoxCfg {
errs.Go(func() error {
if err := cfg.Init(state.task.Context()); err != nil {
return gperr.PrependSubject(err, cfg.URL)
}
return nil
})
}
return errs.Wait().Error()
}
func (state *state) loadRouteProviders() error {
providers := state.Providers
errs := gperr.NewGroup("route provider errors")
agentpool.RemoveAll()
registerProvider := func(p types.RouteProvider) {
if actual, loaded := state.providers.LoadOrStore(p.String(), p); loaded {
errs.Addf("provider %s already exists, first: %s, second: %s", p.String(), actual.GetType(), p.GetType())
}
}
agentErrs := gperr.NewGroup("agent init errors")
for _, a := range providers.Agents {
agentErrs.Go(func() error {
if err := a.Init(state.task.Context()); err != nil {
return gperr.PrependSubject(err, a.String())
}
agentpool.Add(a)
return nil
})
}
if err := agentErrs.Wait().Error(); err != nil {
errs.Add(err)
}
for _, a := range providers.Agents {
registerProvider(route.NewAgentProvider(a))
}
for _, filename := range providers.Files {
p, err := route.NewFileProvider(filename)
if err != nil {
errs.Add(gperr.PrependSubject(err, filename))
return err
}
registerProvider(p)
}
for name, dockerCfg := range providers.Docker {
registerProvider(route.NewDockerProvider(name, dockerCfg))
}
lenLongestName := 0
for k := range state.providers.Range {
if len(k) > lenLongestName {
lenLongestName = len(k)
}
}
// load routes concurrently
loadErrs := gperr.NewGroup("route load errors")
results := gperr.NewBuilder("loaded route providers")
resultsMu := sync.Mutex{}
for _, p := range state.providers.Range {
loadErrs.Go(func() error {
if err := p.LoadRoutes(); err != nil {
return gperr.PrependSubject(err, p.String())
}
resultsMu.Lock()
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
resultsMu.Unlock()
return nil
})
}
if err := loadErrs.Wait().Error(); err != nil {
errs.Add(err)
}
state.tmpLog.Info().Msg(results.String())
state.printRoutesByProvider(lenLongestName)
state.printState()
return errs.Wait().Error()
}
func (state *state) printRoutesByProvider(lenLongestName int) {
var routeResults strings.Builder
routeResults.Grow(4096) // more than enough
routeResults.WriteString("routes by provider\n")
lenLongestName += 2 // > + space
for _, p := range state.providers.Range {
providerName := p.String()
routeCount := p.NumRoutes()
// Print provider header
fmt.Fprintf(&routeResults, "> %-"+strconv.Itoa(lenLongestName)+"s %d routes:\n", providerName, routeCount)
if routeCount == 0 {
continue
}
// calculate longest name
for alias, r := range p.IterRoutes {
if r.ShouldExclude() {
continue
}
displayName := r.DisplayName()
if displayName != alias {
displayName = fmt.Sprintf("%s (%s)", displayName, alias)
}
if len(displayName)+3 > lenLongestName { // 3 spaces + "-"
lenLongestName = len(displayName) + 3
}
}
for alias, r := range p.IterRoutes {
if r.ShouldExclude() {
continue
}
displayName := r.DisplayName()
if displayName != alias {
displayName = fmt.Sprintf("%s (%s)", displayName, alias)
}
fmt.Fprintf(&routeResults, " - %-"+strconv.Itoa(lenLongestName-2)+"s -> %s\n", displayName, r.TargetURL().String())
}
}
// Always print the routes since we want to show even empty providers
routeStr := routeResults.String()
if routeStr != "" {
state.tmpLog.Info().Msg(routeStr)
}
}
func (state *state) printState() {
state.tmpLog.Info().Msg("active config:")
yamlRepr, _ := yaml.Marshal(state.Config)
state.tmpLog.Info().Msgf("%s", yamlRepr) // prevent copying when casting to string
}