refactor(entrypoint): propagate errors from route registration and stream serving

Changed AddRoute and ListenAndServe methods to return errors instead of logging them internally and continuing. This allows callers to properly handle and propagate errors, improving error visibility and enabling better error management across the codebase. Updated all callers in fileserver, reverse_proxy, stream routes to handle these errors appropriately.
This commit is contained in:
yusing
2026-02-07 21:06:22 +08:00
parent 65cfa90c95
commit 1705f77060
9 changed files with 43 additions and 40 deletions

View File

@@ -5,7 +5,6 @@ import (
"net" "net"
"strconv" "strconv"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/types" "github.com/yusing/godoxy/internal/types"
) )
@@ -45,22 +44,18 @@ func (ep *Entrypoint) GetRoute(alias string) (types.Route, bool) {
return nil, false return nil, false
} }
func (ep *Entrypoint) AddRoute(r types.Route) { func (ep *Entrypoint) AddRoute(r types.Route) error {
if r.ShouldExclude() { if r.ShouldExclude() {
ep.excludedRoutes.Add(r) ep.excludedRoutes.Add(r)
r.Task().OnCancel("remove_route", func() { r.Task().OnCancel("remove_route", func() {
ep.excludedRoutes.Del(r) ep.excludedRoutes.Del(r)
}) })
return return nil
} }
switch r := r.(type) { switch r := r.(type) {
case types.HTTPRoute: case types.HTTPRoute:
if err := ep.AddHTTPRoute(r); err != nil { if err := ep.AddHTTPRoute(r); err != nil {
log.Error(). return err
Err(err).
Str("route", r.Key()).
Str("listen_url", r.ListenURL().String()).
Msg("failed to add HTTP route")
} }
ep.shortLinkMatcher.AddRoute(r.Key()) ep.shortLinkMatcher.AddRoute(r.Key())
r.Task().OnCancel("remove_route", func() { r.Task().OnCancel("remove_route", func() {
@@ -68,11 +63,18 @@ func (ep *Entrypoint) AddRoute(r types.Route) {
ep.shortLinkMatcher.DelRoute(r.Key()) ep.shortLinkMatcher.DelRoute(r.Key())
}) })
case types.StreamRoute: case types.StreamRoute:
err := r.ListenAndServe(r.Task().Context(), nil, nil)
if err != nil {
return err
}
ep.streamRoutes.Add(r) ep.streamRoutes.Add(r)
r.Task().OnCancel("remove_route", func() { r.Task().OnCancel("remove_route", func() {
r.Stream().Close()
ep.streamRoutes.Del(r) ep.streamRoutes.Del(r)
}) })
} }
return nil
} }
// AddHTTPRoute adds a HTTP route to the entrypoint's server. // AddHTTPRoute adds a HTTP route to the entrypoint's server.

View File

@@ -10,7 +10,7 @@ type Entrypoint interface {
DisablePoolsLog(v bool) DisablePoolsLog(v bool)
GetRoute(alias string) (types.Route, bool) GetRoute(alias string) (types.Route, bool)
AddRoute(r types.Route) AddRoute(r types.Route) error
IterRoutes(yield func(r types.Route) bool) IterRoutes(yield func(r types.Route) bool)
NumRoutes() int NumRoutes() int
RoutesByProvider() map[string][]types.Route RoutesByProvider() map[string][]types.Route

View File

@@ -10,8 +10,8 @@ import (
var _ nettypes.Stream = (*Watcher)(nil) var _ nettypes.Stream = (*Watcher)(nil)
// ListenAndServe implements nettypes.Stream. // ListenAndServe implements nettypes.Stream.
func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) { func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) error {
w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck return w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck
return w.preDial(ctx, predial) return w.preDial(ctx, predial)
}, func(ctx context.Context) error { }, func(ctx context.Context) error {
return w.onRead(ctx, onRead) return w.onRead(ctx, onRead)

View File

@@ -6,7 +6,7 @@ import (
) )
type Stream interface { type Stream interface {
ListenAndServe(ctx context.Context, preDial, onRead HookFunc) ListenAndServe(ctx context.Context, preDial, onRead HookFunc) error
LocalAddr() net.Addr LocalAddr() net.Addr
Close() error Close() error
} }

View File

@@ -132,7 +132,11 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
s.task.Finish(err) s.task.Finish(err)
return err return err
} }
ep.AddRoute(s)
if err := ep.AddRoute(s); err != nil {
s.task.Finish(err)
return gperr.Wrap(err)
}
return nil return nil
} }

View File

