mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-25 09:48:32 +02:00
refactor(stream): update TCP and UDP stream listeners to support proxy protocol and ACL wrapping
This commit is contained in:
@@ -4,14 +4,16 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/pires/go-proxyproto"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TCPTCPStream struct {
|
type TCPTCPStream struct {
|
||||||
listener *net.TCPListener
|
listener net.Listener
|
||||||
laddr *net.TCPAddr
|
laddr *net.TCPAddr
|
||||||
dst *net.TCPAddr
|
dst *net.TCPAddr
|
||||||
|
|
||||||
@@ -34,12 +36,20 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||||
listener, err := net.ListenTCP("tcp", s.laddr)
|
var err error
|
||||||
|
s.listener, err = net.ListenTCP("tcp", s.laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(s, err, "failed to listen")
|
logErr(s, err, "failed to listen")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.listener = listener
|
|
||||||
|
if proxyProto := config.GetInstance().Value().Entrypoint.SupportProxyProtocol; proxyProto {
|
||||||
|
s.listener = &proxyproto.Listener{Listener: s.listener}
|
||||||
|
}
|
||||||
|
if acl := config.GetInstance().Value().ACL; acl != nil {
|
||||||
|
s.listener = acl.WrapTCP(s.listener)
|
||||||
|
}
|
||||||
|
|
||||||
s.preDial = preDial
|
s.preDial = preDial
|
||||||
s.onRead = onRead
|
s.onRead = onRead
|
||||||
go s.listen(ctx)
|
go s.listen(ctx)
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ package stream
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
nettypes "github.com/yusing/go-proxy/internal/net/types"
|
||||||
"github.com/yusing/go-proxy/internal/utils/synk"
|
"github.com/yusing/go-proxy/internal/utils/synk"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
@@ -16,7 +18,7 @@ import (
|
|||||||
|
|
||||||
type UDPUDPStream struct {
|
type UDPUDPStream struct {
|
||||||
name string
|
name string
|
||||||
listener *net.UDPConn
|
listener net.PacketConn
|
||||||
|
|
||||||
laddr *net.UDPAddr
|
laddr *net.UDPAddr
|
||||||
dst *net.UDPAddr
|
dst *net.UDPAddr
|
||||||
@@ -34,7 +36,7 @@ type UDPUDPStream struct {
|
|||||||
type udpUDPConn struct {
|
type udpUDPConn struct {
|
||||||
srcAddr *net.UDPAddr
|
srcAddr *net.UDPAddr
|
||||||
dstConn *net.UDPConn
|
dstConn *net.UDPConn
|
||||||
listener *net.UDPConn
|
listener net.PacketConn
|
||||||
lastUsed atomic.Time
|
lastUsed atomic.Time
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -66,12 +68,15 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||||
listener, err := net.ListenUDP("udp", s.laddr)
|
var err error
|
||||||
|
s.listener, err = net.ListenUDP("udp", s.laddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(s, err, "failed to listen")
|
logErr(s, err, "failed to listen")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.listener = listener
|
if acl := config.GetInstance().Value().ACL; acl != nil {
|
||||||
|
s.listener = acl.WrapUDP(s.listener)
|
||||||
|
}
|
||||||
s.preDial = preDial
|
s.preDial = preDial
|
||||||
s.onRead = onRead
|
s.onRead = onRead
|
||||||
go s.listen(ctx)
|
go s.listen(ctx)
|
||||||
@@ -120,7 +125,7 @@ func (s *UDPUDPStream) listen(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
n, srcAddr, err := s.listener.ReadFromUDP(buf)
|
n, srcAddr, err := s.listener.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s.closed.Load() {
|
if s.closed.Load() {
|
||||||
return
|
return
|
||||||
@@ -129,6 +134,12 @@ func (s *UDPUDPStream) listen(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srcAddrUDP, ok := srcAddr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
logErr(s, fmt.Errorf("unexpected source address type: %T", srcAddr), "unexpected source address type")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
logDebugf(s, "read %d bytes from %s", n, srcAddr)
|
logDebugf(s, "read %d bytes from %s", n, srcAddr)
|
||||||
|
|
||||||
if s.onRead != nil {
|
if s.onRead != nil {
|
||||||
@@ -139,7 +150,7 @@ func (s *UDPUDPStream) listen(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get or create connection, passing the initial data
|
// Get or create connection, passing the initial data
|
||||||
go s.getOrCreateConnection(ctx, srcAddr, bytes.Clone(buf[:n]))
|
go s.getOrCreateConnection(ctx, srcAddrUDP, bytes.Clone(buf[:n]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -233,7 +244,7 @@ func (conn *udpUDPConn) handleResponses(ctx context.Context) {
|
|||||||
_ = conn.dstConn.SetReadDeadline(time.Time{})
|
_ = conn.dstConn.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
// Forward response back to client using the listener
|
// Forward response back to client using the listener
|
||||||
_, err = conn.listener.WriteToUDP(buf[:n], conn.srcAddr)
|
_, err = conn.listener.WriteTo(buf[:n], conn.srcAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !conn.closed.Load() {
|
if !conn.closed.Load() {
|
||||||
logErrf(conn, err, "failed to write %d bytes to client", n)
|
logErrf(conn, err, "failed to write %d bytes to client", n)
|
||||||
|
|||||||
Reference in New Issue
Block a user