From b763c9264580528c33418b22b0d67df85723dde4 Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 19 Sep 2025 10:23:47 +0800 Subject: [PATCH] refactor(stream): update TCP and UDP stream listeners to support proxy protocol and ACL wrapping --- internal/route/stream/tcp_tcp.go | 16 +++++++++++++--- internal/route/stream/udp_udp.go | 25 ++++++++++++++++++------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go index 6acdbc39..66c424d3 100644 --- a/internal/route/stream/tcp_tcp.go +++ b/internal/route/stream/tcp_tcp.go @@ -4,14 +4,16 @@ import ( "context" "net" + "github.com/pires/go-proxyproto" "github.com/rs/zerolog" + config "github.com/yusing/go-proxy/internal/config/types" nettypes "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils" "go.uber.org/atomic" ) type TCPTCPStream struct { - listener *net.TCPListener + listener net.Listener laddr *net.TCPAddr dst *net.TCPAddr @@ -34,12 +36,20 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { } func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { - listener, err := net.ListenTCP("tcp", s.laddr) + var err error + s.listener, err = net.ListenTCP("tcp", s.laddr) if err != nil { logErr(s, err, "failed to listen") return } - s.listener = listener + + if proxyProto := config.GetInstance().Value().Entrypoint.SupportProxyProtocol; proxyProto { + s.listener = &proxyproto.Listener{Listener: s.listener} + } + if acl := config.GetInstance().Value().ACL; acl != nil { + s.listener = acl.WrapTCP(s.listener) + } + s.preDial = preDial s.onRead = onRead go s.listen(ctx) diff --git a/internal/route/stream/udp_udp.go b/internal/route/stream/udp_udp.go index 4e7a98b9..393e4a7d 100644 --- a/internal/route/stream/udp_udp.go +++ b/internal/route/stream/udp_udp.go @@ -3,12 +3,14 @@ package stream import ( "bytes" "context" + "fmt" "maps" "net" "sync" "time" "github.com/rs/zerolog" + config "github.com/yusing/go-proxy/internal/config/types" nettypes "github.com/yusing/go-proxy/internal/net/types" "github.com/yusing/go-proxy/internal/utils/synk" "go.uber.org/atomic" @@ -16,7 +18,7 @@ import ( type UDPUDPStream struct { name string - listener *net.UDPConn + listener net.PacketConn laddr *net.UDPAddr dst *net.UDPAddr @@ -34,7 +36,7 @@ type UDPUDPStream struct { type udpUDPConn struct { srcAddr *net.UDPAddr dstConn *net.UDPConn - listener *net.UDPConn + listener net.PacketConn lastUsed atomic.Time closed atomic.Bool mu sync.Mutex @@ -66,12 +68,15 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { } func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { - listener, err := net.ListenUDP("udp", s.laddr) + var err error + s.listener, err = net.ListenUDP("udp", s.laddr) if err != nil { logErr(s, err, "failed to listen") return } - s.listener = listener + if acl := config.GetInstance().Value().ACL; acl != nil { + s.listener = acl.WrapUDP(s.listener) + } s.preDial = preDial s.onRead = onRead go s.listen(ctx) @@ -120,7 +125,7 @@ func (s *UDPUDPStream) listen(ctx context.Context) { case <-ctx.Done(): return default: - n, srcAddr, err := s.listener.ReadFromUDP(buf) + n, srcAddr, err := s.listener.ReadFrom(buf) if err != nil { if s.closed.Load() { return @@ -129,6 +134,12 @@ func (s *UDPUDPStream) listen(ctx context.Context) { continue } + srcAddrUDP, ok := srcAddr.(*net.UDPAddr) + if !ok { + logErr(s, fmt.Errorf("unexpected source address type: %T", srcAddr), "unexpected source address type") + continue + } + logDebugf(s, "read %d bytes from %s", n, srcAddr) if s.onRead != nil { @@ -139,7 +150,7 @@ func (s *UDPUDPStream) listen(ctx context.Context) { } // Get or create connection, passing the initial data - go s.getOrCreateConnection(ctx, srcAddr, bytes.Clone(buf[:n])) + go s.getOrCreateConnection(ctx, srcAddrUDP, bytes.Clone(buf[:n])) } } } @@ -233,7 +244,7 @@ func (conn *udpUDPConn) handleResponses(ctx context.Context) { _ = conn.dstConn.SetReadDeadline(time.Time{}) // Forward response back to client using the listener - _, err = conn.listener.WriteToUDP(buf[:n], conn.srcAddr) + _, err = conn.listener.WriteTo(buf[:n], conn.srcAddr) if err != nil { if !conn.closed.Load() { logErrf(conn, err, "failed to write %d bytes to client", n)