From b5328fe5e72be72b454cda61764b8f100cfe0ded Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 9 Jun 2025 22:20:26 +0800 Subject: [PATCH] feat(idlesleep): support idlesleep for stream routes, rewritten and fixed stream implementation --- go.mod | 2 +- internal/idlewatcher/common.go | 6 +- internal/idlewatcher/handle_http.go | 4 +- internal/idlewatcher/handle_stream.go | 81 ++++--- internal/idlewatcher/watcher.go | 2 +- internal/net/types/stream.go | 4 +- internal/route/route.go | 18 +- internal/route/routes/route.go | 1 + internal/route/stream.go | 100 ++++---- internal/route/stream/debug_debug.go | 12 + internal/route/stream/debug_prod.go | 7 + internal/route/stream/errors.go | 41 ++++ internal/route/stream/tcp_tcp.go | 162 +++++++++++++ internal/route/stream/udp_udp.go | 316 ++++++++++++++++++++++++++ internal/route/stream_impl.go | 129 ----------- internal/route/udp_forwarder.go | 204 ----------------- 16 files changed, 659 insertions(+), 430 deletions(-) create mode 100644 internal/route/stream/debug_debug.go create mode 100644 internal/route/stream/debug_prod.go create mode 100644 internal/route/stream/errors.go create mode 100644 internal/route/stream/tcp_tcp.go create mode 100644 internal/route/stream/udp_udp.go delete mode 100644 internal/route/stream_impl.go delete mode 100644 internal/route/udp_forwarder.go diff --git a/go.mod b/go.mod index 695f84b2..67a8dfed 100644 --- a/go.mod +++ b/go.mod @@ -222,7 +222,7 @@ require ( go.opentelemetry.io/otel v1.36.0 // indirect go.opentelemetry.io/otel/metric v1.36.0 // indirect go.opentelemetry.io/otel/trace v1.36.0 // indirect - go.uber.org/atomic v1.11.0 // indirect + go.uber.org/atomic v1.11.0 go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/mock v0.5.2 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/internal/idlewatcher/common.go b/internal/idlewatcher/common.go index ecc93bb9..5e13f8a1 100644 --- a/internal/idlewatcher/common.go +++ b/internal/idlewatcher/common.go @@ -1,8 +1,10 @@ package idlewatcher -import "context" +import ( + "context" +) -func (w *Watcher) cancelled(reqCtx context.Context) bool { +func (w *Watcher) canceled(reqCtx context.Context) bool { select { case <-reqCtx.Done(): w.l.Debug().AnErr("cause", context.Cause(reqCtx)).Msg("wake canceled") diff --git a/internal/idlewatcher/handle_http.go b/internal/idlewatcher/handle_http.go index 9c2d9f4b..e1b63266 100644 --- a/internal/idlewatcher/handle_http.go +++ b/internal/idlewatcher/handle_http.go @@ -92,7 +92,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN } ctx := r.Context() - if w.cancelled(ctx) { + if w.canceled(ctx) { w.redirectToStartEndpoint(rw, r) return false } @@ -107,7 +107,7 @@ func (w *Watcher) wakeFromHTTP(rw http.ResponseWriter, r *http.Request) (shouldN for { w.resetIdleTimer() - if w.cancelled(ctx) { + if w.canceled(ctx) { w.redirectToStartEndpoint(rw, r) return false } diff --git a/internal/idlewatcher/handle_stream.go b/internal/idlewatcher/handle_stream.go index 96810afa..4efa7ee9 100644 --- a/internal/idlewatcher/handle_stream.go +++ b/internal/idlewatcher/handle_stream.go @@ -5,45 +5,51 @@ import ( "net" "time" - gpnet "github.com/yusing/go-proxy/internal/net/types" + nettypes "github.com/yusing/go-proxy/internal/net/types" ) -// Setup implements types.Stream. -func (w *Watcher) Addr() net.Addr { - return w.stream.Addr() +var _ nettypes.Stream = (*Watcher)(nil) + +// ListenAndServe implements nettypes.Stream. +func (w *Watcher) ListenAndServe(ctx context.Context, predial, onRead nettypes.HookFunc) { + w.stream.ListenAndServe(ctx, func(ctx context.Context) error { //nolint:contextcheck + return w.preDial(ctx, predial) + }, func(ctx context.Context) error { + return w.onRead(ctx, onRead) + }) } -// Setup implements types.Stream. -func (w *Watcher) Setup() error { - return w.stream.Setup() -} - -// Accept implements types.Stream. -func (w *Watcher) Accept() (conn gpnet.StreamConn, err error) { - conn, err = w.stream.Accept() - if err != nil { - return - } - if wakeErr := w.wakeFromStream(); wakeErr != nil { - w.l.Err(wakeErr).Msg("error waking container") - } - return -} - -// Handle implements types.Stream. -func (w *Watcher) Handle(conn gpnet.StreamConn) error { - if err := w.wakeFromStream(); err != nil { - return err - } - return w.stream.Handle(conn) -} - -// Close implements types.Stream. +// Close implements nettypes.Stream. func (w *Watcher) Close() error { return w.stream.Close() } -func (w *Watcher) wakeFromStream() error { +// LocalAddr implements nettypes.Stream. +func (w *Watcher) LocalAddr() net.Addr { + return w.stream.LocalAddr() +} + +func (w *Watcher) preDial(ctx context.Context, predial nettypes.HookFunc) error { + if predial != nil { + if err := predial(ctx); err != nil { + return err + } + } + + return w.wakeFromStream(ctx) +} + +func (w *Watcher) onRead(ctx context.Context, onRead nettypes.HookFunc) error { + w.resetIdleTimer() + if onRead != nil { + if err := onRead(ctx); err != nil { + return err + } + } + return nil +} + +func (w *Watcher) wakeFromStream(ctx context.Context) error { w.resetIdleTimer() // pass through if container is already ready @@ -52,18 +58,27 @@ func (w *Watcher) wakeFromStream() error { } w.l.Debug().Msg("wake signal received") - err := w.Wake(context.Background()) + err := w.Wake(ctx) if err != nil { return err } for { + w.resetIdleTimer() + + if w.canceled(ctx) { + return nil + } + + if !w.waitStarted(ctx) { + return nil + } + ready, err := w.checkUpdateState() if err != nil { return err } if ready { - w.resetIdleTimer() w.l.Debug().Stringer("url", w.hc.URL()).Msg("container is ready, passing through") return nil } diff --git a/internal/idlewatcher/watcher.go b/internal/idlewatcher/watcher.go index 0f092129..7b42a622 100644 --- a/internal/idlewatcher/watcher.go +++ b/internal/idlewatcher/watcher.go @@ -261,7 +261,7 @@ func NewWatcher(parent task.Parent, r routes.Route, cfg *idlewatcher.Config) (*W case routes.ReverseProxyRoute: w.rp = r.ReverseProxy() case routes.StreamRoute: - w.stream = r + w.stream = r.Stream() default: w.provider.Close() return nil, w.newWatcherError(gperr.Errorf("unexpected route type: %T", r)) diff --git a/internal/net/types/stream.go b/internal/net/types/stream.go index 003d5549..25ba81ef 100644 --- a/internal/net/types/stream.go +++ b/internal/net/types/stream.go @@ -6,9 +6,9 @@ import ( ) type Stream interface { - ListenAndServe(ctx context.Context, preDial PreDialFunc) + ListenAndServe(ctx context.Context, preDial, onRead HookFunc) LocalAddr() net.Addr Close() error } -type PreDialFunc func(ctx context.Context) error +type HookFunc func(ctx context.Context) error diff --git a/internal/route/route.go b/internal/route/route.go index 743224d4..8f6c89ad 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -16,7 +16,7 @@ import ( "github.com/yusing/go-proxy/internal/homepage" idlewatcher "github.com/yusing/go-proxy/internal/idlewatcher/types" netutils "github.com/yusing/go-proxy/internal/net" - net "github.com/yusing/go-proxy/internal/net/types" + nettypes "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/proxmox" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -39,7 +39,7 @@ type ( Alias string `json:"alias"` Scheme route.Scheme `json:"scheme,omitempty"` Host string `json:"host,omitempty"` - Port route.Port `json:"port,omitempty"` + Port route.Port `json:"port"` Root string `json:"root,omitempty"` route.HTTPConfig @@ -64,8 +64,8 @@ type ( Provider string `json:"provider,omitempty"` // for backward compatibility // private fields - LisURL *net.URL `json:"lurl,omitempty"` - ProxyURL *net.URL `json:"purl,omitempty"` + LisURL *nettypes.URL `json:"lurl,omitempty"` + ProxyURL *nettypes.URL `json:"purl,omitempty"` Excluded *bool `json:"excluded"` @@ -195,19 +195,19 @@ func (r *Route) Validate() gperr.Error { switch r.Scheme { case route.SchemeFileServer: - r.ProxyURL = gperr.Collect(errs, net.ParseURL, "file://"+r.Root) + r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, "file://"+r.Root) r.Host = "" r.Port.Proxy = 0 case route.SchemeHTTP, route.SchemeHTTPS: if r.Port.Listening != 0 { errs.Addf("unexpected listening port for %s scheme", r.Scheme) } - r.ProxyURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) + r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) case route.SchemeTCP, route.SchemeUDP: if !r.ShouldExclude() { - r.LisURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://:%d", r.Scheme, r.Port.Listening)) + r.LisURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://:%d", r.Scheme, r.Port.Listening)) } - r.ProxyURL = gperr.Collect(errs, net.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) + r.ProxyURL = gperr.Collect(errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) } if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) { @@ -309,7 +309,7 @@ func (r *Route) ProviderName() string { return r.Provider } -func (r *Route) TargetURL() *net.URL { +func (r *Route) TargetURL() *nettypes.URL { return r.ProxyURL } diff --git a/internal/route/routes/route.go b/internal/route/routes/route.go index 30e1b7f4..babd8fa0 100644 --- a/internal/route/routes/route.go +++ b/internal/route/routes/route.go @@ -58,6 +58,7 @@ type ( StreamRoute interface { Route nettypes.Stream + Stream() nettypes.Stream } Provider interface { GetRoute(alias string) (r Route, ok bool) diff --git a/internal/route/stream.go b/internal/route/stream.go index 31e5c075..dddeae0e 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -2,7 +2,8 @@ package route import ( "context" - "errors" + "fmt" + "net" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -10,6 +11,7 @@ import ( "github.com/yusing/go-proxy/internal/idlewatcher" nettypes "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/route/routes" + "github.com/yusing/go-proxy/internal/route/stream" "github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/watcher/health/monitor" ) @@ -17,7 +19,7 @@ import ( // TODO: support stream load balance. type StreamRoute struct { *Route - nettypes.Stream `json:"-"` + stream nettypes.Stream l zerolog.Logger } @@ -33,10 +35,19 @@ func NewStreamRoute(base *Route) (routes.Route, gperr.Error) { }, nil } +func (r *StreamRoute) Stream() nettypes.Stream { + return r.stream +} + // Start implements task.TaskStarter. func (r *StreamRoute) Start(parent task.Parent) gperr.Error { + stream, err := r.initStream() + if err != nil { + return gperr.Wrap(err) + } + r.stream = stream + r.task = parent.Subtask("stream."+r.Name(), !r.ShouldExclude()) - r.Stream = NewStream(r) switch { case r.UseIdleWatcher(): @@ -45,20 +56,12 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { r.task.Finish(err) return gperr.Wrap(err, "idlewatcher error") } - r.Stream = waker + r.stream = waker r.HealthMon = waker case r.UseHealthCheck(): r.HealthMon = monitor.NewMonitor(r) } - if !r.ShouldExclude() { - if err := r.Setup(); err != nil { - r.task.Finish(err) - return gperr.Wrap(err) - } - r.l.Info().Int("port", r.Port.Listening).Msg("listening") - } - if r.HealthMon != nil { if err := r.HealthMon.Start(r.task); err != nil { gperr.LogWarn("health monitor error", err, &r.l) @@ -73,7 +76,14 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { return err } - go r.acceptConnections() + r.ListenAndServe(r.task.Context(), nil, nil) + r.l = r.l.With().Stringer("rurl", r.ProxyURL).Stringer("laddr", r.LocalAddr()).Logger() + r.l.Info().Msg("stream started") + + r.task.OnCancel("close_stream", func() { + r.stream.Close() + r.l.Info().Msg("stream closed") + }) routes.Stream.Add(r) r.task.OnCancel("remove_route_from_stream", func() { @@ -82,38 +92,34 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { return nil } -func (r *StreamRoute) acceptConnections() { - defer r.task.Finish("listener closed") - - go func() { - <-r.task.Context().Done() - r.Close() - }() - - for { - select { - case <-r.task.Context().Done(): - return - default: - conn, err := r.Accept() - if err != nil { - select { - case <-r.task.Context().Done(): - default: - gperr.LogError("accept connection error", err, &r.l) - } - r.task.Finish(err) - return - } - if conn == nil { - panic("connection is nil") - } - go func() { - err := r.Handle(conn) - if err != nil && !errors.Is(err, context.Canceled) { - gperr.LogError("handle connection error", err, &r.l) - } - }() - } - } +func (r *StreamRoute) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { + r.stream.ListenAndServe(ctx, preDial, onRead) +} + +func (r *StreamRoute) Close() error { + return r.stream.Close() +} + +func (r *StreamRoute) LocalAddr() net.Addr { + return r.stream.LocalAddr() +} + +func (r *StreamRoute) initStream() (nettypes.Stream, error) { + lurl, rurl := r.LisURL, r.ProxyURL + if lurl != nil && lurl.Scheme != rurl.Scheme { + return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme) + } + + laddr := ":0" + if lurl != nil { + laddr = lurl.Host + } + + switch rurl.Scheme { + case "tcp": + return stream.NewTCPTCPStream(laddr, rurl.Host) + case "udp": + return stream.NewUDPUDPStream(laddr, rurl.Host) + } + return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme) } diff --git a/internal/route/stream/debug_debug.go b/internal/route/stream/debug_debug.go new file mode 100644 index 00000000..26e814dc --- /dev/null +++ b/internal/route/stream/debug_debug.go @@ -0,0 +1,12 @@ +//go:build debug + +package stream + +import ( + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) { + log.Debug().Object("stream", stream).Msgf(format, v...) +} diff --git a/internal/route/stream/debug_prod.go b/internal/route/stream/debug_prod.go new file mode 100644 index 00000000..a8483ccb --- /dev/null +++ b/internal/route/stream/debug_prod.go @@ -0,0 +1,7 @@ +//go:build !debug + +package stream + +import "github.com/rs/zerolog" + +func logDebugf(stream zerolog.LogObjectMarshaler, format string, v ...any) {} diff --git a/internal/route/stream/errors.go b/internal/route/stream/errors.go new file mode 100644 index 00000000..79a5dbff --- /dev/null +++ b/internal/route/stream/errors.go @@ -0,0 +1,41 @@ +package stream + +import ( + "context" + "errors" + "io" + "syscall" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +func convertErr(err error) error { + if err == nil { + return nil + } + switch { + case errors.Is(err, context.Canceled), + errors.Is(err, io.ErrClosedPipe), + errors.Is(err, syscall.ECONNRESET): + return nil + default: + return err + } +} + +func logErr(stream zerolog.LogObjectMarshaler, err error, msg string) { + err = convertErr(err) + if err == nil { + return + } + log.Err(err).Object("stream", stream).Msg(msg) +} + +func logErrf(stream zerolog.LogObjectMarshaler, err error, format string, v ...any) { + err = convertErr(err) + if err == nil { + return + } + log.Err(err).Object("stream", stream).Msgf(format, v...) +} diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go new file mode 100644 index 00000000..6acdbc39 --- /dev/null +++ b/internal/route/stream/tcp_tcp.go @@ -0,0 +1,162 @@ +package stream + +import ( + "context" + "net" + + "github.com/rs/zerolog" + nettypes "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils" + "go.uber.org/atomic" +) + +type TCPTCPStream struct { + listener *net.TCPListener + laddr *net.TCPAddr + dst *net.TCPAddr + + preDial nettypes.HookFunc + onRead nettypes.HookFunc + + closed atomic.Bool +} + +func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { + dst, err := net.ResolveTCPAddr("tcp", dstAddr) + if err != nil { + return nil, err + } + laddr, err := net.ResolveTCPAddr("tcp", listenAddr) + if err != nil { + return nil, err + } + return &TCPTCPStream{laddr: laddr, dst: dst}, nil +} + +func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { + listener, err := net.ListenTCP("tcp", s.laddr) + if err != nil { + logErr(s, err, "failed to listen") + return + } + s.listener = listener + s.preDial = preDial + s.onRead = onRead + go s.listen(ctx) +} + +func (s *TCPTCPStream) Close() error { + if s.closed.Swap(true) || s.listener == nil { + return nil + } + return s.listener.Close() +} + +func (s *TCPTCPStream) LocalAddr() net.Addr { + if s.listener == nil { + return s.laddr + } + return s.listener.Addr() +} + +func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) { + e.Str("protocol", "tcp-tcp").Str("listen", s.listener.Addr().String()).Str("dst", s.dst.String()) +} + +func (s *TCPTCPStream) listen(ctx context.Context) { + for { + if s.closed.Load() { + return + } + + select { + case <-ctx.Done(): + return + default: + conn, err := s.listener.Accept() + if err != nil { + if s.closed.Load() { + return + } + logErr(s, err, "failed to accept connection") + continue + } + if s.onRead != nil { + if err := s.onRead(ctx); err != nil { + logErr(s, err, "failed to on read") + continue + } + } + go s.handle(ctx, conn) + } + } +} + +func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + if s.preDial != nil { + if err := s.preDial(ctx); err != nil { + if !s.closed.Load() { + logErr(s, err, "failed to pre-dial") + } + return + } + } + + if s.closed.Load() { + return + } + + dstConn, err := net.DialTCP("tcp", nil, s.dst) + if err != nil { + if !s.closed.Load() { + logErr(s, err, "failed to dial destination") + } + return + } + defer dstConn.Close() + + if s.closed.Load() { + return + } + + src := conn + dst := net.Conn(dstConn) + if s.onRead != nil { + src = &wrapperConn{ + Conn: conn, + ctx: ctx, + onRead: s.onRead, + } + dst = &wrapperConn{ + Conn: dstConn, + ctx: ctx, + onRead: s.onRead, + } + } + + pipe := utils.NewBidirectionalPipe(ctx, src, dst) + if err := pipe.Start(); err != nil && !s.closed.Load() { + logErr(s, err, "error in bidirectional pipe") + } +} + +type wrapperConn struct { + net.Conn + ctx context.Context + onRead nettypes.HookFunc +} + +func (w *wrapperConn) Read(b []byte) (n int, err error) { + n, err = w.Conn.Read(b) + if err != nil { + return + } + if w.onRead != nil { + if err = w.onRead(w.ctx); err != nil { + return + } + } + return +} diff --git a/internal/route/stream/udp_udp.go b/internal/route/stream/udp_udp.go new file mode 100644 index 00000000..8734c94c --- /dev/null +++ b/internal/route/stream/udp_udp.go @@ -0,0 +1,316 @@ +package stream + +import ( + "bytes" + "context" + "maps" + "net" + "sync" + "time" + + "github.com/rs/zerolog" + nettypes "github.com/yusing/go-proxy/internal/net/types" + "github.com/yusing/go-proxy/internal/utils/synk" + "go.uber.org/atomic" +) + +type UDPUDPStream struct { + name string + listener *net.UDPConn + + laddr *net.UDPAddr + dst *net.UDPAddr + + preDial nettypes.HookFunc + onRead nettypes.HookFunc + + cleanUpTicker *time.Ticker + + conns map[string]*udpUDPConn + closed atomic.Bool + mu sync.Mutex +} + +type udpUDPConn struct { + srcAddr *net.UDPAddr + dstConn *net.UDPConn + listener *net.UDPConn + lastUsed atomic.Time + closed atomic.Bool + mu sync.Mutex +} + +const ( + udpBufferSize = 16 * 1024 + udpIdleTimeout = 5 * time.Minute // Longer timeout for game sessions + udpCleanupInterval = 1 * time.Minute + udpReadTimeout = 30 * time.Second +) + +var bufPool = synk.NewBytesPool() + +func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { + dst, err := net.ResolveUDPAddr("udp", dstAddr) + if err != nil { + return nil, err + } + laddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return nil, err + } + return &UDPUDPStream{ + laddr: laddr, + dst: dst, + conns: make(map[string]*udpUDPConn), + }, nil +} + +func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { + listener, err := net.ListenUDP("udp", s.laddr) + if err != nil { + logErr(s, err, "failed to listen") + return + } + s.listener = listener + s.preDial = preDial + s.onRead = onRead + go s.listen(ctx) + go s.cleanUp(ctx) +} + +func (s *UDPUDPStream) Close() error { + if s.closed.Swap(true) || s.listener == nil { + return nil + } + + var wg sync.WaitGroup + s.mu.Lock() + for _, conn := range s.conns { + wg.Add(1) + go func(c *udpUDPConn) { + defer wg.Done() + c.Close() + }(conn) + } + clear(s.conns) + s.mu.Unlock() + + wg.Wait() + + return s.listener.Close() +} + +func (s *UDPUDPStream) LocalAddr() net.Addr { + if s.listener == nil { + return s.laddr + } + return s.listener.LocalAddr() +} + +func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) { + e.Str("protocol", "udp-udp").Str("name", s.name).Str("dst", s.dst.String()) +} + +func (s *UDPUDPStream) listen(ctx context.Context) { + buf := bufPool.GetSized(udpBufferSize) + defer bufPool.Put(buf) + + for { + select { + case <-ctx.Done(): + return + default: + n, srcAddr, err := s.listener.ReadFromUDP(buf) + if err != nil { + if s.closed.Load() { + return + } + logErr(s, err, "failed to read from listener") + continue + } + + logDebugf(s, "read %d bytes from %s", n, srcAddr) + + if s.onRead != nil { + if err := s.onRead(ctx); err != nil { + logErr(s, err, "failed to on read") + continue + } + } + + // Get or create connection, passing the initial data + go s.getOrCreateConnection(ctx, srcAddr, bytes.Clone(buf[:n])) + } + } +} + +func (s *UDPUDPStream) getOrCreateConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) { + key := srcAddr.String() + + s.mu.Lock() + if conn, ok := s.conns[key]; ok { + s.mu.Unlock() + // Forward packet for existing connection + go conn.forwardToDestination(initialData) + return + } + + defer s.mu.Unlock() + // Create new connection with initial data + conn, ok := s.createConnection(ctx, srcAddr, initialData) + if ok && !conn.closed.Load() { + s.conns[key] = conn + } +} + +func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) (*udpUDPConn, bool) { + // Apply pre-dial if configured + if s.preDial != nil { + if err := s.preDial(ctx); err != nil { + logErr(s, err, "failed to pre-dial") + return nil, false + } + } + + // Create UDP connection to destination + dstConn, err := net.DialUDP("udp", nil, s.dst) + if err != nil { + logErr(s, err, "failed to dial dst") + return nil, false + } + + conn := &udpUDPConn{ + srcAddr: srcAddr, + dstConn: dstConn, + listener: s.listener, + } + conn.lastUsed.Store(time.Now()) + + // Send initial data before starting response handler + if !conn.forwardToDestination(initialData) { + dstConn.Close() + return nil, false + } + + // Start response handler after initial data is sent + go conn.handleResponses(ctx) + + logDebugf(s, "created new connection from %s", srcAddr.String()) + return conn, true +} + +func (conn *udpUDPConn) MarshalZerologObject(e *zerolog.Event) { + e.Stringer("src", conn.srcAddr).Stringer("dst", conn.dstConn.RemoteAddr()) +} + +func (conn *udpUDPConn) handleResponses(ctx context.Context) { + buf := bufPool.GetSized(udpBufferSize) + defer bufPool.Put(buf) + + defer conn.Close() + + for { + if conn.closed.Load() { + return + } + + select { + case <-ctx.Done(): + return + default: + // Set a reasonable timeout for reads + _ = conn.dstConn.SetReadDeadline(time.Now().Add(udpReadTimeout)) + + n, err := conn.dstConn.Read(buf) + if err != nil { + if !conn.closed.Load() { + logErr(conn, err, "failed to read from dst") + } + return + } + + // Clear deadline after successful read + _ = conn.dstConn.SetReadDeadline(time.Time{}) + + // Forward response back to client using the listener + _, err = conn.listener.WriteToUDP(buf[:n], conn.srcAddr) + if err != nil { + if !conn.closed.Load() { + logErrf(conn, err, "failed to write %d bytes to client", n) + } + return + } + + conn.lastUsed.Store(time.Now()) + logDebugf(conn, "forwarded response to client, %d bytes", n) + } + } +} + +func (s *UDPUDPStream) cleanUp(ctx context.Context) { + s.cleanUpTicker = time.NewTicker(udpCleanupInterval) + defer s.cleanUpTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-s.cleanUpTicker.C: + s.mu.Lock() + conns := maps.Clone(s.conns) + s.mu.Unlock() + + removed := []string(nil) + for key, conn := range conns { + if conn.Expired() { + conn.Close() + removed = append(removed, key) + } + } + + s.mu.Lock() + for _, key := range removed { + logDebugf(s, "cleaning up expired connection: %s", key) + delete(s.conns, key) + } + s.mu.Unlock() + } + } +} + +func (conn *udpUDPConn) forwardToDestination(data []byte) bool { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.closed.Load() { + return false + } + + _, err := conn.dstConn.Write(data) + if err != nil { + logErrf(conn, err, "failed to write %d bytes to dst", len(data)) + return false + } + + conn.lastUsed.Store(time.Now()) + logDebugf(conn, "forwarded %d bytes to dst", len(data)) + return true +} + +func (conn *udpUDPConn) Expired() bool { + return time.Since(conn.lastUsed.Load()) > udpIdleTimeout +} + +func (conn *udpUDPConn) Close() { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.closed.Load() { + return + } + + conn.closed.Store(true) + + conn.dstConn.Close() + conn.dstConn = nil +} diff --git a/internal/route/stream_impl.go b/internal/route/stream_impl.go deleted file mode 100644 index 316f6dbd..00000000 --- a/internal/route/stream_impl.go +++ /dev/null @@ -1,129 +0,0 @@ -package route - -import ( - "errors" - "fmt" - "io" - "net" - "time" - - "github.com/yusing/go-proxy/internal/net/types" - U "github.com/yusing/go-proxy/internal/utils" -) - -type ( - Stream struct { - *StreamRoute - - listener types.StreamListener - targetAddr net.Addr - } -) - -const ( - streamFirstConnBufferSize = 128 - streamDialTimeout = 5 * time.Second -) - -func NewStream(base *StreamRoute) *Stream { - return &Stream{ - StreamRoute: base, - } -} - -func (stream *Stream) Addr() net.Addr { - if stream.listener == nil { - panic("listener is nil") - } - return stream.listener.Addr() -} - -func (stream *Stream) Setup() error { - var lcfg net.ListenConfig - var err error - - ctx := stream.task.Context() - - switch stream.Scheme { - case "tcp": - stream.targetAddr, err = net.ResolveTCPAddr("tcp", stream.ProxyURL.Host) - if err != nil { - return err - } - tcpListener, err := lcfg.Listen(ctx, "tcp", stream.LisURL.Host) - if err != nil { - return err - } - // in case ListeningPort was zero, get the actual port - stream.Port.Listening = tcpListener.Addr().(*net.TCPAddr).Port - stream.listener = types.NetListener(tcpListener) - case "udp": - stream.targetAddr, err = net.ResolveUDPAddr("udp", stream.ProxyURL.Host) - if err != nil { - return err - } - udpListener, err := lcfg.ListenPacket(ctx, "udp", stream.LisURL.Host) - if err != nil { - return err - } - udpConn, ok := udpListener.(*net.UDPConn) - if !ok { - udpListener.Close() - return errors.New("udp listener is not *net.UDPConn") - } - stream.Port.Listening = udpConn.LocalAddr().(*net.UDPAddr).Port - stream.listener = NewUDPForwarder(ctx, udpConn, stream.targetAddr) - default: - panic("should not reach here") - } - - return nil -} - -func (stream *Stream) Accept() (conn types.StreamConn, err error) { - if stream.listener == nil { - return nil, errors.New("listener is nil") - } - // prevent Accept from blocking forever - done := make(chan struct{}) - go func() { - conn, err = stream.listener.Accept() - close(done) - }() - - select { - case <-stream.task.Context().Done(): - stream.Close() - return nil, stream.task.Context().Err() - case <-done: - return conn, nil - } -} - -func (stream *Stream) Handle(conn types.StreamConn) error { - switch conn := conn.(type) { - case *UDPConn: - switch stream := stream.listener.(type) { - case *UDPForwarder: - return stream.Handle(conn) - default: - return fmt.Errorf("unexpected listener type: %T", stream) - } - case io.ReadWriteCloser: - dialer := &net.Dialer{Timeout: streamDialTimeout} - dstConn, err := dialer.DialContext(stream.task.Context(), stream.targetAddr.Network(), stream.targetAddr.String()) - if err != nil { - return err - } - defer dstConn.Close() - defer conn.Close() - pipe := U.NewBidirectionalPipe(stream.task.Context(), conn, dstConn) - return pipe.Start() - default: - return fmt.Errorf("unexpected conn type: %T", conn) - } -} - -func (stream *Stream) Close() error { - return stream.listener.Close() -} diff --git a/internal/route/udp_forwarder.go b/internal/route/udp_forwarder.go deleted file mode 100644 index 581e2733..00000000 --- a/internal/route/udp_forwarder.go +++ /dev/null @@ -1,204 +0,0 @@ -package route - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/rs/zerolog/log" - "github.com/yusing/go-proxy/internal/gperr" - "github.com/yusing/go-proxy/internal/net/types" - F "github.com/yusing/go-proxy/internal/utils/functional" -) - -type ( - UDPForwarder struct { - ctx context.Context - forwarder *net.UDPConn - dstAddr net.Addr - connMap F.Map[string, *UDPConn] - mu sync.Mutex - } - UDPConn struct { - srcAddr *net.UDPAddr - conn net.Conn - buf *UDPBuf - } - UDPBuf struct { - data, oob []byte - n, oobn int - } -) - -const udpConnBufferSize = 4096 - -func NewUDPForwarder(ctx context.Context, forwarder *net.UDPConn, dstAddr net.Addr) *UDPForwarder { - return &UDPForwarder{ - ctx: ctx, - forwarder: forwarder, - dstAddr: dstAddr, - connMap: F.NewMapOf[string, *UDPConn](), - } -} - -func newUDPBuf() *UDPBuf { - return &UDPBuf{ - data: make([]byte, udpConnBufferSize), - oob: make([]byte, udpConnBufferSize), - } -} - -func (conn *UDPConn) DstAddrString() string { - return conn.conn.RemoteAddr().Network() + "://" + conn.conn.RemoteAddr().String() -} - -func (w *UDPForwarder) Addr() net.Addr { - return w.forwarder.LocalAddr() -} - -func (w *UDPForwarder) Accept() (types.StreamConn, error) { - buf := newUDPBuf() - addr, err := w.readFromListener(buf) - if err != nil { - return nil, err - } - return &UDPConn{ - srcAddr: addr, - buf: buf, - }, nil -} - -func (w *UDPForwarder) dialDst() (dstConn net.Conn, err error) { - switch dstAddr := w.dstAddr.(type) { - case *net.UDPAddr: - var laddr *net.UDPAddr - if dstAddr.IP.IsLoopback() { - laddr, _ = net.ResolveUDPAddr(dstAddr.Network(), "127.0.0.1:") - } - dstConn, err = net.DialUDP(w.dstAddr.Network(), laddr, dstAddr) - case *net.TCPAddr: - dstConn, err = net.DialTCP(w.dstAddr.Network(), nil, dstAddr) - default: - err = fmt.Errorf("unsupported network %s", w.dstAddr.Network()) - } - return -} - -func (w *UDPForwarder) readFromListener(buf *UDPBuf) (srcAddr *net.UDPAddr, err error) { - buf.n, buf.oobn, _, srcAddr, err = w.forwarder.ReadMsgUDP(buf.data, buf.oob) - if err == nil { - log.Debug().Msgf("read from listener udp://%s success (n: %d, oobn: %d)", w.Addr().String(), buf.n, buf.oobn) - } - return -} - -func (conn *UDPConn) read() (err error) { - switch dstConn := conn.conn.(type) { - case *net.UDPConn: - conn.buf.n, conn.buf.oobn, _, _, err = dstConn.ReadMsgUDP(conn.buf.data, conn.buf.oob) - default: - conn.buf.n, err = dstConn.Read(conn.buf.data[:conn.buf.n]) - conn.buf.oobn = 0 - } - if err == nil { - log.Debug().Msgf("read from dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) - } - return -} - -func (w *UDPForwarder) writeToSrc(srcAddr *net.UDPAddr, buf *UDPBuf) (err error) { - buf.n, buf.oobn, err = w.forwarder.WriteMsgUDP(buf.data[:buf.n], buf.oob[:buf.oobn], srcAddr) - if err == nil { - log.Debug().Msgf("write to src %s://%s success (n: %d, oobn: %d)", srcAddr.Network(), srcAddr.String(), buf.n, buf.oobn) - } - return -} - -func (conn *UDPConn) write() (err error) { - switch dstConn := conn.conn.(type) { - case *net.UDPConn: - conn.buf.n, conn.buf.oobn, err = dstConn.WriteMsgUDP(conn.buf.data[:conn.buf.n], conn.buf.oob[:conn.buf.oobn], nil) - if err == nil { - log.Debug().Msgf("write to dst %s success (n: %d, oobn: %d)", conn.DstAddrString(), conn.buf.n, conn.buf.oobn) - } - default: - _, err = dstConn.Write(conn.buf.data[:conn.buf.n]) - if err == nil { - log.Debug().Msgf("write to dst %s success (n: %d)", conn.DstAddrString(), conn.buf.n) - } - } - - return -} - -func (w *UDPForwarder) getInitConn(conn *UDPConn, key string) (*UDPConn, error) { - w.mu.Lock() - defer w.mu.Unlock() - - dst, ok := w.connMap.Load(key) - if !ok { - var err error - dst = conn - dst.conn, err = w.dialDst() - if err != nil { - return nil, err - } - if err := dst.write(); err != nil { - dst.conn.Close() - return nil, err - } - w.connMap.Store(key, dst) - } else { - conn.conn = dst.conn - if err := conn.write(); err != nil { - w.connMap.Delete(key) - dst.conn.Close() - return nil, err - } - } - - return dst, nil -} - -func (w *UDPForwarder) Handle(streamConn types.StreamConn) error { - conn, ok := streamConn.(*UDPConn) - if !ok { - panic("unexpected conn type") - } - - key := conn.srcAddr.String() - dst, err := w.getInitConn(conn, key) - if err != nil { - return err - } - - for { - select { - case <-w.ctx.Done(): - return nil - default: - if err := dst.read(); err != nil { - w.connMap.Delete(key) - dst.conn.Close() - return err - } - - if err := w.writeToSrc(dst.srcAddr, dst.buf); err != nil { - return err - } - } - } -} - -func (w *UDPForwarder) Close() error { - errs := gperr.NewBuilder("errors closing udp conn") - w.mu.Lock() - defer w.mu.Unlock() - w.connMap.RangeAll(func(key string, conn *UDPConn) { - errs.Add(conn.conn.Close()) - }) - w.connMap.Clear() - errs.Add(w.forwarder.Close()) - return errs.Error() -}