Files
godoxy-yusing/internal/config/state.go

498 lines
12 KiB
Go

package config
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"iter"
"os"
"strconv"
"strings"
"sync"
"github.com/goccy/go-yaml"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
acl "github.com/yusing/godoxy/internal/acl/types"
"github.com/yusing/godoxy/internal/agentpool"
"github.com/yusing/godoxy/internal/api"
"github.com/yusing/godoxy/internal/autocert"
autocertctx "github.com/yusing/godoxy/internal/autocert/types"
"github.com/yusing/godoxy/internal/common"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/entrypoint"
entrypointctx "github.com/yusing/godoxy/internal/entrypoint/types"
homepage "github.com/yusing/godoxy/internal/homepage/types"
"github.com/yusing/godoxy/internal/logging"
"github.com/yusing/godoxy/internal/maxmind"
"github.com/yusing/godoxy/internal/metrics/systeminfo"
"github.com/yusing/godoxy/internal/metrics/uptime"
"github.com/yusing/godoxy/internal/notif"
route "github.com/yusing/godoxy/internal/route/provider"
"github.com/yusing/godoxy/internal/serialization"
"github.com/yusing/godoxy/internal/types"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/server"
"github.com/yusing/goutils/task"
)
type state struct {
config.Config
providers *xsync.Map[string, types.RouteProvider]
autocertProvider *autocert.Provider
entrypoint *entrypoint.Entrypoint
task *task.Task
// used for temporary logging
// discarded on failed reload
tmpLogBuf *bytes.Buffer
tmpLog zerolog.Logger
}
type CriticalError struct {
err error
}
func (e CriticalError) Error() string {
return e.err.Error()
}
func (e CriticalError) Unwrap() error {
return e.err
}
func NewState() config.State {
tmpLogBuf := bytes.NewBuffer(make([]byte, 0, 4096))
return &state{
providers: xsync.NewMap[string, types.RouteProvider](),
task: task.RootTask("config", false),
tmpLogBuf: tmpLogBuf,
tmpLog: logging.NewLoggerWithFixedLevel(zerolog.InfoLevel, tmpLogBuf),
}
}
var stateMu sync.RWMutex
func GetState() config.State {
return config.ActiveState.Load()
}
func SetState(state config.State) {
stateMu.Lock()
defer stateMu.Unlock()
cfg := state.Value()
config.ActiveState.Store(state)
homepage.ActiveConfig.Store(&cfg.Homepage)
}
func HasState() bool {
return config.ActiveState.Load() != nil
}
func Value() *config.Config {
return config.ActiveState.Load().Value()
}
func (state *state) InitFromFile(filename string) error {
data, err := os.ReadFile(filename)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
state.Config = config.DefaultConfig()
} else {
return CriticalError{err}
}
}
return state.Init(data)
}
func (state *state) Init(data []byte) error {
err := serialization.UnmarshalValidate(data, &state.Config, yaml.Unmarshal)
if err != nil {
return CriticalError{err}
}
g := gperr.NewGroup("config load error")
g.Go(state.initMaxMind)
g.Go(state.initProxmox)
g.Go(state.initAutoCert)
errs := g.Wait()
// these won't benefit from running on goroutines
errs.Add(state.initNotification())
errs.Add(state.initACL())
if err := state.initEntrypoint(); err != nil {
errs.Add(CriticalError{err})
}
errs.Add(state.loadRouteProviders())
return errs.Error()
}
func (state *state) Task() *task.Task {
return state.task
}
func (state *state) Context() context.Context {
return state.task.Context()
}
func (state *state) Value() *config.Config {
return &state.Config
}
func (state *state) Entrypoint() entrypointctx.Entrypoint {
return state.entrypoint
}
func (state *state) ShortLinkMatcher() config.ShortLinkMatcher {
return state.entrypoint.ShortLinkMatcher()
}
// AutoCertProvider returns the autocert provider.
//
// If the autocert provider is not configured, it returns nil.
func (state *state) AutoCertProvider() server.CertProvider {
if state.autocertProvider == nil {
return nil
}
return state.autocertProvider
}
func (state *state) LoadOrStoreProvider(key string, value types.RouteProvider) (actual types.RouteProvider, loaded bool) {
actual, loaded = state.providers.LoadOrStore(key, value)
return
}
func (state *state) DeleteProvider(key string) {
state.providers.Delete(key)
}
func (state *state) IterProviders() iter.Seq2[string, types.RouteProvider] {
return func(yield func(string, types.RouteProvider) bool) {
for k, v := range state.providers.Range {
if !yield(k, v) {
return
}
}
}
}
func (state *state) StartProviders() error {
errs := gperr.NewGroup("provider errors")
for _, p := range state.providers.Range {
errs.Go(func() error {
return p.Start(state.Task())
})
}
return errs.Wait().Error()
}
func (state *state) NumProviders() int {
return state.providers.Size()
}
func (state *state) FlushTmpLog() {
_, _ = state.tmpLogBuf.WriteTo(os.Stdout)
state.tmpLogBuf.Reset()
}
func (state *state) StartAPIServers() {
// API Handler needs to start after auth is initialized.
_, err := server.StartServer(state.task.Subtask("api_server", false), server.Options{
Name: "api",
HTTPAddr: common.APIHTTPAddr,
Handler: api.NewHandler(true),
})
if err != nil {
log.Err(err).Msg("failed to start API server")
}
// Local API Handler is used for unauthenticated access.
if common.LocalAPIHTTPAddr != "" {
_, err := server.StartServer(state.task.Subtask("local_api_server", false), server.Options{
Name: "local_api",
HTTPAddr: common.LocalAPIHTTPAddr,
Handler: api.NewHandler(false),
})
if err != nil {
log.Err(err).Msg("failed to start local API server")
}
}
}
func (state *state) StartMetrics() {
systeminfo.Poller.Start(state.task)
uptime.Poller.Start(state.task)
}
// initACL initializes the ACL.
func (state *state) initACL() error {
if !state.ACL.Valid() {
return nil
}
err := state.ACL.Start(state.task)
if err != nil {
return err
}
acl.SetCtx(state.task, state.ACL)
return nil
}
func (state *state) initEntrypoint() error {
epCfg := state.Config.Entrypoint
matchDomains := state.MatchDomains
state.entrypoint = entrypoint.NewEntrypoint(state.task, &epCfg)
state.entrypoint.SetFindRouteDomains(matchDomains)
state.entrypoint.SetNotFoundRules(epCfg.Rules.NotFound)
if len(matchDomains) > 0 {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix(matchDomains[0])
}
if state.autocertProvider != nil {
if domain := getAutoCertDefaultDomain(state.autocertProvider); domain != "" {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix("." + domain)
}
}
entrypointctx.SetCtx(state.task, state.entrypoint)
errs := gperr.NewBuilder("entrypoint error")
errs.Add(state.entrypoint.SetMiddlewares(epCfg.Middlewares))
errs.Add(state.entrypoint.SetAccessLogger(state.task, epCfg.AccessLog))
return errs.Error()
}
func getAutoCertDefaultDomain(p *autocert.Provider) string {
if p == nil {
return ""
}
cert, err := tls.LoadX509KeyPair(p.GetCertPath(), p.GetKeyPath())
if err != nil || len(cert.Certificate) == 0 {
return ""
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return ""
}
domain := x509Cert.Subject.CommonName
if domain == "" && len(x509Cert.DNSNames) > 0 {
domain = x509Cert.DNSNames[0]
}
domain = strings.TrimSpace(domain)
if after, ok := strings.CutPrefix(domain, "*."); ok {
domain = after
}
return strings.ToLower(domain)
}
func (state *state) initMaxMind() error {
maxmindCfg := state.Providers.MaxMind
if maxmindCfg != nil {
return maxmind.SetInstance(state.task, maxmindCfg)
}
return nil
}
func (state *state) initNotification() error {
notifCfg := state.Providers.Notification
if len(notifCfg) == 0 {
return nil
}
dispatcher := notif.StartNotifDispatcher(state.task)
for _, notifier := range notifCfg {
dispatcher.RegisterProvider(notifier)
}
return nil
}
func (state *state) initAutoCert() error {
autocertCfg := state.AutoCert
if autocertCfg == nil {
autocertCfg = new(autocert.Config)
_ = autocertCfg.Validate()
}
user, legoCfg, err := autocertCfg.GetLegoConfig()
if err != nil {
return err
}
p, err := autocert.NewProvider(autocertCfg, user, legoCfg)
if err != nil {
return err
}
if err := p.ObtainCertIfNotExistsAll(); err != nil {
return err
}
p.ScheduleRenewalAll(state.task)
p.PrintCertExpiriesAll()
state.autocertProvider = p
autocertctx.SetCtx(state.task, p)
return nil
}
func (state *state) initProxmox() error {
proxmoxCfg := state.Providers.Proxmox
if len(proxmoxCfg) == 0 {
return nil
}
var errs gperr.Group
for _, cfg := range proxmoxCfg {
errs.Go(func() error {
if err := cfg.Init(state.task.Context()); err != nil {
return gperr.PrependSubject(err, cfg.URL)
}
return nil
})
}
return errs.Wait().Error()
}
func (state *state) loadRouteProviders() error {
providers := state.Providers
errs := gperr.NewGroup("route provider errors")
agentpool.RemoveAll()
registerProvider := func(p types.RouteProvider) {
if actual, loaded := state.providers.LoadOrStore(p.String(), p); loaded {
errs.Addf("provider %s already exists, first: %s, second: %s", p.String(), actual.GetType(), p.GetType())
}
}
agentErrs := gperr.NewGroup("agent init errors")
for _, a := range providers.Agents {
agentErrs.Go(func() error {
if err := a.Init(state.task.Context()); err != nil {
return gperr.PrependSubject(err, a.String())
}
agentpool.Add(a)
return nil
})
}
if err := agentErrs.Wait().Error(); err != nil {
errs.Add(err)
}
for _, a := range providers.Agents {
registerProvider(route.NewAgentProvider(a))
}
for _, filename := range providers.Files {
p, err := route.NewFileProvider(filename)
if err != nil {
errs.Add(gperr.PrependSubject(err, filename))
return err
}
registerProvider(p)
}
for name, dockerCfg := range providers.Docker {
registerProvider(route.NewDockerProvider(name, dockerCfg))
}
lenLongestName := 0
for k := range state.providers.Range {
if len(k) > lenLongestName {
lenLongestName = len(k)
}
}
// load routes concurrently
loadErrs := gperr.NewGroup("route load errors")
results := gperr.NewBuilder("loaded route providers")
resultsMu := sync.Mutex{}
for _, p := range state.providers.Range {
loadErrs.Go(func() error {
if err := p.LoadRoutes(); err != nil {
return gperr.PrependSubject(err, p.String())
}
resultsMu.Lock()
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
resultsMu.Unlock()
return nil
})
}
if err := loadErrs.Wait().Error(); err != nil {
errs.Add(err)
}
state.tmpLog.Info().Msg(results.String())
state.printRoutesByProvider(lenLongestName)
state.printState()
return errs.Wait().Error()
}
func (state *state) printRoutesByProvider(lenLongestName int) {
var routeResults strings.Builder
routeResults.Grow(4096) // more than enough
routeResults.WriteString("routes by provider\n")
lenLongestName += 2 // > + space
for _, p := range state.providers.Range {
providerName := p.String()
routeCount := p.NumRoutes()
// Print provider header
fmt.Fprintf(&routeResults, "> %-"+strconv.Itoa(lenLongestName)+"s %d routes:\n", providerName, routeCount)
if routeCount == 0 {
continue
}
// calculate longest name
for alias, r := range p.IterRoutes {
if r.ShouldExclude() {
continue
}
displayName := r.DisplayName()
if displayName != alias {
displayName = fmt.Sprintf("%s (%s)", displayName, alias)
}
if len(displayName)+3 > lenLongestName { // 3 spaces + "-"
lenLongestName = len(displayName) + 3
}
}
for alias, r := range p.IterRoutes {
if r.ShouldExclude() {
continue
}
displayName := r.DisplayName()
if displayName != alias {
displayName = fmt.Sprintf("%s (%s)", displayName, alias)
}
fmt.Fprintf(&routeResults, " - %-"+strconv.Itoa(lenLongestName-2)+"s -> %s\n", displayName, r.TargetURL().String())
}
}
// Always print the routes since we want to show even empty providers
routeStr := routeResults.String()
if routeStr != "" {
state.tmpLog.Info().Msg(routeStr)
}
}
func (state *state) printState() {
state.tmpLog.Info().Msg("active config:")
yamlRepr, _ := yaml.Marshal(state.Config)
state.tmpLog.Info().Msgf("%s", yamlRepr) // prevent copying when casting to string
}