fix: add nil guard for entrypoint retrieval; rename AddRoute to StartAddRoute

This commit is contained in:
yusing
2026-02-07 21:24:24 +08:00
parent 2fcf141ae2
commit faa5b36553
7 changed files with 46 additions and 21 deletions

View File

@@ -45,7 +45,7 @@ func (ep *Entrypoint) GetRoute(alias string) (types.Route, bool) {
return nil, false return nil, false
} }
func (ep *Entrypoint) AddRoute(r types.Route) error { func (ep *Entrypoint) StartAddRoute(r types.Route) error {
if r.ShouldExclude() { if r.ShouldExclude() {
ep.excludedRoutes.Add(r) ep.excludedRoutes.Add(r)
r.Task().OnCancel("remove_route", func() { r.Task().OnCancel("remove_route", func() {
@@ -80,13 +80,9 @@ func (ep *Entrypoint) AddRoute(r types.Route) error {
return nil return nil
} }
// AddHTTPRoute adds a HTTP route to the entrypoint's server. func getAddr(route types.HTTPRoute) (httpAddr, httpsAddr string) {
//
// If the server does not exist, it will be created, started and return any error.
func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
if port := route.ListenURL().Port(); port == "" || port == "0" { if port := route.ListenURL().Port(); port == "" || port == "0" {
host := route.ListenURL().Hostname() host := route.ListenURL().Hostname()
var httpAddr, httpsAddr string
if host == "" { if host == "" {
httpAddr = common.ProxyHTTPAddr httpAddr = common.ProxyHTTPAddr
httpsAddr = common.ProxyHTTPSAddr httpsAddr = common.ProxyHTTPSAddr
@@ -94,10 +90,26 @@ func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
httpAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPPort)) httpAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPPort))
httpsAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPSPort)) httpsAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPSPort))
} }
return errors.Join(ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP), ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS)) return httpAddr, httpsAddr
} }
return ep.addHTTPRoute(route, route.ListenURL().Host, HTTPProtoHTTPS) httpsAddr = route.ListenURL().Host
return
}
// AddHTTPRoute adds a HTTP route to the entrypoint's server.
//
// If the server does not exist, it will be created, started and return any error.
func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
httpAddr, httpsAddr := getAddr(route)
var httpErr, httpsErr error
if httpAddr != "" {
httpErr = ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP)
}
if httpsAddr != "" {
httpsErr = ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS)
}
return errors.Join(httpErr, httpsErr)
} }
func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error { func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error {
@@ -117,10 +129,17 @@ func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTT
} }
func (ep *Entrypoint) delHTTPRoute(route types.HTTPRoute) { func (ep *Entrypoint) delHTTPRoute(route types.HTTPRoute) {
addr := route.ListenURL().Host httpAddr, httpsAddr := getAddr(route)
srv, _ := ep.servers.Load(addr) if httpAddr != "" {
if srv != nil { srv, _ := ep.servers.Load(httpAddr)
srv.DelRoute(route) if srv != nil {
srv.DelRoute(route)
}
}
if httpsAddr != "" {
srv, _ := ep.servers.Load(httpsAddr)
if srv != nil {
srv.DelRoute(route)
}
} }
// TODO: close if no servers left
} }

View File

