fix: add nil guard before entrypoint retrieval; move config from types/

This commit is contained in:
yusing
2026-02-06 12:01:09 +08:00
parent e383cd247a
commit a6fed3f221
17 changed files with 90 additions and 37 deletions

View File

@@ -74,6 +74,9 @@ func FavIcon(c *gin.Context) {
func GetFavIconFromAlias(ctx context.Context, alias string, variant icons.Variant) (iconfetch.Result, error) { func GetFavIconFromAlias(ctx context.Context, alias string, variant icons.Variant) (iconfetch.Result, error) {
// try with route.Icon // try with route.Icon
ep := entrypoint.FromCtx(ctx) ep := entrypoint.FromCtx(ctx)
if ep == nil { // impossible, but just in case
return iconfetch.FetchResultWithErrorf(http.StatusInternalServerError, "entrypoint not initialized")
}
r, ok := ep.HTTPRoutes().Get(alias) r, ok := ep.HTTPRoutes().Get(alias)
if !ok { if !ok {
return iconfetch.FetchResultWithErrorf(http.StatusNotFound, "route not found") return iconfetch.FetchResultWithErrorf(http.StatusNotFound, "route not found")

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
apitypes "github.com/yusing/goutils/apitypes"
"github.com/yusing/goutils/http/httpheaders" "github.com/yusing/goutils/http/httpheaders"
"github.com/yusing/goutils/http/websocket" "github.com/yusing/goutils/http/websocket"
@@ -25,6 +26,10 @@ import (
// @Router /health [get] // @Router /health [get]
func Health(c *gin.Context) { func Health(c *gin.Context) {
ep := entrypoint.FromCtx(c.Request.Context()) ep := entrypoint.FromCtx(c.Request.Context())
if ep == nil { // impossible, but just in case
c.JSON(http.StatusInternalServerError, apitypes.Error("entrypoint not initialized"))
return
}
if httpheaders.IsWebsocket(c.Request.Header) { if httpheaders.IsWebsocket(c.Request.Header) {
websocket.PeriodicWrite(c, 1*time.Second, func() (any, error) { websocket.PeriodicWrite(c, 1*time.Second, func() (any, error) {
return ep.GetHealthInfoSimple(), nil return ep.GetHealthInfoSimple(), nil

View File

@@ -8,6 +8,7 @@ import (
"github.com/yusing/godoxy/internal/homepage" "github.com/yusing/godoxy/internal/homepage"
_ "github.com/yusing/goutils/apitypes" _ "github.com/yusing/goutils/apitypes"
apitypes "github.com/yusing/goutils/apitypes"
) )
// @x-id "categories" // @x-id "categories"
@@ -22,6 +23,10 @@ import (
// @Router /homepage/categories [get] // @Router /homepage/categories [get]
func Categories(c *gin.Context) { func Categories(c *gin.Context) {
ep := entrypoint.FromCtx(c.Request.Context()) ep := entrypoint.FromCtx(c.Request.Context())
if ep == nil { // impossible, but just in case
c.JSON(http.StatusInternalServerError, apitypes.Error("entrypoint not initialized"))
return
}
c.JSON(http.StatusOK, HomepageCategories(ep)) c.JSON(http.StatusOK, HomepageCategories(ep))
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/yusing/godoxy/internal/route" "github.com/yusing/godoxy/internal/route"
_ "github.com/yusing/goutils/apitypes" _ "github.com/yusing/goutils/apitypes"
apitypes "github.com/yusing/goutils/apitypes"
) )
type RoutesByProvider map[string][]route.Route type RoutesByProvider map[string][]route.Route
@@ -25,5 +26,9 @@ type RoutesByProvider map[string][]route.Route
// @Router /route/by_provider [get] // @Router /route/by_provider [get]
func ByProvider(c *gin.Context) { func ByProvider(c *gin.Context) {
ep := entrypoint.FromCtx(c.Request.Context()) ep := entrypoint.FromCtx(c.Request.Context())
if ep == nil { // impossible, but just in case
c.JSON(http.StatusInternalServerError, apitypes.Error("entrypoint not initialized"))
return
}
c.JSON(http.StatusOK, ep.RoutesByProvider()) c.JSON(http.StatusOK, ep.RoutesByProvider())
} }

View File

@@ -33,6 +33,11 @@ func Route(c *gin.Context) {
} }
ep := entrypoint.FromCtx(c.Request.Context()) ep := entrypoint.FromCtx(c.Request.Context())
if ep == nil { // impossible, but just in case
c.JSON(http.StatusInternalServerError, apitypes.Error("entrypoint not initialized"))
return
}
route, ok := ep.GetRoute(request.Which) route, ok := ep.GetRoute(request.Which)
if ok { if ok {
c.JSON(http.StatusOK, route) c.JSON(http.StatusOK, route)

View File

@@ -8,7 +8,7 @@ import (
"github.com/yusing/godoxy/agent/pkg/agent" "github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/internal/acl" "github.com/yusing/godoxy/internal/acl"
"github.com/yusing/godoxy/internal/autocert" "github.com/yusing/godoxy/internal/autocert"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types" entrypoint "github.com/yusing/godoxy/internal/entrypoint"
homepage "github.com/yusing/godoxy/internal/homepage/types" homepage "github.com/yusing/godoxy/internal/homepage/types"
maxmind "github.com/yusing/godoxy/internal/maxmind/types" maxmind "github.com/yusing/godoxy/internal/maxmind/types"
"github.com/yusing/godoxy/internal/notif" "github.com/yusing/godoxy/internal/notif"

View File

@@ -27,7 +27,7 @@ type findRouteFunc func(HTTPRoutes, string) types.HTTPRoute
type Entrypoint struct { type Entrypoint struct {
task *task.Task task *task.Task
cfg *entrypoint.Config cfg *Config
middleware *middleware.Middleware middleware *middleware.Middleware
notFoundHandler http.Handler notFoundHandler http.Handler
@@ -48,9 +48,9 @@ type Entrypoint struct {
var _ entrypoint.Entrypoint = &Entrypoint{} var _ entrypoint.Entrypoint = &Entrypoint{}
var emptyCfg entrypoint.Config var emptyCfg Config
func NewEntrypoint(parent task.Parent, cfg *entrypoint.Config) *Entrypoint { func NewEntrypoint(parent task.Parent, cfg *Config) *Entrypoint {
if cfg == nil { if cfg == nil {
cfg = &emptyCfg cfg = &emptyCfg
} }
@@ -91,12 +91,23 @@ func NewEntrypoint(parent task.Parent, cfg *entrypoint.Config) *Entrypoint {
return ep return ep
} }
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher { func (ep *Entrypoint) SupportProxyProtocol() bool {
return ep.shortLinkMatcher return ep.cfg.SupportProxyProtocol
} }
func (ep *Entrypoint) Config() *entrypoint.Config { func (ep *Entrypoint) DisablePoolsLog(v bool) {
return ep.cfg ep.httpPoolDisableLog.Store(v)
// apply to all running http servers
for _, srv := range ep.servers.Range {
srv.routes.DisableLog(v)
}
// apply to other pools
ep.streamRoutes.DisableLog(v)
ep.excludedRoutes.DisableLog(v)
}
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
return ep.shortLinkMatcher
} }
func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] { func (ep *Entrypoint) HTTPRoutes() entrypoint.PoolLike[types.HTTPRoute] {
@@ -111,19 +122,12 @@ func (ep *Entrypoint) ExcludedRoutes() entrypoint.RWPoolLike[types.Route] {
return ep.excludedRoutes return ep.excludedRoutes
} }
func (ep *Entrypoint) GetServer(addr string) (*httpServer, bool) { func (ep *Entrypoint) GetServer(addr string) (http.Handler, bool) {
return ep.servers.Load(addr) return ep.servers.Load(addr)
} }
func (ep *Entrypoint) DisablePoolsLog(v bool) { func (ep *Entrypoint) PrintServers() {
ep.httpPoolDisableLog.Store(v) log.Info().Msgf("servers: %v", xsync.ToPlainMap(ep.servers))
// apply to all running http servers
for _, srv := range ep.servers.Range {
srv.routes.DisableLog(v)
}
// apply to other pools
ep.streamRoutes.DisableLog(v)
ep.excludedRoutes.DisableLog(v)
} }
func (ep *Entrypoint) SetFindRouteDomains(domains []string) { func (ep *Entrypoint) SetFindRouteDomains(domains []string) {

View File

@@ -11,6 +11,7 @@ import (
"testing" "testing"
. "github.com/yusing/godoxy/internal/entrypoint" . "github.com/yusing/godoxy/internal/entrypoint"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
"github.com/yusing/godoxy/internal/route" "github.com/yusing/godoxy/internal/route"
routeTypes "github.com/yusing/godoxy/internal/route/types" routeTypes "github.com/yusing/godoxy/internal/route/types"
"github.com/yusing/godoxy/internal/types" "github.com/yusing/godoxy/internal/types"
@@ -47,13 +48,15 @@ func (t noopTransport) RoundTrip(req *http.Request) (*http.Response, error) {
} }
func BenchmarkEntrypointReal(b *testing.B) { func BenchmarkEntrypointReal(b *testing.B) {
var ep Entrypoint task := task.NewTestTask(b)
ep := NewEntrypoint(task, nil)
req := http.Request{ req := http.Request{
Method: "GET", Method: "GET",
URL: &url.URL{Path: "/", RawPath: "/"}, URL: &url.URL{Path: "/", RawPath: "/"},
Host: "test.domain.tld", Host: "test.domain.tld",
} }
ep.SetFindRouteDomains([]string{}) ep.SetFindRouteDomains([]string{})
entrypoint.SetCtx(task, ep)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "1") w.Header().Set("Content-Length", "1")
@@ -89,7 +92,7 @@ func BenchmarkEntrypointReal(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
err = r.Start(task.NewTestTask(b)) err = r.Start(task)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -114,13 +117,15 @@ func BenchmarkEntrypointReal(b *testing.B) {
} }
func BenchmarkEntrypoint(b *testing.B) { func BenchmarkEntrypoint(b *testing.B) {
var ep Entrypoint task := task.NewTestTask(b)
ep := NewEntrypoint(task, nil)
req := http.Request{ req := http.Request{
Method: "GET", Method: "GET",
URL: &url.URL{Path: "/", RawPath: "/"}, URL: &url.URL{Path: "/", RawPath: "/"},
Host: "test.domain.tld", Host: "test.domain.tld",
} }
ep.SetFindRouteDomains([]string{}) ep.SetFindRouteDomains([]string{})
entrypoint.SetCtx(task, ep)
r := &route.Route{ r := &route.Route{
Alias: "test", Alias: "test",
@@ -139,7 +144,7 @@ func BenchmarkEntrypoint(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
err = r.Start(task.RootTask("test", false)) err = r.Start(task)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -74,11 +74,14 @@ func (ep *Entrypoint) AddRoute(r types.Route) {
func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error { func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
if port := route.ListenURL().Port(); port == "" || port == "0" { if port := route.ListenURL().Port(); port == "" || port == "0" {
host := route.ListenURL().Hostname() host := route.ListenURL().Hostname()
var httpAddr, httpsAddr string
if host == "" { if host == "" {
host = common.ProxyHTTPHost httpAddr = common.ProxyHTTPAddr
httpsAddr = common.ProxyHTTPSAddr
} else {
httpAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPPort))
httpsAddr = net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPSPort))
} }
httpAddr := net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPPort))
httpsAddr := net.JoinHostPort(host, strconv.Itoa(common.ProxyHTTPSPort))
return errors.Join(ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP), ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS)) return errors.Join(ep.addHTTPRoute(route, httpAddr, HTTPProtoHTTP), ep.addHTTPRoute(route, httpsAddr, HTTPProtoHTTPS))
} }

View File

@@ -165,7 +165,7 @@ func TestEntrypoint_ShortLinkDispatch(t *testing.T) {
ep.ShortLinkMatcher().AddRoute("app") ep.ShortLinkMatcher().AddRoute("app")
server := NewHTTPServer(ep) server := NewHTTPServer(ep)
err := server.Listen("localhost:8080", HTTPProtoHTTP) err := server.Listen("localhost:0", HTTPProtoHTTP)
require.NoError(t, err) require.NoError(t, err)
t.Run("shortlink host", func(t *testing.T) { t.Run("shortlink host", func(t *testing.T) {

View File

@@ -5,7 +5,7 @@ import (
) )
type Entrypoint interface { type Entrypoint interface {
Config() *Config SupportProxyProtocol() bool
DisablePoolsLog(v bool) DisablePoolsLog(v bool)

View File

@@ -126,7 +126,11 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
} }
} }
entrypoint.FromCtx(parent.Context()).AddRoute(s) ep := entrypoint.FromCtx(parent.Context())
if ep == nil {
return gperr.New("entrypoint not initialized")
}
ep.AddRoute(s)
return nil return nil
} }

View File

@@ -163,10 +163,15 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
} }
} }
ep := entrypoint.FromCtx(parent.Context())
if ep == nil {
return gperr.New("entrypoint not initialized")
}
if r.UseLoadBalance() { if r.UseLoadBalance() {
r.addToLoadBalancer(parent) r.addToLoadBalancer(parent, ep)
} else { } else {
entrypoint.FromCtx(parent.Context()).AddRoute(r) ep.AddRoute(r)
} }
return nil return nil
} }
@@ -178,12 +183,11 @@ func (r *ReveseProxyRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var lbLock sync.Mutex var lbLock sync.Mutex
func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) { func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.Entrypoint) {
var lb *loadbalancer.LoadBalancer var lb *loadbalancer.LoadBalancer
cfg := r.LoadBalance cfg := r.LoadBalance
lbLock.Lock() lbLock.Lock()
ep := entrypoint.FromCtx(r.task.Context())
l, ok := ep.HTTPRoutes().Get(cfg.Link) l, ok := ep.HTTPRoutes().Get(cfg.Link)
var linked *ReveseProxyRoute var linked *ReveseProxyRoute
if ok { if ok {

View File

@@ -46,7 +46,7 @@ type (
Host string `json:"host,omitempty"` Host string `json:"host,omitempty"`
Port route.Port `json:"port"` Port route.Port `json:"port"`
Bind string `json:"bind,omitempty" validate:"omitempty,dive,ip_addr" extensions:"x-nullable"` Bind string `json:"bind,omitempty" validate:"omitempty,ip_addr" extensions:"x-nullable"`
Root string `json:"root,omitempty"` Root string `json:"root,omitempty"`
SPA bool `json:"spa,omitempty"` // Single-page app mode: serves index for non-existent paths SPA bool `json:"spa,omitempty"` // Single-page app mode: serves index for non-existent paths
@@ -199,7 +199,11 @@ func (r *Route) validate() gperr.Error {
if (r.Proxmox == nil || r.Proxmox.Node == "" || r.Proxmox.VMID == nil) && r.Container == nil { if (r.Proxmox == nil || r.Proxmox.Node == "" || r.Proxmox.VMID == nil) && r.Container == nil {
wasNotNil := r.Proxmox != nil wasNotNil := r.Proxmox != nil
proxmoxProviders := config.WorkingState.Load().Value().Providers.Proxmox workingState := config.WorkingState.Load()
var proxmoxProviders []*proxmox.Config
if workingState != nil { // nil in tests
proxmoxProviders = workingState.Value().Providers.Proxmox
}
if len(proxmoxProviders) > 0 { if len(proxmoxProviders) > 0 {
// it's fine if ip is nil // it's fine if ip is nil
hostname := r.Host hostname := r.Host

View File

@@ -82,7 +82,11 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
r.l.Info().Msg("stream closed") r.l.Info().Msg("stream closed")
}) })
entrypoint.FromCtx(parent.Context()).AddRoute(r) ep := entrypoint.FromCtx(parent.Context())
if ep == nil {
return gperr.New("entrypoint not initialized")
}
ep.AddRoute(r)
return nil return nil
} }

View File

@@ -58,8 +58,10 @@ func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead netty
s.listener = acl.WrapTCP(s.listener) s.listener = acl.WrapTCP(s.listener)
} }
if proxyProto := entrypoint.FromCtx(ctx).Config().SupportProxyProtocol; proxyProto { if ep := entrypoint.FromCtx(ctx); ep != nil {
s.listener = &proxyproto.Listener{Listener: s.listener} if proxyProto := entrypoint.FromCtx(ctx).SupportProxyProtocol(); proxyProto {
s.listener = &proxyproto.Listener{Listener: s.listener}
}
} }
s.preDial = preDial s.preDial = preDial