@@ -171,9 +171,15 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
} }
if r.UseLoadBalance() { if r.UseLoadBalance() {
r.addToLoadBalancer(parent, ep) if err := r.addToLoadBalancer(parent, ep); err != nil {
r.task.Finish(err)
return gperr.Wrap(err)
}
} else { } else {
ep.AddRoute(r) if err := ep.AddRoute(r); err != nil {
r.task.Finish(err)
return gperr.Wrap(err)
}
} }
return nil return nil
} }
@@ -185,7 +191,7 @@ func (r *ReveseProxyRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var lbLock sync.Mutex var lbLock sync.Mutex
func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.Entrypoint) { func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.Entrypoint) error {
var lb *loadbalancer.LoadBalancer var lb *loadbalancer.LoadBalancer
cfg := r.LoadBalance cfg := r.LoadBalance
lbLock.Lock() lbLock.Lock()
@@ -217,7 +223,10 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E
handler: lb, handler: lb,
} }
linked.SetHealthMonitor(lb) linked.SetHealthMonitor(lb)
ep.AddRoute(linked) if err := ep.AddRoute(linked); err != nil {
lb.Finish(err)
return err
}
lbLock.Unlock() lbLock.Unlock()
} }
r.loadBalancer = lb r.loadBalancer = lb
@@ -227,4 +236,5 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E
r.task.OnCancel("lb_remove_server", func() { r.task.OnCancel("lb_remove_server", func() {
lb.RemoveServer(server) lb.RemoveServer(server)
}) })
return nil
} }

View File

@@ -69,26 +69,16 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
} }
} }
r.ListenAndServe(r.task.Context(), nil, nil)
r.l = log.With().
Str("type", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme).
Str("name", r.Name()).
Stringer("rurl", r.ProxyURL).
Stringer("laddr", r.LocalAddr()).Logger()
r.l.Info().Msg("stream started")
r.task.OnCancel("close_stream", func() {
r.stream.Close()
r.l.Info().Msg("stream closed")
})
ep := entrypoint.FromCtx(parent.Context()) ep := entrypoint.FromCtx(parent.Context())
if ep == nil { if ep == nil {
err := gperr.New("entrypoint not initialized") err := gperr.New("entrypoint not initialized")
r.task.Finish(err) r.task.Finish(err)
return err return err
} }
ep.AddRoute(r) if err := ep.AddRoute(r); err != nil {
r.task.Finish(err)
return gperr.Wrap(err)
}
return nil return nil
} }

View File

@@ -43,16 +43,13 @@ func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age
return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil
} }
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error {
var err error var err error
s.listener, err = net.ListenTCP(s.network, s.laddr) s.listener, err = net.ListenTCP(s.network, s.laddr)
if err != nil { if err != nil {
logErr(s, err, "failed to listen") return err
return
} }
// TODO: add to entrypoint
if ep := entrypoint.FromCtx(ctx); ep != nil { if ep := entrypoint.FromCtx(ctx); ep != nil {
if proxyProto := ep.SupportProxyProtocol(); proxyProto { if proxyProto := ep.SupportProxyProtocol(); proxyProto {
s.listener = &proxyproto.Listener{Listener: s.listener} s.listener = &proxyproto.Listener{Listener: s.listener}
@@ -67,6 +64,7 @@ func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead netty
s.preDial = preDial s.preDial = preDial
s.onRead = onRead s.onRead = onRead
go s.listen(ctx) go s.listen(ctx)
return nil
} }
func (s *TCPTCPStream) Close() error { func (s *TCPTCPStream) Close() error {

View File

@@ -75,22 +75,21 @@ func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age
}, nil }, nil
} }
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error {
l, err := net.ListenUDP(s.network, s.laddr) l, err := net.ListenUDP(s.network, s.laddr)
if err != nil { if err != nil {
logErr(s, err, "failed to listen") return err
return
} }
s.listener = l s.listener = l
if acl := acl.FromCtx(ctx); acl != nil { if acl := acl.FromCtx(ctx); acl != nil {
log.Debug().Str("listener", s.listener.LocalAddr().String()).Msg("wrapping listener with ACL") log.Debug().Str("listener", s.listener.LocalAddr().String()).Msg("wrapping listener with ACL")
s.listener = acl.WrapUDP(s.listener) s.listener = acl.WrapUDP(s.listener)
} }
// TODO: add to entrypoint
s.preDial = preDial s.preDial = preDial
s.onRead = onRead s.onRead = onRead
go s.listen(ctx) go s.listen(ctx)
go s.cleanUp(ctx) go s.cleanUp(ctx)
return nil
} }
func (s *UDPUDPStream) Close() error { func (s *UDPUDPStream) Close() error {