refactor(config): parallelize route provider initialization

This commit is contained in:
yusing
2025-09-13 23:25:29 +08:00
parent 5e1da915dc
commit 60c13a797b

View File

@@ -9,6 +9,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
agentPkg "github.com/yusing/go-proxy/agent/pkg/agent" agentPkg "github.com/yusing/go-proxy/agent/pkg/agent"
@@ -25,7 +26,6 @@ import (
proxy "github.com/yusing/go-proxy/internal/route/provider" proxy "github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/serialization" "github.com/yusing/go-proxy/internal/serialization"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
F "github.com/yusing/go-proxy/internal/utils/functional"
"github.com/yusing/go-proxy/internal/utils/strutils/ansi" "github.com/yusing/go-proxy/internal/utils/strutils/ansi"
"github.com/yusing/go-proxy/internal/watcher" "github.com/yusing/go-proxy/internal/watcher"
"github.com/yusing/go-proxy/internal/watcher/events" "github.com/yusing/go-proxy/internal/watcher/events"
@@ -33,7 +33,7 @@ import (
type Config struct { type Config struct {
value *config.Config value *config.Config
providers F.Map[string, *proxy.Provider] providers *xsync.Map[string, *proxy.Provider]
autocertProvider *autocert.Provider autocertProvider *autocert.Provider
entrypoint *entrypoint.Entrypoint entrypoint *entrypoint.Entrypoint
@@ -59,7 +59,7 @@ var Validate = config.Validate
func newConfig() *Config { func newConfig() *Config {
return &Config{ return &Config{
value: config.DefaultConfig(), value: config.DefaultConfig(),
providers: F.NewMapOf[string, *proxy.Provider](), providers: xsync.NewMap[string, *proxy.Provider](),
entrypoint: entrypoint.NewEntrypoint(), entrypoint: entrypoint.NewEntrypoint(),
task: task.RootTask("config", false), task: task.RootTask("config", false),
} }
@@ -174,12 +174,19 @@ func (cfg *Config) StartAutoCert() {
} }
func (cfg *Config) StartProxyProviders() { func (cfg *Config) StartProxyProviders() {
errs := cfg.providers.CollectErrors( var wg sync.WaitGroup
func(_ string, p *proxy.Provider) error {
return p.Start(cfg.task)
})
if err := gperr.Join(errs...); err != nil { errs := gperr.NewBuilderWithConcurrency()
for _, p := range cfg.providers.Range {
wg.Go(func() {
if err := p.Start(cfg.task); err != nil {
errs.Add(err.Subject(p.String()))
}
})
}
wg.Wait()
if err := errs.Error(); err != nil {
gperr.LogError("route provider errors", err) gperr.LogError("route provider errors", err)
} }
} }
@@ -315,72 +322,87 @@ func (cfg *Config) initProxmox(proxmoxCfg []proxmox.Config) gperr.Error {
return errs.Error() return errs.Error()
} }
func (cfg *Config) errIfExists(p *proxy.Provider) gperr.Error {
if _, ok := cfg.providers.Load(p.String()); ok {
return gperr.Errorf("provider %s already exists", p.String())
}
return nil
}
func (cfg *Config) storeProvider(p *proxy.Provider) { func (cfg *Config) storeProvider(p *proxy.Provider) {
cfg.providers.Store(p.String(), p) cfg.providers.Store(p.String(), p)
} }
func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error { func (cfg *Config) loadRouteProviders(providers *config.Providers) gperr.Error {
errs := gperr.NewBuilder("route provider errors") errs := gperr.NewBuilderWithConcurrency("route provider errors")
results := gperr.NewBuilder("loaded route providers") results := gperr.NewBuilder("loaded route providers")
agentPkg.RemoveAllAgents() agentPkg.RemoveAllAgents()
numProviders := len(providers.Agents) + len(providers.Files) + len(providers.Docker)
providersCh := make(chan *proxy.Provider, numProviders)
// start providers concurrently
var providersConsumer sync.WaitGroup
providersConsumer.Go(func() {
for p := range providersCh {
if actual, loaded := cfg.providers.LoadOrStore(p.String(), p); loaded {
errs.Add(gperr.Errorf("provider %s already exists, first: %s, second: %s", p.String(), actual.GetType(), p.GetType()))
continue
}
cfg.storeProvider(p)
}
})
var providersProducer sync.WaitGroup
for _, agent := range providers.Agents { for _, agent := range providers.Agents {
if err := agent.Start(cfg.task.Context()); err != nil { providersProducer.Go(func() {
errs.Add(gperr.PrependSubject(agent.String(), err)) if err := agent.Start(cfg.task.Context()); err != nil {
continue errs.Add(gperr.PrependSubject(agent.String(), err))
} return
agentPkg.AddAgent(agent) }
p := proxy.NewAgentProvider(agent) agentPkg.AddAgent(agent)
if err := cfg.errIfExists(p); err != nil { p := proxy.NewAgentProvider(agent)
errs.Add(err.Subject(p.String())) providersCh <- p
continue })
}
cfg.storeProvider(p)
}
for _, filename := range providers.Files {
p, err := proxy.NewFileProvider(filename)
if err == nil {
err = cfg.errIfExists(p)
}
if err != nil {
errs.Add(gperr.PrependSubject(filename, err))
continue
}
cfg.storeProvider(p)
}
for name, dockerHost := range providers.Docker {
p := proxy.NewDockerProvider(name, dockerHost)
if err := cfg.errIfExists(p); err != nil {
errs.Add(err.Subject(p.String()))
continue
}
cfg.storeProvider(p)
}
if cfg.providers.Size() == 0 {
return nil
} }
for _, filename := range providers.Files {
providersProducer.Go(func() {
p, err := proxy.NewFileProvider(filename)
if err != nil {
errs.Add(gperr.PrependSubject(filename, err))
} else {
providersCh <- p
}
})
}
for name, dockerHost := range providers.Docker {
providersProducer.Go(func() {
providersCh <- proxy.NewDockerProvider(name, dockerHost)
})
}
providersProducer.Wait()
close(providersCh)
providersConsumer.Wait()
lenLongestName := 0 lenLongestName := 0
cfg.providers.RangeAll(func(k string, _ *proxy.Provider) { for k := range cfg.providers.Range {
if len(k) > lenLongestName { if len(k) > lenLongestName {
lenLongestName = len(k) lenLongestName = len(k)
} }
}) }
results.EnableConcurrency() results.EnableConcurrency()
cfg.providers.RangeAllParallel(func(_ string, p *proxy.Provider) {
if err := p.LoadRoutes(); err != nil { // load routes concurrently
errs.Add(err.Subject(p.String())) var providersLoader sync.WaitGroup
} for _, p := range cfg.providers.Range {
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes()) providersLoader.Go(func() {
}) if err := p.LoadRoutes(); err != nil {
errs.Add(err.Subject(p.String()))
}
results.Addf("%-"+strconv.Itoa(lenLongestName)+"s %d routes", p.String(), p.NumRoutes())
})
}
providersLoader.Wait()
log.Info().Msg(results.String()) log.Info().Msg(results.String())
return errs.Error() return errs.Error()
} }