diff --git a/internal/api/v1/checkhealth.go b/internal/api/v1/checkhealth.go index abd659e3..94534f71 100644 --- a/internal/api/v1/checkhealth.go +++ b/internal/api/v1/checkhealth.go @@ -1,13 +1,11 @@ package v1 import ( - "fmt" "net/http" - "strings" . "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/config" - R "github.com/yusing/go-proxy/internal/route" + "github.com/yusing/go-proxy/internal/watcher/health" ) func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { @@ -17,26 +15,14 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { return } - var ok bool - route := cfg.FindRoute(target) - - switch { - case route == nil: + isHealthy, ok := health.IsHealthy(target) + if !ok { HandleErr(w, r, ErrNotFound("target", target), http.StatusNotFound) return - case route.Type() == R.RouteTypeReverseProxy: - ok = IsSiteHealthy(route.URL().String()) - case route.Type() == R.RouteTypeStream: - entry := route.Entry() - ok = IsStreamHealthy( - strings.Split(entry.Scheme, ":")[1], // target scheme - fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]), - ) } - - if ok { + if isHealthy { w.WriteHeader(http.StatusOK) } else { - w.WriteHeader(http.StatusRequestTimeout) + w.WriteHeader(http.StatusServiceUnavailable) } } diff --git a/internal/api/v1/health_check.go b/internal/api/v1/health_check.go deleted file mode 100644 index 20825d65..00000000 --- a/internal/api/v1/health_check.go +++ /dev/null @@ -1,34 +0,0 @@ -package v1 - -import ( - "net" - "net/http" - - U "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/common" -) - -func IsSiteHealthy(url string) bool { - // try HEAD first - // if HEAD is not allowed, try GET - resp, err := U.Head(url) - if resp != nil { - resp.Body.Close() - } - if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed { - _, err = U.Get(url) - } - if resp != nil { - resp.Body.Close() - } - return err == nil -} - -func IsStreamHealthy(scheme, address string) bool { - conn, err := net.DialTimeout(scheme, address, common.DialTimeout) - if err != nil { - return false - } - conn.Close() - return true -} diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index be95da03..ca7380df 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -233,7 +233,7 @@ func (p *Provider) certState() CertState { sort.Strings(certDomains) if !reflect.DeepEqual(certDomains, wantedDomains) { - logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) + logger.Infof("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) return CertStateMismatch } diff --git a/internal/common/constants.go b/internal/common/constants.go index abba4f04..4e863cf7 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -19,20 +19,14 @@ const ( ConfigPath = ConfigBasePath + "/" + ConfigFileName MiddlewareComposeBasePath = ConfigBasePath + "/middlewares" -) -const ( SchemaBasePath = "schema" ConfigSchemaPath = SchemaBasePath + "/config.schema.json" FileProviderSchemaPath = SchemaBasePath + "/providers.schema.json" -) -const ( ComposeFileName = "compose.yml" ComposeExampleFileName = "compose.example.yml" -) -const ( ErrorPagesBasePath = "error_pages" ) @@ -46,6 +40,9 @@ var RequiredDirectories = []string{ const DockerHostFromEnv = "$DOCKER_HOST" const ( + HealthCheckIntervalDefault = 5 * time.Second + HealthCheckTimeoutDefault = 5 * time.Second + IdleTimeoutDefault = "0" WakeTimeoutDefault = "30s" StopTimeoutDefault = "10s" diff --git a/internal/config/config.go b/internal/config/config.go index 9f40b94b..242fe251 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -103,7 +103,7 @@ func (cfg *Config) WatchChanges() { case <-cfg.watcherCtx.Done(): return case <-cfg.reloadReq: - if err := cfg.Reload(); err.HasError() { + if err := cfg.Reload(); err != nil { cfg.l.Error(err) } } @@ -130,9 +130,9 @@ func (cfg *Config) WatchChanges() { }() } -func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) { +func (cfg *Config) forEachRoute(do func(alias string, r *R.Route, p *PR.Provider)) { cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) { - p.RangeRoutes(func(a string, r R.Route) { + p.RangeRoutes(func(a string, r *R.Route) { do(a, r, p) }) }) @@ -146,20 +146,20 @@ func (cfg *Config) load() (res E.NestedError) { defer cfg.l.Debug("loaded config") data, err := E.Check(os.ReadFile(common.ConfigPath)) - if err.HasError() { + if err != nil { b.Add(E.FailWith("read config", err)) logrus.Fatal(b.Build()) } if !common.NoSchemaValidation { - if err = Validate(data); err.HasError() { + if err = Validate(data); err != nil { b.Add(E.FailWith("schema validation", err)) logrus.Fatal(b.Build()) } } model := types.DefaultConfig() - if err := E.From(yaml.Unmarshal(data, model)); err.HasError() { + if err := E.From(yaml.Unmarshal(data, model)); err != nil { b.Add(E.FailWith("parse config", err)) logrus.Fatal(b.Build()) } @@ -182,7 +182,7 @@ func (cfg *Config) initAutoCert(autocertCfg *types.AutoCertConfig) (err E.Nested defer cfg.l.Debug("initialized autocert") cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() - if err.HasError() { + if err != nil { err = E.FailWith("autocert provider", err) } return @@ -220,12 +220,12 @@ func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.Neste errors := E.NewBuilder("errors in %s these providers", action) cfg.proxyProviders.RangeAllParallel(func(name string, p *PR.Provider) { - if err := do(p); err.HasError() { + if err := do(p); err != nil { errors.Add(err.Subject(p)) } }) - if err := errors.Build(); err.HasError() { + if err := errors.Build(); err != nil { cfg.l.Error(err) } } diff --git a/internal/config/query.go b/internal/config/query.go index 17741501..2f275c2b 100644 --- a/internal/config/query.go +++ b/internal/config/query.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/yusing/go-proxy/internal/common" - H "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/homepage" PR "github.com/yusing/go-proxy/internal/proxy/provider" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/types" @@ -15,8 +15,8 @@ import ( func (cfg *Config) DumpEntries() map[string]*types.RawEntry { entries := make(map[string]*types.RawEntry) - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { - entries[alias] = r.Entry() + cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { + entries[alias] = r.Entry }) return entries } @@ -29,7 +29,7 @@ func (cfg *Config) DumpProviders() map[string]*PR.Provider { return entries } -func (cfg *Config) HomepageConfig() H.HomePageConfig { +func (cfg *Config) HomepageConfig() homepage.Config { var proto, port string domains := cfg.value.MatchDomains cert, _ := cfg.autocertProvider.GetCert(nil) @@ -41,16 +41,16 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig { port = common.ProxyHTTPPort } - hpCfg := H.NewHomePageConfig() - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + hpCfg := homepage.NewHomePageConfig() + cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { if !r.Started() { return } - entry := r.Entry() + entry := r.Entry if entry.Homepage == nil { - entry.Homepage = &H.HomePageItem{ - Show: r.Entry().IsExplicit || !p.IsExplicitOnly(), + entry.Homepage = &homepage.Item{ + Show: r.Entry.IsExplicit || !p.IsExplicitOnly(), } } @@ -60,7 +60,7 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig { item.Show = true } - if !item.Show || r.Type() != R.RouteTypeReverseProxy { + if !item.Show || r.Type != R.RouteTypeReverseProxy { return } @@ -99,19 +99,19 @@ func (cfg *Config) HomepageConfig() H.HomePageConfig { func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { routes := make(map[string]U.SerializedObject) - cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + cfg.forEachRoute(func(alias string, r *R.Route, p *PR.Provider) { if !r.Started() { return } obj, err := U.Serialize(r) - if err.HasError() { + if err != nil { cfg.l.Error(err) return } obj["provider"] = p.GetName() - obj["type"] = string(r.Type()) + obj["type"] = string(r.Type) obj["started"] = r.Started() - obj["raw"] = r.Entry() + obj["raw"] = r.Entry routes[alias] = obj }) return routes @@ -138,9 +138,9 @@ func (cfg *Config) Statistics() map[string]any { } } -func (cfg *Config) FindRoute(alias string) R.Route { +func (cfg *Config) FindRoute(alias string) *R.Route { return F.MapFind(cfg.proxyProviders, - func(p *PR.Provider) (R.Route, bool) { + func(p *PR.Provider) (*R.Route, bool) { if route, ok := p.GetRoute(alias); ok { return route, true } diff --git a/internal/docker/label.go b/internal/docker/label.go index 547186e8..ff454d01 100644 --- a/internal/docker/label.go +++ b/internal/docker/label.go @@ -105,7 +105,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { default: l.Attribute = parts[2] nestedLabel, err := ParseLabel(strings.Join(parts[3:], "."), value) - if err.HasError() { + if err != nil { return nil, err } l.Value = nestedLabel diff --git a/internal/error/builder.go b/internal/error/builder.go index 1fc8a63c..e8e849d9 100644 --- a/internal/error/builder.go +++ b/internal/error/builder.go @@ -40,6 +40,8 @@ func (b Builder) Addf(format string, args ...any) Builder { } func (b Builder) AddRangeE(errs ...error) Builder { + b.Lock() + defer b.Unlock() for _, err := range errs { b.AddE(err) } diff --git a/internal/error/error_test.go b/internal/error/error_test.go index 22eb1c29..864d3130 100644 --- a/internal/error/error_test.go +++ b/internal/error/error_test.go @@ -42,7 +42,6 @@ func TestErrorNestedIs(t *testing.T) { func TestIsNil(t *testing.T) { var err NestedError ExpectTrue(t, err.Is(nil)) - ExpectFalse(t, err.HasError()) ExpectTrue(t, err == nil) ExpectTrue(t, err.NoError()) diff --git a/internal/homepage/homepage.go b/internal/homepage/homepage.go index 1c9a7b5d..73c6deeb 100644 --- a/internal/homepage/homepage.go +++ b/internal/homepage/homepage.go @@ -1,24 +1,24 @@ package homepage type ( - HomePageConfig map[string]HomePageCategory - HomePageCategory []*HomePageItem + Config map[string]Category + Category []*Item - HomePageItem struct { - Show bool `yaml:"show" json:"show"` - Name string `yaml:"name" json:"name"` - Icon string `yaml:"icon" json:"icon"` - URL string `yaml:"url" json:"url"` // alias + domain - Category string `yaml:"category" json:"category"` - Description string `yaml:"description" json:"description"` - WidgetConfig map[string]any `yaml:",flow" json:"widget_config"` + Item struct { + Show bool `json:"show" yaml:"show"` + Name string `json:"name" yaml:"name"` + Icon string `json:"icon" yaml:"icon"` + URL string `json:"url" yaml:"url"` // alias + domain + Category string `json:"category" yaml:"category"` + Description string `json:"description" yaml:"description"` + WidgetConfig map[string]any `json:"widget_config" yaml:",flow"` - SourceType string `yaml:"-" json:"source_type"` - AltURL string `yaml:"-" json:"alt_url"` // original proxy target + SourceType string `json:"source_type" yaml:"-"` + AltURL string `json:"alt_url" yaml:"-"` // original proxy target } ) -func (item *HomePageItem) IsEmpty() bool { +func (item *Item) IsEmpty() bool { return item == nil || (item.Name == "" && item.Icon == "" && item.URL == "" && @@ -27,17 +27,17 @@ func (item *HomePageItem) IsEmpty() bool { len(item.WidgetConfig) == 0) } -func NewHomePageConfig() HomePageConfig { - return HomePageConfig(make(map[string]HomePageCategory)) +func NewHomePageConfig() Config { + return Config(make(map[string]Category)) } -func (c *HomePageConfig) Clear() { - *c = make(HomePageConfig) +func (c *Config) Clear() { + *c = make(Config) } -func (c HomePageConfig) Add(item *HomePageItem) { +func (c Config) Add(item *Item) { if c[item.Category] == nil { - c[item.Category] = make(HomePageCategory, 0) + c[item.Category] = make(Category, 0) } c[item.Category] = append(c[item.Category], item) } diff --git a/internal/net/http/loadbalancer/ip_hash.go b/internal/net/http/loadbalancer/ip_hash.go index 8223516b..a48caf76 100644 --- a/internal/net/http/loadbalancer/ip_hash.go +++ b/internal/net/http/loadbalancer/ip_hash.go @@ -4,15 +4,40 @@ import ( "hash/fnv" "net" "net/http" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/middleware" ) -type ipHash struct{ *LoadBalancer } +type ipHash struct { + *LoadBalancer + realIP *middleware.Middleware +} -func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} } +func (lb *LoadBalancer) newIPHash() impl { + impl := &ipHash{LoadBalancer: lb} + if len(lb.Options) == 0 { + return impl + } + var err E.NestedError + impl.realIP, err = middleware.NewRealIP(lb.Options) + if err != nil { + logger.Errorf("loadbalancer %s invalid real_ip options: %s, ignoring", lb.Link, err) + } + return impl +} func (ipHash) OnAddServer(srv *Server) {} func (ipHash) OnRemoveServer(srv *Server) {} func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) { + if impl.realIP != nil { + impl.realIP.ModifyRequest(impl.serveHTTP, rw, r) + } else { + impl.serveHTTP(rw, r) + } +} + +func (impl ipHash) serveHTTP(rw http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { http.Error(rw, "Internal error", http.StatusInternalServerError) @@ -20,7 +45,7 @@ func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) return } idx := hashIP(ip) % uint32(len(impl.pool)) - if !impl.pool[idx].available.Load() { + if !impl.pool[idx].IsHealthy() { http.Error(rw, "Service unavailable", http.StatusServiceUnavailable) } impl.pool[idx].handler.ServeHTTP(rw, r) diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index c2984365..c6b82a36 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -1,13 +1,12 @@ package loadbalancer import ( - "context" "net/http" "sync" - "time" "github.com/go-acme/lego/v4/log" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/middleware" ) // TODO: stats of each server. @@ -19,20 +18,17 @@ type ( OnRemoveServer(srv *Server) } Config struct { - Link string - Mode Mode - Weight weightType + Link string `json:"link" yaml:"link"` + Mode Mode `json:"mode" yaml:"mode"` + Weight weightType `json:"weight" yaml:"weight"` + Options middleware.OptionsRaw `json:"options,omitempty" yaml:"options,omitempty"` } LoadBalancer struct { impl Config pool servers - poolMu sync.RWMutex - - ctx context.Context - cancel context.CancelFunc - done chan struct{} + poolMu sync.Mutex sumWeight weightType } @@ -73,8 +69,8 @@ func (lb *LoadBalancer) AddServer(srv *Server) { } func (lb *LoadBalancer) RemoveServer(srv *Server) { - lb.poolMu.RLock() - defer lb.poolMu.RUnlock() + lb.poolMu.Lock() + defer lb.poolMu.Unlock() lb.impl.OnRemoveServer(srv) @@ -85,7 +81,7 @@ func (lb *LoadBalancer) RemoveServer(srv *Server) { } } if lb.IsEmpty() { - lb.Stop() + lb.pool = nil return } @@ -171,54 +167,12 @@ func (lb *LoadBalancer) Start() { if lb.sumWeight != 0 { log.Warnf("weighted mode not supported yet") } - - lb.done = make(chan struct{}, 1) - lb.ctx, lb.cancel = context.WithCancel(context.Background()) - - updateAll := func() { - lb.poolMu.Lock() - defer lb.poolMu.Unlock() - - var wg sync.WaitGroup - wg.Add(len(lb.pool)) - for _, s := range lb.pool { - go func(s *Server) { - defer wg.Done() - s.checkUpdateAvail(lb.ctx) - }(s) - } - wg.Wait() - } - logger.Debugf("loadbalancer %s started", lb.Link) - - go func() { - defer lb.cancel() - defer close(lb.done) - - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - updateAll() - for { - select { - case <-lb.ctx.Done(): - return - case <-ticker.C: - updateAll() - } - } - }() } func (lb *LoadBalancer) Stop() { - if lb.cancel == nil { - return - } - - lb.cancel() - - <-lb.done + lb.poolMu.Lock() + defer lb.poolMu.Unlock() lb.pool = nil logger.Debugf("loadbalancer %s stopped", lb.Link) @@ -228,9 +182,9 @@ func (lb *LoadBalancer) availServers() servers { lb.poolMu.Lock() defer lb.poolMu.Unlock() - avail := servers{} + avail := make(servers, 0, len(lb.pool)) for _, s := range lb.pool { - if s.available.Load() { + if s.IsHealthy() { avail = append(avail, s) } } diff --git a/internal/net/http/loadbalancer/server.go b/internal/net/http/loadbalancer/server.go index 52fff167..8376a529 100644 --- a/internal/net/http/loadbalancer/server.go +++ b/internal/net/http/loadbalancer/server.go @@ -1,67 +1,42 @@ package loadbalancer import ( - "context" "net/http" - "sync/atomic" - "time" "github.com/yusing/go-proxy/internal/net/types" + U "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/watcher/health" ) type ( Server struct { - Name string - URL types.URL - Weight weightType - handler http.Handler + _ U.NoCopy - pinger *http.Client - available atomic.Bool + Name string + URL types.URL + Weight weightType + + handler http.Handler + healthMon health.HealthMonitor } servers []*Server ) -func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server { +func NewServer(name string, url types.URL, weight weightType, handler http.Handler, healthMon health.HealthMonitor) *Server { srv := &Server{ - Name: name, - URL: url, - Weight: weight, - handler: handler, - pinger: &http.Client{Timeout: 3 * time.Second}, + Name: name, + URL: url, + Weight: weight, + handler: handler, + healthMon: healthMon, } - srv.available.Store(true) return srv } -func (srv *Server) checkUpdateAvail(ctx context.Context) { - req, err := http.NewRequestWithContext( - ctx, - http.MethodHead, - srv.URL.String(), - nil, - ) - if err != nil { - logger.Error("failed to create request: ", err) - srv.available.Store(false) - } - - resp, err := srv.pinger.Do(req) - if err == nil && resp.StatusCode != http.StatusServiceUnavailable { - if !srv.available.Swap(true) { - logger.Infof("server %s is up", srv.Name) - } - } else if err != nil { - if srv.available.Swap(false) { - logger.Warnf("server %s is down: %s", srv.Name, err) - } - } else { - if srv.available.Swap(false) { - logger.Warnf("server %s is down: status %s", srv.Name, resp.Status) - } - } -} - func (srv *Server) String() string { return srv.Name } + +func (srv *Server) IsHealthy() bool { + return srv.healthMon.IsHealthy() +} diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index beaf1b18..502d18a2 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -30,6 +30,8 @@ type ( Options any Middleware struct { + _ U.NoCopy + name string before BeforeFunc // runs before ReverseProxy.ServeHTTP @@ -77,30 +79,37 @@ func (m *Middleware) MarshalJSON() ([]byte, error) { func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) { if len(optsRaw) != 0 && m.withOptions != nil { - if mWithOpt, err := m.withOptions(optsRaw); err != nil { - return nil, err - } else { - return mWithOpt, nil - } + return m.withOptions(optsRaw) } // WithOptionsClone is called only once // set withOptions and labelParser will not be used after that return &Middleware{ - m.name, - m.before, - m.modifyResponse, - nil, - m.impl, - m.parent, - m.children, - false, + name: m.name, + before: m.before, + modifyResponse: m.modifyResponse, + impl: m.impl, + parent: m.parent, + children: m.children, }, nil } -// TODO: check conflict or duplicates -func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) { - middlewares := make([]*Middleware, 0, len(middlewaresMap)) +func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) { + if m.before != nil { + m.before(next, w, r) + } +} + +func (m *Middleware) ModifyResponse(resp *Response) error { + if m.modifyResponse != nil { + return m.modifyResponse(resp) + } + return nil +} + +// TODO: check conflict or duplicates. +func createMiddlewares(middlewaresMap map[string]OptionsRaw) (middlewares []*Middleware, res E.NestedError) { + middlewares = make([]*Middleware, 0, len(middlewaresMap)) invalidM := E.NewBuilder("invalid middlewares") invalidOpts := E.NewBuilder("invalid options") @@ -124,10 +133,15 @@ func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[strin middlewares = append(middlewares, m) } - if invalidM.HasError() { + return +} + +func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.NestedError) { + var middlewares []*Middleware + middlewares, err = createMiddlewares(middlewaresMap) + if err != nil { return } - patchReverseProxy(rpName, rp, middlewares) return } diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index ba80b816..4f21c777 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -114,7 +114,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N } else { proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect } - rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), &rr) + rp := gphttp.NewReverseProxy("test", types.NewURL(proxyURL), &rr) mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr diff --git a/internal/net/http/reverse_proxy_mod.go b/internal/net/http/reverse_proxy_mod.go index ca242fa1..b3afc79f 100644 --- a/internal/net/http/reverse_proxy_mod.go +++ b/internal/net/http/reverse_proxy_mod.go @@ -86,7 +86,8 @@ type ReverseProxy struct { ServeHTTP http.HandlerFunc - TargetURL types.URL + TargetName string + TargetURL types.URL } func singleJoiningSlash(a, b string) string { @@ -144,11 +145,11 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // } // -func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy { +func NewReverseProxy(name string, target types.URL, transport http.RoundTripper) *ReverseProxy { if transport == nil { panic("nil transport") } - rp := &ReverseProxy{Transport: transport, TargetURL: target} + rp := &ReverseProxy{Transport: transport, TargetName: name, TargetURL: target} rp.ServeHTTP = rp.serveHTTP return rp } @@ -194,9 +195,9 @@ func (p *ReverseProxy) errorHandler(rw http.ResponseWriter, r *http.Request, err switch { case errors.Is(err, context.Canceled), errors.Is(err, io.EOF): - logger.Debugf("http proxy to %s error: %s", r.URL.String(), err) + logger.Debugf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) default: - logger.Errorf("http proxy to %s error: %s", r.URL.String(), err) + logger.Errorf("http proxy to %s(%s) error: %s", p.TargetName, r.URL.String(), err) } if writeHeader { rw.WriteHeader(http.StatusBadGateway) diff --git a/internal/net/types/cidr.go b/internal/net/types/cidr.go index e3ab23fd..13177573 100644 --- a/internal/net/types/cidr.go +++ b/internal/net/types/cidr.go @@ -9,20 +9,21 @@ import ( type CIDR net.IPNet -func (*CIDR) ConvertFrom(val any) (any, E.NestedError) { - cidr, ok := val.(string) +func (cidr *CIDR) ConvertFrom(val any) E.NestedError { + cidrStr, ok := val.(string) if !ok { - return nil, E.TypeMismatch[string](val) + return E.TypeMismatch[string](val) } - if !strings.Contains(cidr, "/") { - cidr += "/32" // single IP + if !strings.Contains(cidrStr, "/") { + cidrStr += "/32" // single IP } - _, ipnet, err := net.ParseCIDR(cidr) + _, ipnet, err := net.ParseCIDR(cidrStr) if err != nil { - return nil, E.Invalid("CIDR", cidr) + return E.Invalid("CIDR", cidr) } - return (*CIDR)(ipnet), nil + *cidr = CIDR(*ipnet) + return nil } func (cidr *CIDR) Contains(ip net.IP) bool { diff --git a/internal/net/types/url.go b/internal/net/types/url.go index 065b0ba3..e0d09cba 100644 --- a/internal/net/types/url.go +++ b/internal/net/types/url.go @@ -1,10 +1,22 @@ package types -import "net/url" +import ( + urlPkg "net/url" +) -type URL struct{ *url.URL } +type URL struct { + *urlPkg.URL +} -func NewURL(url *url.URL) URL { +func ParseURL(url string) (URL, error) { + u, err := urlPkg.Parse(url) + if err != nil { + return URL{}, err + } + return URL{URL: u}, nil +} + +func NewURL(url *urlPkg.URL) URL { return URL{url} } @@ -19,6 +31,10 @@ func (u URL) MarshalText() (text []byte, err error) { return []byte(u.String()), nil } -func (u URL) Equals(other URL) bool { +func (u URL) Equals(other *URL) bool { return u.URL == other.URL || u.String() == other.String() } + +func (u URL) JoinPath(path string) URL { + return URL{u.URL.JoinPath(path)} +} diff --git a/internal/proxy/entry.go b/internal/proxy/entry.go index 803fdfe5..d3ce683d 100644 --- a/internal/proxy/entry.go +++ b/internal/proxy/entry.go @@ -11,17 +11,19 @@ import ( net "github.com/yusing/go-proxy/internal/net/types" T "github.com/yusing/go-proxy/internal/proxy/fields" "github.com/yusing/go-proxy/internal/types" + "github.com/yusing/go-proxy/internal/watcher/health" ) type ( ReverseProxyEntry struct { // real model after validation - Alias T.Alias `json:"alias"` - Scheme T.Scheme `json:"scheme"` - URL net.URL `json:"url"` - NoTLSVerify bool `json:"no_tls_verify"` - PathPatterns T.PathPatterns `json:"path_patterns"` - LoadBalance loadbalancer.Config `json:"load_balance"` - Middlewares D.NestedLabelMap `json:"middlewares"` + Alias T.Alias `json:"alias"` + Scheme T.Scheme `json:"scheme"` + URL net.URL `json:"url"` + NoTLSVerify bool `json:"no_tls_verify"` + PathPatterns T.PathPatterns `json:"path_patterns"` + HealthCheck health.HealthCheckConfig `json:"healthcheck"` + LoadBalance loadbalancer.Config `json:"load_balance"` + Middlewares D.NestedLabelMap `json:"middlewares"` /* Docker only */ IdleTimeout time.Duration `json:"idle_timeout"` @@ -35,10 +37,11 @@ type ( ContainerRunning bool `json:"container_running"` } StreamEntry struct { - Alias T.Alias `json:"alias"` - Scheme T.StreamScheme `json:"scheme"` - Host T.Host `json:"host"` - Port T.StreamPort `json:"port"` + Alias T.Alias `json:"alias"` + Scheme T.StreamScheme `json:"scheme"` + Host T.Host `json:"host"` + Port T.StreamPort `json:"port"` + Healthcheck health.HealthCheckConfig `json:"healthcheck"` } ) @@ -58,7 +61,7 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) { m.FillMissingFields() scheme, err := T.NewScheme(m.Scheme) - if err.HasError() { + if err != nil { return nil, err } @@ -69,7 +72,7 @@ func ValidateEntry(m *types.RawEntry) (any, E.NestedError) { } else { entry = validateRPEntry(m, scheme, e) } - if err := e.Build(); err.HasError() { + if err := e.Build(); err != nil { return nil, err } return entry, nil @@ -107,7 +110,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn stopSignal, err := T.ValidateSignal(m.StopSignal) b.Add(err) - if err.HasError() { + if err != nil { return nil } @@ -117,6 +120,7 @@ func validateRPEntry(m *types.RawEntry, s T.Scheme, b E.Builder) *ReverseProxyEn URL: net.NewURL(url), NoTLSVerify: m.NoTLSVerify, PathPatterns: pathPatterns, + HealthCheck: m.HealthCheck, LoadBalance: m.LoadBalance, Middlewares: m.Middlewares, IdleTimeout: idleTimeout, @@ -146,9 +150,10 @@ func validateStreamEntry(m *types.RawEntry, b E.Builder) *StreamEntry { } return &StreamEntry{ - Alias: T.NewAlias(m.Alias), - Scheme: *scheme, - Host: host, - Port: port, + Alias: T.NewAlias(m.Alias), + Scheme: *scheme, + Host: host, + Port: port, + Healthcheck: m.HealthCheck, } } diff --git a/internal/proxy/fields/path_mode.go b/internal/proxy/fields/path_mode.go deleted file mode 100644 index 4f8f4dab..00000000 --- a/internal/proxy/fields/path_mode.go +++ /dev/null @@ -1,24 +0,0 @@ -package fields - -import ( - E "github.com/yusing/go-proxy/internal/error" -) - -type PathMode string - -func NewPathMode(pm string) (PathMode, E.NestedError) { - switch pm { - case "", "forward": - return PathMode(pm), nil - default: - return "", E.Invalid("path mode", pm) - } -} - -func (p PathMode) IsRemove() bool { - return p == "" -} - -func (p PathMode) IsForward() bool { - return p == "forward" -} diff --git a/internal/proxy/fields/path_pattern.go b/internal/proxy/fields/path_pattern.go index 5f9b8394..0a42ce5d 100644 --- a/internal/proxy/fields/path_pattern.go +++ b/internal/proxy/fields/path_pattern.go @@ -13,7 +13,7 @@ type ( var pathPattern = regexp.MustCompile(`^(/[-\w./]*({\$\})?|((GET|POST|DELETE|PUT|HEAD|OPTION) /[-\w./]*({\$\})?))$`) -func NewPathPattern(s string) (PathPattern, E.NestedError) { +func ValidatePathPattern(s string) (PathPattern, E.NestedError) { if len(s) == 0 { return "", E.Invalid("path", "must not be empty") } @@ -29,7 +29,7 @@ func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) { } pp := make(PathPatterns, len(s)) for i, v := range s { - pattern, err := NewPathPattern(v) + pattern, err := ValidatePathPattern(v) if err != nil { return nil, err } diff --git a/internal/proxy/fields/path_pattern_test.go b/internal/proxy/fields/path_pattern_test.go index 0d2444ee..d19cb972 100644 --- a/internal/proxy/fields/path_pattern_test.go +++ b/internal/proxy/fields/path_pattern_test.go @@ -37,11 +37,11 @@ var invalidPatterns = []string{ func TestPathPatternRegex(t *testing.T) { for _, pattern := range validPatterns { - _, err := NewPathPattern(pattern) + _, err := ValidatePathPattern(pattern) U.ExpectNoError(t, err.Error()) } for _, pattern := range invalidPatterns { - _, err := NewPathPattern(pattern) + _, err := ValidatePathPattern(pattern) U.ExpectError2(t, pattern, E.ErrInvalid, err.Error()) } } diff --git a/internal/proxy/provider/docker.go b/internal/proxy/provider/docker.go index 601dc44f..f493c06a 100755 --- a/internal/proxy/provider/docker.go +++ b/internal/proxy/provider/docker.go @@ -46,7 +46,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { entries := types.NewProxyEntries() info, err := D.GetClientInfo(p.dockerHost, true) - if err.HasError() { + if err != nil { return routes, E.FailWith("connect to docker", err) } @@ -59,7 +59,7 @@ func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { } newEntries, err := p.entriesFromContainerLabels(container) - if err.HasError() { + if err != nil { errors.Add(err) } // although err is not nil @@ -98,9 +98,9 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul b := E.NewBuilder("event %s error", event) defer b.To(&res.err) - routes.RangeAll(func(k string, v R.Route) { - if v.Entry().ContainerID == event.ActorID || - v.Entry().ContainerName == event.ActorName { + routes.RangeAll(func(k string, v *R.Route) { + if v.Entry.ContainerID == event.ActorID || + v.Entry.ContainerName == event.ActorName { b.Add(v.Stop()) routes.Delete(k) res.nRemoved++ @@ -115,7 +115,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul b.Add(E.FailWith("rescan routes", err)) return } - routesNew.Range(func(k string, v R.Route) bool { + routesNew.Range(func(k string, v *R.Route) bool { if !routesOld.Has(k) { routesOld.Store(k, v) b.Add(v.Start()) @@ -124,7 +124,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul } return true }) - routesOld.Range(func(k string, v R.Route) bool { + routesOld.Range(func(k string, v *R.Route) bool { if !routesNew.Has(k) { b.Add(v.Stop()) routesOld.Delete(k) @@ -137,13 +137,13 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul } client, err := D.ConnectClient(p.dockerHost) - if err.HasError() { + if err != nil { b.Add(E.FailWith("connect to docker", err)) return } defer client.Close() cont, err := client.Inspect(event.ActorID) - if err.HasError() { + if err != nil { b.Add(E.FailWith("inspect container", err)) return } @@ -159,7 +159,7 @@ func (p *DockerProvider) OnEvent(event W.Event, routes R.Routes) (res EventResul if routes.Has(alias) { b.Add(E.Duplicated("alias", alias)) } else { - if route, err := R.NewRoute(entry); err.HasError() { + if route, err := R.NewRoute(entry); err != nil { b.Add(err) } else { routes.Store(alias, route) @@ -221,7 +221,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt } lbl, err := D.ParseLabel(key, val) - if err.HasError() { + if err != nil { b.Add(err.Subject(key)) } if lbl.Namespace != D.NSProxy { @@ -230,7 +230,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt if lbl.Target == D.WildcardAlias { // apply label for all aliases entries.RangeAll(func(a string, e *types.RawEntry) { - if err = D.ApplyLabel(e, lbl); err.HasError() { + if err = D.ApplyLabel(e, lbl); err != nil { b.Add(err) } }) @@ -249,7 +249,7 @@ func (p *DockerProvider) applyLabel(container *D.Container, entries types.RawEnt b.Add(E.NotExist("alias", lbl.Target)) return } - if err = D.ApplyLabel(config, lbl); err.HasError() { + if err = D.ApplyLabel(config, lbl); err != nil { b.Add(err) } } diff --git a/internal/proxy/provider/docker_test.go b/internal/proxy/provider/docker_test.go index 6b107a6f..90a424fb 100644 --- a/internal/proxy/provider/docker_test.go +++ b/internal/proxy/provider/docker_test.go @@ -15,8 +15,10 @@ import ( . "github.com/yusing/go-proxy/internal/utils/testing" ) -var dummyNames = []string{"/a"} -var p DockerProvider +var ( + dummyNames = []string{"/a"} + p DockerProvider +) func TestApplyLabelWildcard(t *testing.T) { pathPatterns := ` diff --git a/internal/proxy/provider/file.go b/internal/proxy/provider/file.go index e3df8bde..5a67ebc4 100644 --- a/internal/proxy/provider/file.go +++ b/internal/proxy/provider/file.go @@ -47,19 +47,21 @@ func (p FileProvider) OnEvent(event W.Event, routes R.Routes) (res EventResult) defer b.To(&res.err) newRoutes, err := p.LoadRoutesImpl() - if err.HasError() { + if err != nil { b.Add(err) return } - routes.RangeAllParallel(func(_ string, v R.Route) { + res.nRemoved = newRoutes.Size() + routes.RangeAllParallel(func(_ string, v *R.Route) { b.Add(v.Stop()) }) routes.Clear() - newRoutes.RangeAllParallel(func(_ string, v R.Route) { + newRoutes.RangeAllParallel(func(_ string, v *R.Route) { b.Add(v.Start()) }) + res.nAdded = newRoutes.Size() routes.MergeFrom(newRoutes) return @@ -74,12 +76,12 @@ func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) { entries := types.NewProxyEntries() data, err := E.Check(os.ReadFile(p.path)) - if err.HasError() { + if err != nil { b.Add(E.FailWith("read file", err)) return } - if err = entries.UnmarshalFromYAML(data); err.HasError() { + if err = entries.UnmarshalFromYAML(data); err != nil { b.Add(err) return } diff --git a/internal/proxy/provider/provider.go b/internal/proxy/provider/provider.go index 4585f369..62e2f76d 100644 --- a/internal/proxy/provider/provider.go +++ b/internal/proxy/provider/provider.go @@ -111,7 +111,7 @@ func (p *Provider) StartAllRoutes() (res E.NestedError) { // start watcher no matter load success or not go p.watchEvents() - p.routes.RangeAllParallel(func(alias string, r R.Route) { + p.routes.RangeAllParallel(func(alias string, r *R.Route) { errors.Add(r.Start().Subject(r)) }) return @@ -126,17 +126,17 @@ func (p *Provider) StopAllRoutes() (res E.NestedError) { errors := E.NewBuilder("errors stopping routes") defer errors.To(&res) - p.routes.RangeAllParallel(func(alias string, r R.Route) { + p.routes.RangeAllParallel(func(alias string, r *R.Route) { errors.Add(r.Stop().Subject(r)) }) return } -func (p *Provider) RangeRoutes(do func(string, R.Route)) { +func (p *Provider) RangeRoutes(do func(string, *R.Route)) { p.routes.RangeAll(do) } -func (p *Provider) GetRoute(alias string) (R.Route, bool) { +func (p *Provider) GetRoute(alias string) (*R.Route, bool) { return p.routes.Load(alias) } @@ -156,11 +156,11 @@ func (p *Provider) LoadRoutes() E.NestedError { func (p *Provider) Statistics() ProviderStats { numRPs := 0 numStreams := 0 - p.routes.RangeAll(func(_ string, r R.Route) { + p.routes.RangeAll(func(_ string, r *R.Route) { if !r.Started() { return } - switch r.Type() { + switch r.Type { case R.RouteTypeReverseProxy: numRPs++ case R.RouteTypeStream: @@ -187,9 +187,17 @@ func (p *Provider) watchEvents() { res := p.OnEvent(event, p.routes) l.Infof("%s event %q", event.Type, event) if res.nAdded > 0 || res.nRemoved > 0 { - l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved) + n := res.nAdded - res.nRemoved + switch { + case n == 0: + l.Infof("%d route(s) reloaded", res.nAdded) + case n > 0: + l.Infof("%d route(s) added", n) + default: + l.Infof("%d route(s) removed", -n) + } } - if res.err.HasError() { + if res.err != nil { l.Error(res.err) } case err := <-errs: diff --git a/internal/route/http.go b/internal/route/http.go index 76719e4d..e36a4d51 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -1,6 +1,7 @@ package route import ( + "context" "errors" "fmt" "net/http" @@ -14,9 +15,11 @@ import ( gphttp "github.com/yusing/go-proxy/internal/net/http" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" "github.com/yusing/go-proxy/internal/net/http/middleware" + url "github.com/yusing/go-proxy/internal/net/types" P "github.com/yusing/go-proxy/internal/proxy" PT "github.com/yusing/go-proxy/internal/proxy/fields" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/watcher/health" ) type ( @@ -24,9 +27,10 @@ type ( *P.ReverseProxyEntry LoadBalancer *loadbalancer.LoadBalancer `json:"load_balancer"` - server *loadbalancer.Server - handler http.Handler - rp *gphttp.ReverseProxy + healthMon health.HealthMonitor + server *loadbalancer.Server + handler http.Handler + rp *gphttp.ReverseProxy } SubdomainKey = PT.Alias @@ -65,7 +69,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { trans = gphttp.DefaultTransport.Clone() } - rp := gphttp.NewReverseProxy(entry.URL, trans) + rp := gphttp.NewReverseProxy(string(entry.Alias), entry.URL, trans) if len(entry.Middlewares) > 0 { err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares) @@ -81,6 +85,18 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { ReverseProxyEntry: entry, rp: rp, } + if entry.LoadBalance.Link != "" && entry.HealthCheck.Disabled { + logrus.Warnf("%s.healthCheck.disabled cannot be false when loadbalancer is enabled", entry.Alias) + entry.HealthCheck.Disabled = true + } + if !entry.HealthCheck.Disabled { + r.healthMon = health.NewHTTPHealthMonitor( + context.Background(), + string(entry.Alias), + entry.URL, + entry.HealthCheck, + ) + } return r, nil } @@ -88,6 +104,10 @@ func (r *HTTPRoute) String() string { return string(r.Alias) } +func (r *HTTPRoute) URL() url.URL { + return r.ReverseProxyEntry.URL +} + func (r *HTTPRoute) Start() E.NestedError { if r.handler != nil { return nil @@ -118,24 +138,13 @@ func (r *HTTPRoute) Start() E.NestedError { if r.LoadBalance.Link == "" { httpRoutes.Store(string(r.Alias), r) - return nil + } else { + r.addToLoadBalancer() } - var lb *loadbalancer.LoadBalancer - linked, ok := httpRoutes.Load(r.LoadBalance.Link) - if ok { - lb = linked.LoadBalancer - } else { - lb = loadbalancer.New(r.LoadBalance) - lb.Start() - linked = &HTTPRoute{ - LoadBalancer: lb, - handler: lb, - } - httpRoutes.Store(r.LoadBalance.Link, linked) + if r.healthMon != nil { + r.healthMon.Start() } - r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler) - lb.AddServer(r.server) return nil } @@ -164,6 +173,10 @@ func (r *HTTPRoute) Stop() (_ E.NestedError) { httpRoutes.Delete(string(r.Alias)) } + if r.healthMon != nil { + r.healthMon.Stop() + } + r.handler = nil return @@ -173,8 +186,30 @@ func (r *HTTPRoute) Started() bool { return r.handler != nil } +func (r *HTTPRoute) addToLoadBalancer() { + var lb *loadbalancer.LoadBalancer + linked, ok := httpRoutes.Load(r.LoadBalance.Link) + if ok { + lb = linked.LoadBalancer + } else { + lb = loadbalancer.New(r.LoadBalance) + lb.Start() + linked = &HTTPRoute{ + LoadBalancer: lb, + handler: lb, + } + httpRoutes.Store(r.LoadBalance.Link, linked) + } + r.server = loadbalancer.NewServer(string(r.Alias), r.rp.TargetURL, r.LoadBalance.Weight, r.handler, r.healthMon) + lb.AddServer(r.server) +} + func ProxyHandler(w http.ResponseWriter, r *http.Request) { mux, err := findMuxFunc(r.Host) + // Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway? + // On nginx, when route for domain does not exist, it returns StatusBadGateway. + // Then scraper / scanners will know the subdomain is invalid. + // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. if err != nil { if !middleware.ServeStaticErrorPageFile(w, r) { logrus.Error(E.Failure("request"). diff --git a/internal/route/route.go b/internal/route/route.go index 756e7680..c7180232 100755 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -1,35 +1,30 @@ package route import ( - "fmt" - "net/url" - E "github.com/yusing/go-proxy/internal/error" + url "github.com/yusing/go-proxy/internal/net/types" P "github.com/yusing/go-proxy/internal/proxy" "github.com/yusing/go-proxy/internal/types" + U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" ) type ( - Route interface { - RouteImpl - Entry() *types.RawEntry - Type() RouteType - URL() *url.URL + RouteType string + Route struct { + _ U.NoCopy + impl + Type RouteType + Entry *types.RawEntry } - Routes = F.Map[string, Route] + Routes = F.Map[string, *Route] - RouteImpl interface { + impl interface { Start() E.NestedError Stop() E.NestedError Started() bool String() string - } - RouteType string - route struct { - RouteImpl - type_ RouteType - entry *types.RawEntry + URL() url.URL } ) @@ -38,44 +33,36 @@ const ( RouteTypeReverseProxy RouteType = "reverse_proxy" ) -// function alias -var NewRoutes = F.NewMapOf[string, Route] +// function alias. +var NewRoutes = F.NewMap[Routes] -func NewRoute(en *types.RawEntry) (Route, E.NestedError) { +func NewRoute(en *types.RawEntry) (*Route, E.NestedError) { entry, err := P.ValidateEntry(en) if err != nil { return nil, err } var t RouteType - var rt RouteImpl + var rt impl + switch e := entry.(type) { case *P.StreamEntry: - rt, err = NewStreamRoute(e) t = RouteTypeStream + rt, err = NewStreamRoute(e) case *P.ReverseProxyEntry: - rt, err = NewHTTPRoute(e) t = RouteTypeReverseProxy + rt, err = NewHTTPRoute(e) default: panic("bug: should not reach here") } if err != nil { return nil, err } - return &route{RouteImpl: rt, entry: en, type_: t}, nil -} - -func (rt *route) Entry() *types.RawEntry { - return rt.entry -} - -func (rt *route) Type() RouteType { - return rt.type_ -} - -func (rt *route) URL() *url.URL { - url, _ := url.Parse(fmt.Sprintf("%s://%s:%s", rt.entry.Scheme, rt.entry.Host, rt.entry.Port)) - return url + return &Route{ + impl: rt, + Type: t, + Entry: en, + }, nil } func FromEntries(entries types.RawEntries) (Routes, E.NestedError) { @@ -85,7 +72,7 @@ func FromEntries(entries types.RawEntries) (Routes, E.NestedError) { entries.RangeAll(func(alias string, entry *types.RawEntry) { entry.Alias = alias r, err := NewRoute(entry) - if err.HasError() { + if err != nil { b.Add(err.Subject(alias)) } else { routes.Store(alias, r) diff --git a/internal/route/stream.go b/internal/route/stream.go index 2f7b174a..36c3976e 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -10,14 +10,19 @@ import ( "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/internal/error" + url "github.com/yusing/go-proxy/internal/net/types" P "github.com/yusing/go-proxy/internal/proxy" PT "github.com/yusing/go-proxy/internal/proxy/fields" + "github.com/yusing/go-proxy/internal/watcher/health" ) type StreamRoute struct { *P.StreamEntry StreamImpl `json:"-"` + url url.URL + healthMon health.HealthMonitor + wg sync.WaitGroup ctx context.Context cancel context.CancelFunc @@ -40,8 +45,14 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { if !entry.Scheme.IsCoherent() { return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) } + url, err := url.ParseURL(fmt.Sprintf("%s://%s:%d", entry.Scheme.ProxyScheme, entry.Host, entry.Port.ProxyPort)) + if err != nil { + // !! should not happen + panic(err) + } base := &StreamRoute{ StreamEntry: entry, + url: url, connCh: make(chan any, 100), } if entry.Scheme.ListeningScheme.IsTCP() { @@ -49,6 +60,9 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { } else { base.StreamImpl = NewUDPRoute(base) } + if !entry.Healthcheck.Disabled { + base.healthMon = health.NewRawHealthMonitor(base.ctx, string(entry.Alias), url, entry.Healthcheck) + } base.l = logrus.WithField("route", base.StreamImpl) return base, nil } @@ -57,6 +71,10 @@ func (r *StreamRoute) String() string { return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias) } +func (r *StreamRoute) URL() url.URL { + return r.url +} + func (r *StreamRoute) Start() E.NestedError { if r.Port.ProxyPort == PT.NoPort || r.started.Load() { return nil @@ -71,6 +89,9 @@ func (r *StreamRoute) Start() E.NestedError { r.wg.Add(2) go r.grAcceptConnections() go r.grHandleConnections() + if r.healthMon != nil { + r.healthMon.Start() + } return nil } @@ -78,7 +99,12 @@ func (r *StreamRoute) Stop() E.NestedError { if !r.started.Load() { return nil } - l := r.l + r.started.Store(false) + + if r.healthMon != nil { + r.healthMon.Stop() + } + r.cancel() r.CloseListeners() @@ -92,7 +118,7 @@ func (r *StreamRoute) Stop() E.NestedError { for { select { case <-done: - l.Debug("stopped listening") + r.l.Debug("stopped listening") return nil case <-timeout: return E.FailedWhy("stop", "timed out") diff --git a/internal/route/udp.go b/internal/route/udp.go index e3cde373..b74a2eef 100755 --- a/internal/route/udp.go +++ b/internal/route/udp.go @@ -27,7 +27,7 @@ type ( UDPConnMap = F.Map[string, *UDPConn] ) -var NewUDPConnMap = F.NewMapOf[string, *UDPConn] +var NewUDPConnMap = F.NewMap[UDPConnMap] func NewUDPRoute(base *StreamRoute) StreamImpl { return &UDPRoute{ diff --git a/internal/types/raw_entry.go b/internal/types/raw_entry.go index 405d8906..c70fc226 100644 --- a/internal/types/raw_entry.go +++ b/internal/types/raw_entry.go @@ -5,11 +5,12 @@ import ( "strings" "github.com/yusing/go-proxy/internal/common" - D "github.com/yusing/go-proxy/internal/docker" - H "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/docker" + "github.com/yusing/go-proxy/internal/homepage" "github.com/yusing/go-proxy/internal/net/http/loadbalancer" U "github.com/yusing/go-proxy/internal/utils" F "github.com/yusing/go-proxy/internal/utils/functional" + "github.com/yusing/go-proxy/internal/watcher/health" ) type ( @@ -18,18 +19,19 @@ type ( // raw entry object before validation // loaded from docker labels or yaml file - Alias string `json:"-" yaml:"-"` - Scheme string `json:"scheme" yaml:"scheme"` - Host string `json:"host" yaml:"host"` - Port string `json:"port" yaml:"port"` - NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only - PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only - LoadBalance loadbalancer.Config `json:"load_balance" yaml:"load_balance"` - Middlewares D.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"` - Homepage *H.HomePageItem `json:"homepage,omitempty" yaml:"homepage"` + Alias string `json:"-" yaml:"-"` + Scheme string `json:"scheme" yaml:"scheme"` + Host string `json:"host" yaml:"host"` + Port string `json:"port" yaml:"port"` + NoTLSVerify bool `json:"no_tls_verify,omitempty" yaml:"no_tls_verify"` // https proxy only + PathPatterns []string `json:"path_patterns,omitempty" yaml:"path_patterns"` // http(s) proxy only + HealthCheck health.HealthCheckConfig `json:"healthcheck,omitempty" yaml:"healthcheck"` + LoadBalance loadbalancer.Config `json:"load_balance,omitempty" yaml:"load_balance"` + Middlewares docker.NestedLabelMap `json:"middlewares,omitempty" yaml:"middlewares"` + Homepage *homepage.Item `json:"homepage,omitempty" yaml:"homepage"` /* Docker only */ - *D.Container `json:"container" yaml:"-"` + *docker.Container `json:"container" yaml:"-"` } RawEntries = F.Map[string, *RawEntry] @@ -40,7 +42,7 @@ var NewProxyEntries = F.NewMapOf[string, *RawEntry] func (e *RawEntry) FillMissingFields() { isDocker := e.Container != nil if !isDocker { - e.Container = &D.Container{} + e.Container = &docker.Container{} } if e.Host == "" { @@ -113,6 +115,9 @@ func (e *RawEntry) FillMissingFields() { } } + if e.HealthCheck.Interval == 0 { + e.HealthCheck.Interval = common.HealthCheckIntervalDefault + } if e.IdleTimeout == "" { e.IdleTimeout = common.IdleTimeoutDefault } diff --git a/internal/utils/functional/map.go b/internal/utils/functional/map.go index 2c0dfe73..7cd71e95 100644 --- a/internal/utils/functional/map.go +++ b/internal/utils/functional/map.go @@ -24,6 +24,10 @@ func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) { return } +func NewMap[MapType Map[KT, VT], KT comparable, VT any]() Map[KT, VT] { + return NewMapOf[KT, VT]() +} + // MapFind iterates over the map and returns the first value // that satisfies the given criteria. The iteration is stopped // once a value is found. If no value satisfies the criteria, @@ -161,7 +165,7 @@ func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError { return E.FailedWhy("unmarshal from yaml", "map is not empty") } tmp := make(map[KT]VT) - if err := E.From(yaml.Unmarshal(data, tmp)); err.HasError() { + if err := E.From(yaml.Unmarshal(data, tmp)); err != nil { return err } for k, v := range tmp { diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 7a8f0b99..577b69c1 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -8,6 +8,7 @@ import ( "reflect" "strconv" "strings" + "time" "unicode" "github.com/santhosh-tekuri/jsonschema" @@ -18,7 +19,7 @@ import ( type ( SerializedObject = map[string]any Converter interface { - ConvertFrom(value any) (any, E.NestedError) + ConvertFrom(value any) E.NestedError } ) @@ -264,23 +265,10 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError { var ok bool // check if (*T).Convertor is implemented if converter, ok = dst.Addr().Interface().(Converter); !ok { - // check if (T).Convertor is implemented - converter, ok = dst.Interface().(Converter) - if !ok { - return E.TypeError("conversion", srcT, dstT) - } + return E.TypeError("conversion", srcT, dstT) } - converted, err := converter.ConvertFrom(src.Interface()) - if err != nil { - return err - } - c := reflect.ValueOf(converted) - if c.Kind() == reflect.Ptr { - c = c.Elem() - } - dst.Set(c) - return nil + return converter.ConvertFrom(src.Interface()) } func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.NestedError) { @@ -295,6 +283,20 @@ func ConvertString(src string, dst reflect.Value) (convertible bool, convErr E.N dst.SetString(src) return } + switch dst.Type() { + case reflect.TypeFor[time.Duration](): + if src == "" { + dst.Set(reflect.Zero(dst.Type())) + return + } + d, err := time.ParseDuration(src) + if err != nil { + convErr = E.Invalid("duration", src) + return + } + dst.Set(reflect.ValueOf(d)) + return + } // primitive types / simple types switch dst.Kind() { case reflect.Bool: diff --git a/internal/utils/serialization_test.go b/internal/utils/serialization_test.go index 00213e5e..a87499d6 100644 --- a/internal/utils/serialization_test.go +++ b/internal/utils/serialization_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + E "github.com/yusing/go-proxy/internal/error" . "github.com/yusing/go-proxy/internal/utils/testing" ) @@ -102,3 +103,48 @@ func TestStringIntConvert(t *testing.T) { ExpectNoError(t, err.Error()) ExpectEqual(t, test.u64, uint64(127)) } + +type testModel struct { + Test testType +} + +type testType struct { + foo int + bar string +} + +func (c *testType) ConvertFrom(v any) E.NestedError { + switch v := v.(type) { + case string: + c.bar = v + return nil + case int: + c.foo = v + return nil + default: + return E.Invalid("input type", v) + } +} + +func TestConvertor(t *testing.T) { + t.Run("string", func(t *testing.T) { + m := new(testModel) + ExpectNoError(t, Deserialize(map[string]any{"Test": "bar"}, m).Error()) + + ExpectEqual(t, m.Test.foo, 0) + ExpectEqual(t, m.Test.bar, "bar") + }) + + t.Run("int", func(t *testing.T) { + m := new(testModel) + ExpectNoError(t, Deserialize(map[string]any{"Test": 123}, m).Error()) + + ExpectEqual(t, m.Test.foo, 123) + ExpectEqual(t, m.Test.bar, "") + }) + + t.Run("invalid", func(t *testing.T) { + m := new(testModel) + ExpectError(t, E.ErrInvalid, Deserialize(map[string]any{"Test": 123.456}, m).Error()) + }) +} diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index 8a9d9e48..46d833b5 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -26,6 +26,14 @@ type DirWatcher struct { ctx context.Context } +// NewDirectoryWatcher returns a DirWatcher instance. +// +// The DirWatcher watches the given directory for file system events. +// Currently, only events on files directly in the given directory are watched, not +// recursively. +// +// Note that the returned DirWatcher is not ready to use until the goroutine +// started by NewDirectoryWatcher has finished. func NewDirectoryWatcher(ctx context.Context, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() @@ -70,16 +78,8 @@ func (h *DirWatcher) Add(relPath string) Watcher { close(s.eventCh) close(s.errCh) }() - for { - select { - case <-h.ctx.Done(): - return - case _, ok := <-h.eventCh: - if !ok { // directory watcher closed - return - } - } - } + <-h.ctx.Done() + logrus.Debugf("file watcher %s stopped", relPath) }() h.fwMap.Store(relPath, s) return s @@ -88,6 +88,7 @@ func (h *DirWatcher) Add(relPath string) Watcher { func (h *DirWatcher) start() { defer close(h.eventCh) defer h.w.Close() + defer logrus.Debugf("directory watcher %s stopped", h.dir) for { select { @@ -121,7 +122,9 @@ func (h *DirWatcher) start() { // send event to directory watcher select { case h.eventCh <- msg: + logrus.Debugf("sent event to directory watcher %s", h.dir) default: + logrus.Debugf("failed to send event to directory watcher %s", h.dir) } // send event to file watcher too @@ -129,8 +132,12 @@ func (h *DirWatcher) start() { if ok { select { case w.eventCh <- msg: + logrus.Debugf("sent event to file watcher %s", relPath) default: + logrus.Debugf("failed to send event to file watcher %s", relPath) } + } else { + logrus.Debugf("file watcher not found: %s", relPath) } case err := <-h.w.Errors: if errors.Is(err, fsnotify.ErrClosed) { diff --git a/internal/watcher/health/healthcheck_config.go b/internal/watcher/health/healthcheck_config.go new file mode 100644 index 00000000..86a512ec --- /dev/null +++ b/internal/watcher/health/healthcheck_config.go @@ -0,0 +1,22 @@ +package health + +import ( + "time" + + "github.com/yusing/go-proxy/internal/common" +) + +type HealthCheckConfig struct { + Disabled bool `json:"disabled" yaml:"disabled"` + Path string `json:"path" yaml:"path"` + UseGet bool `json:"use_get" yaml:"use_get"` + Interval time.Duration `json:"interval" yaml:"interval"` + Timeout time.Duration `json:"timeout" yaml:"timeout"` +} + +func DefaultHealthCheckConfig() HealthCheckConfig { + return HealthCheckConfig{ + Interval: common.HealthCheckIntervalDefault, + Timeout: common.HealthCheckTimeoutDefault, + } +} diff --git a/internal/watcher/health/http.go b/internal/watcher/health/http.go new file mode 100644 index 00000000..f0ca2184 --- /dev/null +++ b/internal/watcher/health/http.go @@ -0,0 +1,63 @@ +package health + +import ( + "context" + "crypto/tls" + "errors" + "net/http" + + "github.com/yusing/go-proxy/internal/net/types" +) + +type HTTPHealthMonitor struct { + *monitor + method string + pinger *http.Client +} + +func NewHTTPHealthMonitor(ctx context.Context, name string, url types.URL, config HealthCheckConfig) HealthMonitor { + mon := new(HTTPHealthMonitor) + mon.monitor = newMonitor(ctx, name, url, &config, mon.checkHealth) + mon.pinger = &http.Client{Timeout: config.Timeout} + if config.UseGet { + mon.method = http.MethodGet + } else { + mon.method = http.MethodHead + } + return mon +} + +func (mon *HTTPHealthMonitor) checkHealth() (healthy bool, detail string, err error) { + req, reqErr := http.NewRequestWithContext( + mon.ctx, + mon.method, + mon.URL.String(), + nil, + ) + if reqErr != nil { + err = reqErr + return + } + req.Header.Set("Connection", "close") + + resp, respErr := mon.pinger.Do(req) + if respErr == nil { + resp.Body.Close() + } + + switch { + case respErr != nil: + // treat tls error as healthy + var tlsErr *tls.CertificateVerificationError + if ok := errors.As(respErr, &tlsErr); !ok { + detail = respErr.Error() + return + } + case resp.StatusCode == http.StatusServiceUnavailable: + detail = resp.Status + return + } + + healthy = true + return +} diff --git a/internal/watcher/health/logger.go b/internal/watcher/health/logger.go new file mode 100644 index 00000000..171f4a5e --- /dev/null +++ b/internal/watcher/health/logger.go @@ -0,0 +1,5 @@ +package health + +import "github.com/sirupsen/logrus" + +var logger = logrus.WithField("module", "health_mon") diff --git a/internal/watcher/health/monitor.go b/internal/watcher/health/monitor.go new file mode 100644 index 00000000..32a4c6d3 --- /dev/null +++ b/internal/watcher/health/monitor.go @@ -0,0 +1,139 @@ +package health + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/yusing/go-proxy/internal/net/types" + F "github.com/yusing/go-proxy/internal/utils/functional" +) + +type ( + HealthMonitor interface { + Start() + Stop() + IsHealthy() bool + String() string + } + HealthCheckFunc func() (healthy bool, detail string, err error) + monitor struct { + Name string + URL types.URL + Interval time.Duration + + healthy atomic.Bool + checkHealth HealthCheckFunc + + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + mu sync.Mutex + } +) + +var monMap = F.NewMapOf[string, HealthMonitor]() + +func newMonitor(parentCtx context.Context, name string, url types.URL, config *HealthCheckConfig, healthCheckFunc HealthCheckFunc) *monitor { + if parentCtx == nil { + parentCtx = context.Background() + } + ctx, cancel := context.WithCancel(parentCtx) + mon := &monitor{ + Name: name, + URL: url.JoinPath(config.Path), + Interval: config.Interval, + checkHealth: healthCheckFunc, + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + } + mon.healthy.Store(true) + monMap.Store(name, mon) + return mon +} + +func IsHealthy(name string) (healthy bool, ok bool) { + mon, ok := monMap.Load(name) + if !ok { + return + } + return mon.IsHealthy(), true +} + +func (mon *monitor) Start() { + go func() { + defer close(mon.done) + + ok := mon.checkUpdateHealth() + if !ok { + return + } + + ticker := time.NewTicker(mon.Interval) + defer ticker.Stop() + + for { + select { + case <-mon.ctx.Done(): + return + case <-ticker.C: + ok = mon.checkUpdateHealth() + if !ok { + return + } + } + } + }() + logger.Debugf("health monitor %q started", mon) +} + +func (mon *monitor) Stop() { + defer logger.Debugf("health monitor %q stopped", mon) + + monMap.Delete(mon.Name) + + mon.mu.Lock() + defer mon.mu.Unlock() + + if mon.cancel == nil { + return + } + + mon.cancel() + <-mon.done + + mon.cancel = nil +} + +func (mon *monitor) IsHealthy() bool { + return mon.healthy.Load() +} + +func (mon *monitor) String() string { + return mon.Name +} + +func (mon *monitor) checkUpdateHealth() (hasError bool) { + healthy, detail, err := mon.checkHealth() + if err != nil { + mon.healthy.Store(false) + if !errors.Is(err, context.Canceled) { + logger.Errorf("server %q failed to check health: %s", mon, err) + } + mon.Stop() + return false + } + if healthy != mon.healthy.Swap(healthy) { + if healthy { + logger.Infof("server %q is up", mon) + } else { + logger.Warnf("server %q is down: %s", mon, detail) + } + } + + return true +} diff --git a/internal/watcher/health/raw.go b/internal/watcher/health/raw.go new file mode 100644 index 00000000..4990d870 --- /dev/null +++ b/internal/watcher/health/raw.go @@ -0,0 +1,37 @@ +package health + +import ( + "context" + "net" + + "github.com/yusing/go-proxy/internal/net/types" +) + +type ( + RawHealthMonitor struct { + *monitor + dialer *net.Dialer + } +) + +func NewRawHealthMonitor(ctx context.Context, name string, url types.URL, config HealthCheckConfig) HealthMonitor { + mon := new(RawHealthMonitor) + mon.monitor = newMonitor(ctx, name, url, &config, mon.checkAvail) + mon.dialer = &net.Dialer{ + Timeout: config.Timeout, + FallbackDelay: -1, + } + return mon +} + +func (mon *RawHealthMonitor) checkAvail() (avail bool, detail string, err error) { + conn, dialErr := mon.dialer.DialContext(mon.ctx, mon.URL.Scheme, mon.URL.Host) + if dialErr != nil { + detail = dialErr.Error() + /* trunk-ignore(golangci-lint/nilerr) */ + return + } + conn.Close() + avail = true + return +} diff --git a/schema/providers.schema.json b/schema/providers.schema.json index dcd4beeb..8a23e7e3 100644 --- a/schema/providers.schema.json +++ b/schema/providers.schema.json @@ -116,6 +116,61 @@ "type": "object" } } + }, + "load_balance": { + "type": "object", + "properties": { + "link": { + "type": "string", + "description": "Name and subdomain of load-balancer", + "format": "uri" + }, + "mode": { + "enum": [ + "round_robin", + "least_conn", + "ip_hash" + ], + "description": "Load-balance mode", + "default": "roundrobin" + }, + "weight": { + "type": "integer", + "description": "Reserved for future use", + "minimum": 0, + "maximum": 100 + }, + "options": { + "type": "object", + "description": "load-balance mode specific options" + } + } + }, + "healthcheck": { + "type": "object", + "properties": { + "disabled": { + "type": "boolean", + "default": false + }, + "path": { + "type": "string", + "description": "Healthcheck path", + "default": "/", + "format": "uri" + }, + "use_get": { + "type": "boolean", + "description": "Use GET instead of HEAD", + "default": false + }, + "interval": { + "type": "string", + "description": "Interval for healthcheck (e.g. 5s, 1h25m30s)", + "pattern": "^([0-9]+(ms|s|m|h))+$", + "default": "5s" + } + } } }, "additionalProperties": false,