fix: add nil guard for entrypoint retrieval; rename AddRoute to StartAddRoute

This commit is contained in:
yusing
2026-02-07 21:24:24 +08:00
parent 2fcf141ae2
commit faa5b36553
7 changed files with 46 additions and 21 deletions

View File

@@ -45,7 +45,7 @@ func (ep *Entrypoint) GetRoute(alias string) (types.Route, bool) {
return nil, false
}
func (ep *Entrypoint) AddRoute(r types.Route) error {
func (ep *Entrypoint) StartAddRoute(r types.Route) error {
if r.ShouldExclude() {
ep.excludedRoutes.Add(r)
r.Task().OnCancel("remove_route", func() {
@@ -80,13 +80,9 @@ func (ep *Entrypoint) AddRoute(r types.Route) error {
return nil
}
// AddHTTPRoute adds a HTTP route to the entrypoint's server.
//
// If the server does not exist, it will be created, started and return any error.
func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
func getAddr(route types.HTTPRoute) (httpAddr, httpsAddr string) {
if port := route.ListenURL().Port(); port == "" || port == "0" {
host := route.ListenURL().Hostname()
var httpAddr, httpsAddr string
if host == "" {
httpAddr = common.ProxyHTTPAddr
httpsAddr = common.ProxyHTTPSAddr
@@ -94,10 +90,26 @@ func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
httpAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPPort))
httpsAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPSPort))
}
return errors.Join(ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP), ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS))
return httpAddr, httpsAddr
}
return ep.addHTTPRoute(route, route.ListenURL().Host, HTTPProtoHTTPS)
httpsAddr = route.ListenURL().Host
return
}
// AddHTTPRoute adds a HTTP route to the entrypoint's server.
//
// If the server does not exist, it will be created, started and return any error.
func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
httpAddr, httpsAddr := getAddr(route)
var httpErr, httpsErr error
if httpAddr != "" {
httpErr = ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP)
}
if httpsAddr != "" {
httpsErr = ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS)
}
return errors.Join(httpErr, httpsErr)
}
func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error {
@@ -117,10 +129,17 @@ func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTT
}
func (ep *Entrypoint) delHTTPRoute(route types.HTTPRoute) {
addr := route.ListenURL().Host
srv, _ := ep.servers.Load(addr)
if srv != nil {
srv.DelRoute(route)
httpAddr, httpsAddr := getAddr(route)
if httpAddr != "" {
srv, _ := ep.servers.Load(httpAddr)
if srv != nil {
srv.DelRoute(route)
}
}
if httpsAddr != "" {
srv, _ := ep.servers.Load(httpsAddr)
if srv != nil {
srv.DelRoute(route)
}
}
// TODO: close if no servers left
}

View File

@@ -10,7 +10,7 @@ type Entrypoint interface {
DisablePoolsLog(v bool)
GetRoute(alias string) (types.Route, bool)
AddRoute(r types.Route) error
StartAddRoute(r types.Route) error
IterRoutes(yield func(r types.Route) bool)
NumRoutes() int
RoutesByProvider() map[string][]types.Route

View File

@@ -2,6 +2,7 @@ package uptime
import (
"context"
"errors"
"net/url"
"slices"
"time"
@@ -41,8 +42,12 @@ type (
var Poller = period.NewPoller("uptime", getStatuses, aggregateStatuses)
func getStatuses(ctx context.Context, _ StatusByAlias) (StatusByAlias, error) {
ep := entrypoint.FromCtx(ctx)
if ep == nil {
return StatusByAlias{}, errors.New("entrypoint not found in context")
}
return StatusByAlias{
Map: entrypoint.FromCtx(ctx).GetHealthInfoWithoutDetail(),
Map: ep.GetHealthInfoWithoutDetail(),
Timestamp: time.Now().Unix(),
}, nil
}
@@ -129,6 +134,7 @@ func (rs RouteStatuses) aggregate(limit int, offset int) Aggregated {
status := types.StatusUnknown
if state := config.ActiveState.Load(); state != nil {
// FIXME: pass ctx to getRoute
r, ok := entrypoint.FromCtx(state.Context()).GetRoute(alias)
if ok {
mon := r.HealthMonitor()

View File

@@ -133,7 +133,7 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
return err
}
if err := ep.AddRoute(s); err != nil {
if err := ep.StartAddRoute(s); err != nil {
s.task.Finish(err)
return gperr.Wrap(err)
}

View File

@@ -176,7 +176,7 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
return gperr.Wrap(err)
}
} else {
if err := ep.AddRoute(r); err != nil {
if err := ep.StartAddRoute(r); err != nil {
r.task.Finish(err)
return gperr.Wrap(err)
}
@@ -195,6 +195,7 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E
var lb *loadbalancer.LoadBalancer
cfg := r.LoadBalance
lbLock.Lock()
defer lbLock.Unlock()
l, ok := ep.HTTPRoutes().Get(cfg.Link)
var linked *ReveseProxyRoute
@@ -223,11 +224,10 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E
handler: lb,
}
linked.SetHealthMonitor(lb)
if err := ep.AddRoute(linked); err != nil {
if err := ep.StartAddRoute(linked); err != nil {
lb.Finish(err)
return err
}
lbLock.Unlock()
}
r.loadBalancer = lb

View File

@@ -217,7 +217,7 @@ func TestRouteBindField(t *testing.T) {
Port: route.Port{Proxy: 80},
}
err := r.Validate()
require.NoError(t, err, "Validate should not return error for HTTP route with bind")
require.NoError(t, err, "Validate should not return error for HTTP route without bind")
require.NotNil(t, r.LisURL, "LisURL should be set")
require.Equal(t, "https://:0", r.LisURL.String(), "LisURL should contain bind address")
})

View File

@@ -75,7 +75,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
r.task.Finish(err)
return err
}
if err := ep.AddRoute(r); err != nil {
if err := ep.StartAddRoute(r); err != nil {
r.task.Finish(err)
return gperr.Wrap(err)
}