Files
godoxy/internal/route/stream/udp_udp.go
yusing 5c8126c2e6 refactor(route/logging): streamline log messages with EmbedObject for improved clarity
Updated logging statements across multiple files to utilize EmbedObject for enhanced context in log messages. This change improves the readability and consistency of log outputs, particularly in health monitoring and route validation processes.
2026-02-08 09:20:45 +08:00

348 lines
7.3 KiB
Go

package stream
import (
"bytes"
"context"
"fmt"
"maps"
"net"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
acl "github.com/yusing/godoxy/internal/acl/types"
"github.com/yusing/godoxy/internal/agentpool"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/goutils/synk"
"go.uber.org/atomic"
)
type UDPUDPStream struct {
listener net.PacketConn
network string
dstNetwork string
laddr *net.UDPAddr
dst *net.UDPAddr
agent *agentpool.Agent
preDial nettypes.HookFunc
onRead nettypes.HookFunc
cleanUpTicker *time.Ticker
conns map[string]*udpUDPConn
closed atomic.Bool
mu sync.Mutex
}
type udpUDPConn struct {
srcAddr *net.UDPAddr
dstConn net.Conn
listener net.PacketConn
lastUsed atomic.Time
closed atomic.Bool
mu sync.Mutex
}
const (
udpBufferSize = 16 * 1024
udpIdleTimeout = 5 * time.Minute // Longer timeout for game sessions
udpCleanupInterval = 1 * time.Minute
udpReadTimeout = 30 * time.Second
)
var bufPool = synk.GetSizedBytesPool()
func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent) (nettypes.Stream, error) {
dst, err := net.ResolveUDPAddr(dstNetwork, dstAddr)
if err != nil {
return nil, err
}
laddr, err := net.ResolveUDPAddr(network, listenAddr)
if err != nil {
return nil, err
}
return &UDPUDPStream{
network: network,
dstNetwork: dstNetwork,
laddr: laddr,
dst: dst,
agent: agent,
conns: make(map[string]*udpUDPConn),
}, nil
}
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error {
l, err := net.ListenUDP(s.network, s.laddr)
if err != nil {
return err
}
s.listener = l
if acl := acl.FromCtx(ctx); acl != nil {
log.Debug().EmbedObject(s).Msg("wrapping listener with ACL")
s.listener = acl.WrapUDP(s.listener)
}
s.preDial = preDial
s.onRead = onRead
go s.listen(ctx)
go s.cleanUp(ctx)
return nil
}
func (s *UDPUDPStream) Close() error {
if s.closed.Swap(true) || s.listener == nil {
return nil
}
var wg sync.WaitGroup
s.mu.Lock()
for _, conn := range s.conns {
wg.Add(1)
go func(c *udpUDPConn) {
defer wg.Done()
c.Close()
}(conn)
}
clear(s.conns)
s.mu.Unlock()
wg.Wait()
return s.listener.Close()
}
func (s *UDPUDPStream) LocalAddr() net.Addr {
if s.listener == nil {
return s.laddr
}
return s.listener.LocalAddr()
}
func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) {
e.Str("protocol", s.network+"->"+s.dstNetwork)
if s.dst != nil {
e.Str("dst", s.dst.String())
}
}
func (s *UDPUDPStream) listen(ctx context.Context) {
buf := bufPool.GetSized(udpBufferSize)
defer bufPool.Put(buf)
for {
select {
case <-ctx.Done():
return
default:
n, srcAddr, err := s.listener.ReadFrom(buf)
if err != nil {
if s.closed.Load() {
return
}
logErr(s, err, "failed to read from listener")
continue
}
srcAddrUDP, ok := srcAddr.(*net.UDPAddr)
if !ok {
logErr(s, fmt.Errorf("unexpected source address type: %T", srcAddr), "unexpected source address type")
continue
}
logDebugf(s, "read %d bytes from %s", n, srcAddr)
if s.onRead != nil {
if err := s.onRead(ctx); err != nil {
logErr(s, err, "failed to on read")
continue
}
}
// Get or create connection, passing the initial data
go s.getOrCreateConnection(ctx, srcAddrUDP, bytes.Clone(buf[:n]))
}
}
}
func (s *UDPUDPStream) getOrCreateConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) {
key := srcAddr.String()
s.mu.Lock()
if conn, ok := s.conns[key]; ok {
s.mu.Unlock()
// Forward packet for existing connection
go conn.forwardToDestination(initialData)
return
}
defer s.mu.Unlock()
// Create new connection with initial data
conn, ok := s.createConnection(ctx, srcAddr, initialData)
if ok && !conn.closed.Load() {
s.conns[key] = conn
}
}
func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAddr, initialData []byte) (*udpUDPConn, bool) {
// Apply pre-dial if configured
if s.preDial != nil {
if err := s.preDial(ctx); err != nil {
logErr(s, err, "failed to pre-dial")
return nil, false
}
}
// Create connection to destination (direct UDP or via agent stream tunnel)
var (
dstConn net.Conn
err error
)
if s.agent != nil {
dstConn, err = s.agent.NewUDPClient(s.dst.String())
} else {
dstConn, err = net.DialUDP(s.dst.Network(), nil, s.dst)
}
if err != nil {
logErr(s, err, "failed to dial dst")
return nil, false
}
conn := &udpUDPConn{
srcAddr: srcAddr,
dstConn: dstConn,
listener: s.listener,
}
conn.lastUsed.Store(time.Now())
// Send initial data before starting response handler
if !conn.forwardToDestination(initialData) {
_ = dstConn.Close()
return nil, false
}
// Start response handler after initial data is sent
go conn.handleResponses(ctx)
logDebugf(s, "created new connection from %s", srcAddr.String())
return conn, true
}
func (conn *udpUDPConn) MarshalZerologObject(e *zerolog.Event) {
e.Stringer("src", conn.srcAddr).Stringer("dst", conn.dstConn.RemoteAddr())
}
func (conn *udpUDPConn) handleResponses(ctx context.Context) {
buf := bufPool.GetSized(udpBufferSize)
defer bufPool.Put(buf)
defer conn.Close()
for {
if conn.closed.Load() {
return
}
select {
case <-ctx.Done():
return
default:
// Set a reasonable timeout for reads
_ = conn.dstConn.SetReadDeadline(time.Now().Add(udpReadTimeout))
n, err := conn.dstConn.Read(buf)
if err != nil {
if !conn.closed.Load() {
logErr(conn, err, "failed to read from dst")
}
return
}
// Clear deadline after successful read
_ = conn.dstConn.SetReadDeadline(time.Time{})
// Forward response back to client using the listener
_, err = conn.listener.WriteTo(buf[:n], conn.srcAddr)
if err != nil {
if !conn.closed.Load() {
logErrf(conn, err, "failed to write %d bytes to client", n)
}
return
}
conn.lastUsed.Store(time.Now())
logDebugf(conn, "forwarded response to client, %d bytes", n)
}
}
}
func (s *UDPUDPStream) cleanUp(ctx context.Context) {
s.cleanUpTicker = time.NewTicker(udpCleanupInterval)
defer s.cleanUpTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-s.cleanUpTicker.C:
s.mu.Lock()
conns := maps.Clone(s.conns)
s.mu.Unlock()
removed := []string(nil)
for key, conn := range conns {
if conn.Expired() {
conn.Close()
removed = append(removed, key)
}
}
s.mu.Lock()
for _, key := range removed {
logDebugf(s, "cleaning up expired connection: %s", key)
delete(s.conns, key)
}
s.mu.Unlock()
}
}
}
func (conn *udpUDPConn) forwardToDestination(data []byte) bool {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.closed.Load() {
return false
}
_, err := conn.dstConn.Write(data)
if err != nil {
logErrf(conn, err, "failed to write %d bytes to dst", len(data))
return false
}
conn.lastUsed.Store(time.Now())
logDebugf(conn, "forwarded %d bytes to dst", len(data))
return true
}
func (conn *udpUDPConn) Expired() bool {
return time.Since(conn.lastUsed.Load()) > udpIdleTimeout
}
func (conn *udpUDPConn) Close() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.closed.Load() {
return
}
conn.closed.Store(true)
_ = conn.dstConn.Close()
conn.dstConn = nil
}