package config import ( "bytes" "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io/fs" "iter" "os" "strconv" "strings" "sync" "github.com/goccy/go-yaml" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog" "github.com/rs/zerolog/log" acl "github.com/yusing/godoxy/internal/acl/types" "github.com/yusing/godoxy/internal/agentpool" "github.com/yusing/godoxy/internal/api" "github.com/yusing/godoxy/internal/autocert" autocertctx "github.com/yusing/godoxy/internal/autocert/types" "github.com/yusing/godoxy/internal/common" config "github.com/yusing/godoxy/internal/config/types" "github.com/yusing/godoxy/internal/entrypoint" entrypointctx "github.com/yusing/godoxy/internal/entrypoint/types" homepage "github.com/yusing/godoxy/internal/homepage/types" "github.com/yusing/godoxy/internal/logging" "github.com/yusing/godoxy/internal/maxmind" "github.com/yusing/godoxy/internal/metrics/systeminfo" "github.com/yusing/godoxy/internal/metrics/uptime" "github.com/yusing/godoxy/internal/notif" route "github.com/yusing/godoxy/internal/route/provider" "github.com/yusing/godoxy/internal/serialization" "github.com/yusing/godoxy/internal/types" gperr "github.com/yusing/goutils/errs" "github.com/yusing/goutils/server" "github.com/yusing/goutils/task" ) type state struct { config.Config providers *xsync.Map[string, types.RouteProvider] autocertProvider *autocert.Provider entrypoint *entrypoint.Entrypoint task *task.Task // used for temporary logging // discarded on failed reload tmpLogBuf *bytes.Buffer tmpLog zerolog.Logger } type CriticalError struct { err error } func (e CriticalError) Error() string { return e.err.Error() } func (e CriticalError) Unwrap() error { return e.err } func NewState() config.State { tmpLogBuf := bytes.NewBuffer(make([]byte, 0, 4096)) return &state{ providers: xsync.NewMap[string, types.RouteProvider](), task: task.RootTask("config", false), tmpLogBuf: tmpLogBuf, tmpLog: logging.NewLoggerWithFixedLevel(zerolog.InfoLevel, tmpLogBuf), } } var stateMu sync.RWMutex func GetState() config.State { return config.ActiveState.Load() } func SetState(state config.State) { stateMu.Lock() defer stateMu.Unlock() cfg := state.Value() config.ActiveState.Store(state) homepage.ActiveConfig.Store(&cfg.Homepage) } func HasState() bool { return config.ActiveState.Load() != nil } func Value() *config.Config { return config.ActiveState.Load().Value() } func (state *state) InitFromFile(filename string) error { data, err := os.ReadFile(filename) if err != nil { if errors.Is(err, fs.ErrNotExist) { state.Config = config.DefaultConfig() } else { return CriticalError{err} } } return state.Init(data) } func (state *state) Init(data []byte) error { err := serialization.UnmarshalValidate(data, &state.Config, yaml.Unmarshal) if err != nil { return CriticalError{err} } g := gperr.NewGroup("config load error") g.Go(state.initMaxMind) g.Go(state.initProxmox) g.Go(state.initAutoCert) errs := g.Wait() // these won't benefit from running on goroutines errs.Add(state.initNotification()) errs.Add(state.initACL()) if err := state.initEntrypoint(); err != nil { errs.Add(CriticalError{err}) } errs.Add(state.loadRouteProviders()) return errs.Error() } func (state *state) Task() *task.Task { return state.task } func (state *state) Context() context.Context { return state.task.Context() } func (state *state) Value() *config.Config { return &state.Config } func (state *state) Entrypoint() entrypointctx.Entrypoint { return state.entrypoint } func (state *state) ShortLinkMatcher() config.ShortLinkMatcher { return state.entrypoint.ShortLinkMatcher() } // AutoCertProvider returns the autocert provider. // // If the autocert provider is not configured, it returns nil. func (state *state) AutoCertProvider() server.CertProvider { if state.autocertProvider == nil { return nil } return state.autocertProvider } func (state *state) LoadOrStoreProvider(key string, value types.RouteProvider) (actual types.RouteProvider, loaded bool) { actual, loaded = state.providers.LoadOrStore(key, value) return } func (state *state) DeleteProvider(key string) { state.providers.Delete(key) } func (state *state) IterProviders() iter.Seq2[string, types.RouteProvider] { return func(yield func(string, types.RouteProvider) bool) { for k, v := range state.providers.Range { if !yield(k, v) { return } } } } func (state *state) StartProviders() error { errs := gperr.NewGroup("provider errors") for _, p := range state.providers.Range { errs.Go(func() error { return p.Start(state.Task()) }) } return errs.Wait().Error() } func (state *state) NumProviders() int { return state.providers.Size() } func (state *state) FlushTmpLog() { _, _ = state.tmpLogBuf.WriteTo(os.Stdout) state.tmpLogBuf.Reset() } func (state *state) StartAPIServers() { // API Handler needs to start after auth is initialized. _, err := server.StartServer(state.task.Subtask("api_server", false), server.Options{ Name: "api", HTTPAddr: common.APIHTTPAddr, Handler: api.NewHandler(true), }) if err != nil { log.Err(err).Msg("failed to start API server") } // Local API Handler is used for unauthenticated access. if common.LocalAPIHTTPAddr != "" { _, err := server.StartServer(state.task.Subtask("local_api_server", false), server.Options{ Name: "local_api", HTTPAddr: common.LocalAPIHTTPAddr, Handler: api.NewHandler(false), }) if err != nil { log.Err(err).Msg("failed to start local API server") } } } func (state *state) StartMetrics() { systeminfo.Poller.Start(state.task) uptime.Poller.Start(state.task) } // initACL initializes the ACL. func (state *state) initACL() error { if !state.ACL.Valid() { return nil } err := state.ACL.Start(state.task) if err != nil { return err } acl.SetCtx(state.task, state.ACL) return nil } func (state *state) initEntrypoint() error { epCfg := state.Config.Entrypoint matchDomains := state.MatchDomains state.entrypoint = entrypoint.NewEntrypoint(state.task, &epCfg) state.entrypoint.SetFindRouteDomains(matchDomains) state.entrypoint.SetNotFoundRules(epCfg.Rules.NotFound) if len(matchDomains) > 0 { state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix(matchDomains[0]) } if state.autocertProvider != nil { if domain := getAutoCertDefaultDomain(state.autocertProvider); domain != "" { state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix("." + domain) } } entrypointctx.SetCtx(state.task, state.entrypoint) errs := gperr.NewBuilder("entrypoint error") errs.Add(state.entrypoint.SetMiddlewares(epCfg.Middlewares)) errs.Add(state.entrypoint.SetAccessLogger(state.task, epCfg.AccessLog)) return errs.Error() } func getAutoCertDefaultDomain(p *autocert.Provider) string { if p == nil { return "" } cert, err := tls.LoadX509KeyPair(p.GetCertPath(), p.GetKeyPath()) if err != nil || len(cert.Certificate) == 0 { return "" } x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { return "" } domain := x509Cert.Subject.CommonName if domain == "" && len(x509Cert.DNSNames) > 0 { domain = x509Cert.DNSNames[0] } domain = strings.TrimSpace(domain) if after, ok := strings.CutPrefix(domain, "*."); ok { domain = after } return strings.ToLower(domain) } func (state *state) initMaxMind() error { maxmindCfg := state.Providers.MaxMind if maxmindCfg != nil { return maxmind.SetInstance(state.task, maxmindCfg) } return nil } func (state *state) initNotification() error { notifCfg := state.Providers.Notification if len(notifCfg) == 0 { return nil } dispatcher := notif.StartNotifDispatcher(state.task) for _, notifier := range notifCfg { dispatcher.RegisterProvider(notifier) } return nil } func (state *state) initAutoCert() error { autocertCfg := state.AutoCert if autocertCfg == nil { autocertCfg = new(autocert.Config) _ = autocertCfg.Validate() } user, legoCfg, err := autocertCfg.GetLegoConfig() if err != nil { return err } p, err := autocert.NewProvider(autocertCfg, user, legoCfg) if err != nil { return err } if err := p.ObtainCertIfNotExistsAll(); err != nil { return err } p.ScheduleRenewalAll(state.task) p.PrintCertExpiriesAll() state.autocertProvider = p autocertctx.SetCtx(state.task, p) return nil } func (state *state) initProxmox() error { proxmoxCfg := state.Providers.Proxmox if len(proxmoxCfg) == 0 { return nil } var errs gperr.Group for _, cfg := range proxmoxCfg { errs.Go(func() error { if err := cfg.Init(state.task.Context()); err != nil { return gperr.PrependSubject(err, cfg.URL) } return nil }) } return errs.Wait().Error() } func (state *state) loadRouteProviders() error { providers := state.Providers errs := gperr.NewGroup("route provider errors") agentpool.RemoveAll() registerProvider := func(p types.RouteProvider) { if actual, loaded := state.providers.LoadOrStore(p.String(), p); loaded { errs.Addf("provider %s already exists, first: %s, second: %s", p.String(), actual.GetType(), p.GetType()) } } agentErrs := gperr.NewGroup("agent init errors") for _, a := range providers.Agents { agentErrs.Go(func() error { if err := a.Init(state.task.Context()); err != nil { return gperr.PrependSubject(err, a.String()) } agentpool.Add(a) return nil }) } if err := agentErrs.Wait().Error(); err != nil { errs.Add(err) } for _, a := range providers.Agents { registerProvider(route.NewAgentProvider(a)) } for _, filename := range providers.Files { p, err := route.NewFileProvider(filename) if err != nil { errs.Add(gperr.PrependSubject(err, filename)) return err } registerProvider(p) } for name, dockerCfg := range providers.Docker { registerProvider(route.NewDockerProvider(name, dockerCfg)) } lenLongestName := 0 for k := range state.providers.Range { if len(k) > lenLongestName { lenLongestName = len(k) } } // load routes concurrently loadErrs := gperr.NewGroup("route load errors") results := gperr.NewBuilder("loaded route providers") resultsMu := sync.Mutex{} for _, p := range state.providers.Range { loadErrs.Go(func() error { if err := p.LoadRoutes(); err != nil { return gperr.PrependSubject(err, p.String()) } resultsMu.Lock() results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes()) resultsMu.Unlock() return nil }) } if err := loadErrs.Wait().Error(); err != nil { errs.Add(err) } state.tmpLog.Info().Msg(results.String()) state.printRoutesByProvider(lenLongestName) state.printState() return errs.Wait().Error() } func (state *state) printRoutesByProvider(lenLongestName int) { var routeResults strings.Builder routeResults.Grow(4096) // more than enough routeResults.WriteString("routes by provider\n") lenLongestName += 2 // > + space for _, p := range state.providers.Range { providerName := p.String() routeCount := p.NumRoutes() // Print provider header fmt.Fprintf(&routeResults, "> %-"+strconv.Itoa(lenLongestName)+"s %d routes:\n", providerName, routeCount) if routeCount == 0 { continue } // calculate longest name for alias, r := range p.IterRoutes { if r.ShouldExclude() { continue } displayName := r.DisplayName() if displayName != alias { displayName = fmt.Sprintf("%s (%s)", displayName, alias) } if len(displayName)+3 > lenLongestName { // 3 spaces + "-" lenLongestName = len(displayName) + 3 } } for alias, r := range p.IterRoutes { if r.ShouldExclude() { continue } displayName := r.DisplayName() if displayName != alias { displayName = fmt.Sprintf("%s (%s)", displayName, alias) } fmt.Fprintf(&routeResults, " - %-"+strconv.Itoa(lenLongestName-2)+"s -> %s\n", displayName, r.TargetURL().String()) } } // Always print the routes since we want to show even empty providers routeStr := routeResults.String() if routeStr != "" { state.tmpLog.Info().Msg(routeStr) } } func (state *state) printState() { state.tmpLog.Info().Msg("active config:") yamlRepr, _ := yaml.Marshal(state.Config) state.tmpLog.Info().Msgf("%s", yamlRepr) // prevent copying when casting to string }