refactor: improved config initialization flow, add agent config

This commit is contained in:
yusing
2025-03-28 07:47:28 +08:00
parent fb8ce6c878
commit 84e8dc0e06
3 changed files with 180 additions and 42 deletions

View File

@@ -2,6 +2,7 @@ package config
import (
"context"
"errors"
"os"
"strconv"
"strings"
@@ -11,11 +12,11 @@ import (
"github.com/yusing/go-proxy/internal/api"
"github.com/yusing/go-proxy/internal/autocert"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config/types"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/entrypoint"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/http/server"
"github.com/yusing/go-proxy/internal/net/gphttp/server"
"github.com/yusing/go-proxy/internal/notif"
proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/task"
@@ -26,7 +27,7 @@ import (
)
type Config struct {
value *types.Config
value *config.Config
providers F.Map[string, *proxy.Provider]
autocertProvider *autocert.Provider
entrypoint *entrypoint.Entrypoint
@@ -35,7 +36,6 @@ type Config struct {
}
var (
instance *Config
cfgWatcher watcher.Watcher
reloadMu sync.Mutex
)
@@ -49,15 +49,11 @@ Make sure you rename it back before next time you start.`
You may run "ls-config" to show or dump the current config.`
)
var Validate = types.Validate
func GetInstance() *Config {
return instance
}
var Validate = config.Validate
func newConfig() *Config {
return &Config{
value: types.DefaultConfig(),
value: config.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](),
entrypoint: entrypoint.NewEntrypoint(),
task: task.RootTask("config", false),
@@ -65,16 +61,17 @@ func newConfig() *Config {
}
func Load() (*Config, gperr.Error) {
if instance != nil {
return instance, nil
if config.HasInstance() {
panic(errors.New("config already loaded"))
}
instance = newConfig()
cfg := newConfig()
config.SetInstance(cfg)
cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
return instance, instance.load()
return cfg, cfg.load()
}
func MatchDomains() []string {
return instance.value.MatchDomains
return config.GetInstance().Value().MatchDomains
}
func WatchChanges() {
@@ -122,22 +119,25 @@ func Reload() gperr.Error {
// cancel all current subtasks -> wait
// -> replace config -> start new subtasks
instance.task.Finish("config changed")
instance = newCfg
instance.Start(StartAllServers)
config.GetInstance().(*Config).Task().Finish("config changed")
newCfg.Start(StartAllServers)
config.SetInstance(newCfg)
return nil
}
func (cfg *Config) Value() *types.Config {
return instance.value
func (cfg *Config) Value() *config.Config {
return cfg.value
}
func (cfg *Config) Reload() gperr.Error {
return Reload()
}
// AutoCertProvider returns the autocert provider.
//
// If the autocert provider is not configured, it returns nil.
func (cfg *Config) AutoCertProvider() *autocert.Provider {
return instance.autocertProvider
return cfg.autocertProvider
}
func (cfg *Config) Task() *task.Task {
@@ -217,7 +217,7 @@ func (cfg *Config) load() gperr.Error {
gperr.LogFatal(errMsg, err)
}
model := types.DefaultConfig()
model := config.DefaultConfig()
if err := utils.DeserializeYAML(data, model); err != nil {
gperr.LogFatal(errMsg, err)
}
@@ -260,31 +260,65 @@ func (cfg *Config) initAutoCert(autocertCfg *autocert.AutocertConfig) (err gperr
return
}
func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
func (cfg *Config) errIfExists(p *proxy.Provider) gperr.Error {
if _, ok := cfg.providers.Load(p.String()); ok {
return gperr.Errorf("provider %s already exists", p.String())
}
return nil
}
lenLongestName := 0
func (cfg *Config) storeProvider(p *proxy.Provider) {
cfg.providers.Store(p.String(), p)
}
func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
errs := gperr.NewBuilder("route provider errors")
results := gperr.NewBuilder("loaded route providers")
removeAllAgents()
for _, agent := range providers.Agents {
if err := agent.Start(cfg.task); err != nil {
errs.Add(err.Subject(agent.String()))
continue
}
addAgent(agent)
p := proxy.NewAgentProvider(agent)
if err := cfg.errIfExists(p); err != nil {
errs.Add(err.Subject(p.String()))
continue
}
cfg.storeProvider(p)
}
for _, filename := range providers.Files {
p, err := proxy.NewFileProvider(filename)
if err == nil {
err = cfg.errIfExists(p)
}
if err != nil {
errs.Add(E.PrependSubject(filename, err))
errs.Add(gperr.PrependSubject(filename, err))
continue
}
cfg.providers.Store(p.String(), p)
if len(p.String()) > lenLongestName {
lenLongestName = len(p.String())
}
cfg.storeProvider(p)
}
for name, dockerHost := range providers.Docker {
p, err := proxy.NewDockerProvider(name, dockerHost)
if err != nil {
errs.Add(E.PrependSubject(name, err))
p := proxy.NewDockerProvider(name, dockerHost)
if err := cfg.errIfExists(p); err != nil {
errs.Add(err.Subject(p.String()))
continue
}
cfg.providers.Store(p.String(), p)
if len(p.String()) > lenLongestName {
lenLongestName = len(p.String())
}
cfg.storeProvider(p)
}
if cfg.providers.Size() == 0 {
return nil
}
lenLongestName := 0
cfg.providers.RangeAll(func(k string, _ *proxy.Provider) {
if len(k) > lenLongestName {
lenLongestName = len(k)
}
})
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
if err := p.LoadRoutes(); err != nil {
errs.Add(err.Subject(p.String()))