diff --git a/internal/entrypoint/routes.go b/internal/entrypoint/routes.go index 475edc7c..b7834b4d 100644 --- a/internal/entrypoint/routes.go +++ b/internal/entrypoint/routes.go @@ -5,7 +5,6 @@ import ( "net" "strconv" - "github.com/rs/zerolog/log" "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/types" ) @@ -45,22 +44,18 @@ func (ep *Entrypoint) GetRoute(alias string) (types.Route, bool) { return nil, false } -func (ep *Entrypoint) AddRoute(r types.Route) { +func (ep *Entrypoint) AddRoute(r types.Route) error { if r.ShouldExclude() { ep.excludedRoutes.Add(r) r.Task().OnCancel("remove_route", func() { ep.excludedRoutes.Del(r) }) - return + return nil } switch r := r.(type) { case types.HTTPRoute: if err := ep.AddHTTPRoute(r); err != nil { - log.Error(). - Err(err). - Str("route", r.Key()). - Str("listen_url", r.ListenURL().String()). - Msg("failed to add HTTP route") + return err } ep.shortLinkMatcher.AddRoute(r.Key()) r.Task().OnCancel("remove_route", func() { @@ -68,11 +63,18 @@ func (ep *Entrypoint) AddRoute(r types.Route) { ep.shortLinkMatcher.DelRoute(r.Key()) }) case types.StreamRoute: + err := r.ListenAndServe(r.Task().Context(), nil, nil) + if err != nil { + return err + } ep.streamRoutes.Add(r) + r.Task().OnCancel("remove_route", func() { + r.Stream().Close() ep.streamRoutes.Del(r) }) } + return nil } // AddHTTPRoute adds a HTTP route to the entrypoint's server. diff --git a/internal/entrypoint/types/entrypoint.go b/internal/entrypoint/types/entrypoint.go index f2ed5ca3..a58a66e0 100644 --- a/internal/entrypoint/types/entrypoint.go +++ b/internal/entrypoint/types/entrypoint.go @@ -10,7 +10,7 @@ type Entrypoint interface { DisablePoolsLog(v bool) GetRoute(alias string) (types.Route, bool) - AddRoute(r types.Route) + AddRoute(r types.Route) error IterRoutes(yield func(r types.Route) bool) NumRoutes() int RoutesByProvider() map[string][]types.Route diff --git a/internal/idlewatcher/handle_stream.go b/internal/idlewatcher/handle_stream.go index f24648ba..ec4f5df5 100644 --- a/internal/idlewatcher/handle_stream.go +++ b/internal/idlewatcher/handle_stream.go @@ -10,8 +10,8 @@ import ( var _ nettypes.Stream = (*Watcher)(nil) // ListenAndServe implements nettypes.Stream. -func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) { - w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck +func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) error { + return w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck return w.preDial(ctx, predial) }, func(ctx context.Context) error { return w.onRead(ctx, onRead) diff --git a/internal/net/types/stream.go b/internal/net/types/stream.go index 25ba81ef..f555b0b5 100644 --- a/internal/net/types/stream.go +++ b/internal/net/types/stream.go @@ -6,7 +6,7 @@ import ( ) type Stream interface { - ListenAndServe(ctx context.Context, preDial, onRead HookFunc) + ListenAndServe(ctx context.Context, preDial, onRead HookFunc) error LocalAddr() net.Addr Close() error } diff --git a/internal/route/fileserver.go b/internal/route/fileserver.go index d80b5918..b268b9cc 100644 --- a/internal/route/fileserver.go +++ b/internal/route/fileserver.go @@ -132,7 +132,11 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error { s.task.Finish(err) return err } - ep.AddRoute(s) + + if err := ep.AddRoute(s); err != nil { + s.task.Finish(err) + return gperr.Wrap(err) + } return nil } diff --git a/internal/route/reverse_proxy.go b/internal/route/reverse_proxy.go index 44c3c9a9..b816b67f 100755 --- a/internal/route/reverse_proxy.go +++ b/internal/route/reverse_proxy.go @@ -171,9 +171,15 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error { } if r.UseLoadBalance() { - r.addToLoadBalancer(parent, ep) + if err := r.addToLoadBalancer(parent, ep); err != nil { + r.task.Finish(err) + return gperr.Wrap(err) + } } else { - ep.AddRoute(r) + if err := ep.AddRoute(r); err != nil { + r.task.Finish(err) + return gperr.Wrap(err) + } } return nil } @@ -185,7 +191,7 @@ func (r *ReveseProxyRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) { var lbLock sync.Mutex -func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.Entrypoint) { +func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.Entrypoint) error { var lb *loadbalancer.LoadBalancer cfg := r.LoadBalance lbLock.Lock() @@ -217,7 +223,10 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E handler: lb, } linked.SetHealthMonitor(lb) - ep.AddRoute(linked) + if err := ep.AddRoute(linked); err != nil { + lb.Finish(err) + return err + } lbLock.Unlock() } r.loadBalancer = lb @@ -227,4 +236,5 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent, ep entrypoint.E r.task.OnCancel("lb_remove_server", func() { lb.RemoveServer(server) }) + return nil } diff --git a/internal/route/stream.go b/internal/route/stream.go index ea26ed5a..62b2ce34 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -69,26 +69,16 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { } } - r.ListenAndServe(r.task.Context(), nil, nil) - r.l = log.With(). - Str("type", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme). - Str("name", r.Name()). - Stringer("rurl", r.ProxyURL). - Stringer("laddr", r.LocalAddr()).Logger() - r.l.Info().Msg("stream started") - - r.task.OnCancel("close_stream", func() { - r.stream.Close() - r.l.Info().Msg("stream closed") - }) - ep := entrypoint.FromCtx(parent.Context()) if ep == nil { err := gperr.New("entrypoint not initialized") r.task.Finish(err) return err } - ep.AddRoute(r) + if err := ep.AddRoute(r); err != nil { + r.task.Finish(err) + return gperr.Wrap(err) + } return nil } diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go index a12a3b67..c65db5f4 100644 --- a/internal/route/stream/tcp_tcp.go +++ b/internal/route/stream/tcp_tcp.go @@ -43,16 +43,13 @@ func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil } -func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { +func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error { var err error s.listener, err = net.ListenTCP(s.network, s.laddr) if err != nil { - logErr(s, err, "failed to listen") - return + return err } - // TODO: add to entrypoint - if ep := entrypoint.FromCtx(ctx); ep != nil { if proxyProto := ep.SupportProxyProtocol(); proxyProto { s.listener = &proxyproto.Listener{Listener: s.listener} @@ -67,6 +64,7 @@ func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead netty s.preDial = preDial s.onRead = onRead go s.listen(ctx) + return nil } func (s *TCPTCPStream) Close() error { diff --git a/internal/route/stream/udp_udp.go b/internal/route/stream/udp_udp.go index 00aa2a61..6e2b67bc 100644 --- a/internal/route/stream/udp_udp.go +++ b/internal/route/stream/udp_udp.go @@ -75,22 +75,21 @@ func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age }, nil } -func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { +func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error { l, err := net.ListenUDP(s.network, s.laddr) if err != nil { - logErr(s, err, "failed to listen") - return + return err } s.listener = l if acl := acl.FromCtx(ctx); acl != nil { log.Debug().Str("listener", s.listener.LocalAddr().String()).Msg("wrapping listener with ACL") s.listener = acl.WrapUDP(s.listener) } - // TODO: add to entrypoint s.preDial = preDial s.onRead = onRead go s.listen(ctx) go s.cleanUp(ctx) + return nil } func (s *UDPUDPStream) Close() error {