@@ -10,7 +10,7 @@ type Entrypoint interface {
DisablePoolsLog(v bool) DisablePoolsLog(v bool)
GetRoute(alias string) (types.Route, bool) GetRoute(alias string) (types.Route, bool)
AddRoute(r types.Route) error StartAddRoute(r types.Route) error
IterRoutes(yield func(r types.Route) bool) IterRoutes(yield func(r types.Route) bool)
NumRoutes() int NumRoutes() int
RoutesByProvider() map[string][]types.Route RoutesByProvider() map[string][]types.Route

View File

@@ -2,6 +2,7 @@ package uptime
import ( import (
"context" "context"
"errors"
"net/url" "net/url"
"slices" "slices"
"time" "time"
@@ -41,8 +42,12 @@ type (
var Poller = period.NewPoller("uptime", getStatuses, aggregateStatuses) var Poller = period.NewPoller("uptime", getStatuses, aggregateStatuses)
func getStatuses(ctx context.Context, _ StatusByAlias) (StatusByAlias, error) { func getStatuses(ctx context.Context, _ StatusByAlias) (StatusByAlias, error) {
ep := entrypoint.FromCtx(ctx)
if ep == nil {
return StatusByAlias{}, errors.New("entrypoint not found in context")
}
return StatusByAlias{ return StatusByAlias{
Map: entrypoint.FromCtx(ctx).GetHealthInfoWithoutDetail(), Map: ep.GetHealthInfoWithoutDetail(),
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
}, nil }, nil
} }
@@ -129,6 +134,7 @@ func (rs RouteStatuses) aggregate(limit int, offset int) Aggregated {
status := types.StatusUnknown status := types.StatusUnknown
if state := config.ActiveState.Load(); state != nil { if state := config.ActiveState.Load(); state != nil {
// FIXME: pass ctx to getRoute
r, ok := entrypoint.FromCtx(state.Context()).GetRoute(alias) r, ok := entrypoint.FromCtx(state.Context()).GetRoute(alias)
if ok { if ok {
mon := r.HealthMonitor() mon := r.HealthMonitor()

View File

@@ -133,7 +133,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
return err return err
} }
if err := ep.AddRoute(s); err != nil { if err := ep.StartAddRoute(s); err != nil {
s.task.Finish(err) s.task.Finish(err)
return gperr.Wrap(err) return gperr.Wrap(err)
} }

View File

@@ -176,7 +176,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
return gperr.Wrap(err) return gperr.Wrap(err)
} }
} else { } else {
if err := ep.AddRoute(r); err != nil { if err := ep.StartAddRoute(r); err != nil {
r.task.Finish(err) r.task.Finish(err)
return gperr.Wrap(err) return gperr.Wrap(err)
} }
@@ -195,6 +195,7 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E
var lb *loadbalancer.LoadBalancer var lb *loadbalancer.LoadBalancer
cfg := r.LoadBalance cfg := r.LoadBalance
lbLock.Lock() lbLock.Lock()
defer lbLock.Unlock()
l, ok := ep.HTTPRoutes().Get(cfg.Link) l, ok := ep.HTTPRoutes().Get(cfg.Link)
var linked *ReveseProxyRoute var linked *ReveseProxyRoute
@@ -223,11 +224,10 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E
handler: lb, handler: lb,
} }
linked.SetHealthMonitor(lb) linked.SetHealthMonitor(lb)
if err := ep.AddRoute(linked); err != nil { if err := ep.StartAddRoute(linked); err != nil {
lb.Finish(err) lb.Finish(err)
return err return err
} }
lbLock.Unlock()
} }
r.loadBalancer = lb r.loadBalancer = lb

View File

@@ -217,7 +217,7 @@ func TestRouteBindField(t *testing.T) {
Port: route.Port{Proxy: 80}, Port: route.Port{Proxy: 80},
} }
err := r.Validate() err := r.Validate()
require.NoError(t, err, "Validate should not return error for HTTP route with bind") require.NoError(t, err, "Validate should not return error for HTTP route without bind")
require.NotNil(t, r.LisURL, "LisURL should be set") require.NotNil(t, r.LisURL, "LisURL should be set")
require.Equal(t, "https://:0", r.LisURL.String(), "LisURL should contain bind address") require.Equal(t, "https://:0", r.LisURL.String(), "LisURL should contain bind address")
}) })

View File

@@ -75,7 +75,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
r.task.Finish(err) r.task.Finish(err)
return err return err
} }
if err := ep.AddRoute(r); err != nil { if err := ep.StartAddRoute(r); err != nil {
r.task.Finish(err) r.task.Finish(err)
return gperr.Wrap(err) return gperr.Wrap(err)
} }