refactor(stream): update TCP and UDP stream listeners to support proxy protocol and ACL wrapping

This commit is contained in:
yusing
2025-09-19 10:23:47 +08:00
parent 09b14a47e9
commit b763c92645
2 changed files with 31 additions and 10 deletions

View File

@@ -4,14 +4,16 @@ import (
"context"
"net"
"github.com/pires/go-proxyproto"
"github.com/rs/zerolog"
config "github.com/yusing/go-proxy/internal/config/types"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils"
"go.uber.org/atomic"
)
type TCPTCPStream struct {
listener *net.TCPListener
listener net.Listener
laddr *net.TCPAddr
dst *net.TCPAddr
@@ -34,12 +36,20 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
}
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
listener, err := net.ListenTCP("tcp", s.laddr)
var err error
s.listener, err = net.ListenTCP("tcp", s.laddr)
if err != nil {
logErr(s, err, "failed to listen")
return
}
s.listener = listener
if proxyProto := config.GetInstance().Value().Entrypoint.SupportProxyProtocol; proxyProto {
s.listener = &proxyproto.Listener{Listener: s.listener}
}
if acl := config.GetInstance().Value().ACL; acl != nil {
s.listener = acl.WrapTCP(s.listener)
}
s.preDial = preDial
s.onRead = onRead
go s.listen(ctx)

View File

@@ -3,12 +3,14 @@ package stream
import (
"bytes"
"context"
"fmt"
"maps"
"net"
"sync"
"time"
"github.com/rs/zerolog"
config "github.com/yusing/go-proxy/internal/config/types"
nettypes "github.com/yusing/go-proxy/internal/net/types"
"github.com/yusing/go-proxy/internal/utils/synk"
"go.uber.org/atomic"
@@ -16,7 +18,7 @@ import (
type UDPUDPStream struct {
name string
listener *net.UDPConn
listener net.PacketConn
laddr *net.UDPAddr
dst *net.UDPAddr
@@ -34,7 +36,7 @@ type UDPUDPStream struct {
type udpUDPConn struct {
srcAddr *net.UDPAddr
dstConn *net.UDPConn
listener *net.UDPConn
listener net.PacketConn
lastUsed atomic.Time
closed atomic.Bool
mu sync.Mutex
@@ -66,12 +68,15 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
}
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
listener, err := net.ListenUDP("udp", s.laddr)
var err error
s.listener, err = net.ListenUDP("udp", s.laddr)
if err != nil {
logErr(s, err, "failed to listen")
return
}
s.listener = listener
if acl := config.GetInstance().Value().ACL; acl != nil {
s.listener = acl.WrapUDP(s.listener)
}
s.preDial = preDial
s.onRead = onRead
go s.listen(ctx)
@@ -120,7 +125,7 @@ func (s *UDPUDPStream) listen(ctx context.Context) {
case <-ctx.Done():
return
default:
n, srcAddr, err := s.listener.ReadFromUDP(buf)
n, srcAddr, err := s.listener.ReadFrom(buf)
if err != nil {
if s.closed.Load() {
return
@@ -129,6 +134,12 @@ func (s *UDPUDPStream) listen(ctx context.Context) {
continue
}
srcAddrUDP, ok := srcAddr.(*net.UDPAddr)
if !ok {
logErr(s, fmt.Errorf("unexpected source address type: %T", srcAddr), "unexpected source address type")
continue
}
logDebugf(s, "read %d bytes from %s", n, srcAddr)
if s.onRead != nil {
@@ -139,7 +150,7 @@ func (s *UDPUDPStream) listen(ctx context.Context) {
}
// Get or create connection, passing the initial data
go s.getOrCreateConnection(ctx, srcAddr, bytes.Clone(buf[:n]))
go s.getOrCreateConnection(ctx, srcAddrUDP, bytes.Clone(buf[:n]))
}
}
}
@@ -233,7 +244,7 @@ func (conn *udpUDPConn) handleResponses(ctx context.Context) {
_ = conn.dstConn.SetReadDeadline(time.Time{})
// Forward response back to client using the listener
_, err = conn.listener.WriteToUDP(buf[:n], conn.srcAddr)
_, err = conn.listener.WriteTo(buf[:n], conn.srcAddr)
if err != nil {
if !conn.closed.Load() {
logErrf(conn, err, "failed to write %d bytes to client", n)