refactor(agents): enhance VerifyNewAgent

This commit is contained in:
yusing
2025-09-13 23:24:43 +08:00
parent 3288624cf2
commit 5e1da915dc

View File

@@ -1,18 +1,16 @@
package config package config
import ( import (
"slices"
"github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/route/provider" "github.com/yusing/go-proxy/internal/route/provider"
) )
func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) { func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) {
if slices.ContainsFunc(cfg.value.Providers.Agents, func(a *agent.AgentConfig) bool { for _, a := range cfg.value.Providers.Agents {
return a.Addr == host if a.Addr == host {
}) { return 0, gperr.New("agent already exists")
return 0, gperr.New("agent already exists") }
} }
var agentCfg agent.AgentConfig var agentCfg agent.AgentConfig
@@ -21,17 +19,20 @@ func (cfg *Config) VerifyNewAgent(host string, ca agent.PEMPair, client agent.PE
if err != nil { if err != nil {
return 0, gperr.Wrap(err, "failed to start agent") return 0, gperr.Wrap(err, "failed to start agent")
} }
agent.AddAgent(&agentCfg)
provider := provider.NewAgentProvider(&agentCfg) provider := provider.NewAgentProvider(&agentCfg)
if err := cfg.errIfExists(provider); err != nil { if _, loaded := cfg.providers.LoadOrStore(provider.String(), provider); loaded {
agent.RemoveAgent(&agentCfg) return 0, gperr.Errorf("provider %s already exists", provider.String())
return 0, err
} }
// agent must be added before loading routes
agent.AddAgent(&agentCfg)
err = provider.LoadRoutes() err = provider.LoadRoutes()
if err != nil { if err != nil {
cfg.providers.Delete(provider.String())
agent.RemoveAgent(&agentCfg) agent.RemoveAgent(&agentCfg)
return 0, gperr.Wrap(err, "failed to load routes") return 0, gperr.Wrap(err, "failed to load routes")
} }
return provider.NumRoutes(), nil return provider.NumRoutes(), nil
} }