From faa5b365534072379d60dbe309f1a98698a58931 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 7 Feb 2026 21:24:24 +0800 Subject: [PATCH] fix: add nil guard for entrypoint retrieval; rename AddRoute to StartAddRoute --- internal/entrypoint/routes.go | 45 ++++++++++++++++++------- internal/entrypoint/types/entrypoint.go | 2 +- internal/metrics/uptime/uptime.go | 8 ++++- internal/route/fileserver.go | 2 +- internal/route/reverse_proxy.go | 6 ++-- internal/route/route_test.go | 2 +- internal/route/stream.go | 2 +- 7 files changed, 46 insertions(+), 21 deletions(-) diff --git a/internal/entrypoint/routes.go b/internal/entrypoint/routes.go index 60d53048..bcd5f484 100644 --- a/internal/entrypoint/routes.go +++ b/internal/entrypoint/routes.go @@ -45,7 +45,7 @@ func (ep *Entrypoint) GetRoute(alias string) (types.Route, bool) { return nil, false } -func (ep *Entrypoint) AddRoute(r types.Route) error { +func (ep *Entrypoint) StartAddRoute(r types.Route) error { if r.ShouldExclude() { ep.excludedRoutes.Add(r) r.Task().OnCancel("remove_route", func() { @@ -80,13 +80,9 @@ func (ep *Entrypoint) AddRoute(r types.Route) error { return nil } -// AddHTTPRoute adds a HTTP route to the entrypoint's server. -// -// If the server does not exist, it will be created, started and return any error. -func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error { +func getAddr(route types.HTTPRoute) (httpAddr, httpsAddr string) { if port := route.ListenURL().Port(); port == "" || port == "0" { host := route.ListenURL().Hostname() - var httpAddr, httpsAddr string if host == "" { httpAddr = common.ProxyHTTPAddr httpsAddr = common.ProxyHTTPSAddr @@ -94,10 +90,26 @@ func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error { httpAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPPort)) httpsAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPSPort)) } - return errors.Join(ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP), ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS)) + return httpAddr, httpsAddr } - return ep.addHTTPRoute(route, route.ListenURL().Host, HTTPProtoHTTPS) + httpsAddr = route.ListenURL().Host + return +} + +// AddHTTPRoute adds a HTTP route to the entrypoint's server. +// +// If the server does not exist, it will be created, started and return any error. +func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error { + httpAddr, httpsAddr := getAddr(route) + var httpErr, httpsErr error + if httpAddr != "" { + httpErr = ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP) + } + if httpsAddr != "" { + httpsErr = ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS) + } + return errors.Join(httpErr, httpsErr) } func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error { @@ -117,10 +129,17 @@ func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTT } func (ep *Entrypoint) delHTTPRoute(route types.HTTPRoute) { - addr := route.ListenURL().Host - srv, _ := ep.servers.Load(addr) - if srv != nil { - srv.DelRoute(route) + httpAddr, httpsAddr := getAddr(route) + if httpAddr != "" { + srv, _ := ep.servers.Load(httpAddr) + if srv != nil { + srv.DelRoute(route) + } + } + if httpsAddr != "" { + srv, _ := ep.servers.Load(httpsAddr) + if srv != nil { + srv.DelRoute(route) + } } - // TODO: close if no servers left } diff --git a/internal/entrypoint/types/entrypoint.go b/internal/entrypoint/types/entrypoint.go index a58a66e0..51dc1cb1 100644 --- a/internal/entrypoint/types/entrypoint.go +++ b/internal/entrypoint/types/entrypoint.go @@ -10,7 +10,7 @@ type Entrypoint interface { DisablePoolsLog(v bool) GetRoute(alias string) (types.Route, bool) - AddRoute(r types.Route) error + StartAddRoute(r types.Route) error IterRoutes(yield func(r types.Route) bool) NumRoutes() int RoutesByProvider() map[string][]types.Route diff --git a/internal/metrics/uptime/uptime.go b/internal/metrics/uptime/uptime.go index 1df8ccd3..dba3cda1 100644 --- a/internal/metrics/uptime/uptime.go +++ b/internal/metrics/uptime/uptime.go @@ -2,6 +2,7 @@ package uptime import ( "context" + "errors" "net/url" "slices" "time" @@ -41,8 +42,12 @@ type ( var Poller = period.NewPoller("uptime", getStatuses, aggregateStatuses) func getStatuses(ctx context.Context, _ StatusByAlias) (StatusByAlias, error) { + ep := entrypoint.FromCtx(ctx) + if ep == nil { + return StatusByAlias{}, errors.New("entrypoint not found in context") + } return StatusByAlias{ - Map: entrypoint.FromCtx(ctx).GetHealthInfoWithoutDetail(), + Map: ep.GetHealthInfoWithoutDetail(), Timestamp: time.Now().Unix(), }, nil } @@ -129,6 +134,7 @@ func (rs RouteStatuses) aggregate(limit int, offset int) Aggregated { status := types.StatusUnknown if state := config.ActiveState.Load(); state != nil { + // FIXME: pass ctx to getRoute r, ok := entrypoint.FromCtx(state.Context()).GetRoute(alias) if ok { mon := r.HealthMonitor() diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index b268b9cc..460b73f3 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -133,7 +133,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error { return err } - if err := ep.AddRoute(s); err != nil { + if err := ep.StartAddRoute(s); err != nil { s.task.Finish(err) return gperr.Wrap(err) } diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index b816b67f..dccd4906 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -176,7 +176,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { return gperr.Wrap(err) } } else { - if err := ep.AddRoute(r); err != nil { + if err := ep.StartAddRoute(r); err != nil { r.task.Finish(err) return gperr.Wrap(err) } @@ -195,6 +195,7 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E var lb *loadbalancer.LoadBalancer cfg := r.LoadBalance lbLock.Lock() + defer lbLock.Unlock() l, ok := ep.HTTPRoutes().Get(cfg.Link) var linked *ReveseProxyRoute @@ -223,11 +224,10 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E handler: lb, } linked.SetHealthMonitor(lb) - if err := ep.AddRoute(linked); err != nil { + if err := ep.StartAddRoute(linked); err != nil { lb.Finish(err) return err } - lbLock.Unlock() } r.loadBalancer = lb diff --git a/internal/route/route_test.go b/internal/route/route_test.go index 500e40c4..8be85c1b 100644 --- a/internal/route/route_test.go +++ b/internal/route/route_test.go @@ -217,7 +217,7 @@ func TestRouteBindField(t *testing.T) { Port: route.Port{Proxy: 80}, } err := r.Validate() - require.NoError(t, err, "Validate should not return error for HTTP route with bind") + require.NoError(t, err, "Validate should not return error for HTTP route without bind") require.NotNil(t, r.LisURL, "LisURL should be set") require.Equal(t, "https://:0", r.LisURL.String(), "LisURL should contain bind address") }) diff --git a/internal/route/stream.go b/internal/route/stream.go index 62b2ce34..83ec3d72 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -75,7 +75,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { r.task.Finish(err) return err } - if err := ep.AddRoute(r); err != nil { + if err := ep.StartAddRoute(r); err != nil { r.task.Finish(err) return gperr.Wrap(err) }