mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
fix(stream): properly handle remote stream scheme IPv4/6
This commit is contained in:
@@ -269,16 +269,19 @@ func (r *Route) validate() gperr.Error {
|
||||
|
||||
switch r.Scheme {
|
||||
case route.SchemeFileServer:
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
|
||||
r.Host = ""
|
||||
r.Port.Proxy = 0
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
|
||||
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
|
||||
if r.Port.Listening != 0 {
|
||||
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
|
||||
}
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
case route.SchemeTCP, route.SchemeUDP:
|
||||
if !r.ShouldExclude() {
|
||||
if r.ShouldExclude() {
|
||||
// should exclude, we don't care the scheme here.
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
} else {
|
||||
if r.Bind == "" {
|
||||
r.Bind = "0.0.0.0"
|
||||
}
|
||||
@@ -286,23 +289,28 @@ func (r *Route) validate() gperr.Error {
|
||||
if bindIP == nil {
|
||||
return gperr.Errorf("invalid bind address %s", r.Bind)
|
||||
}
|
||||
var scheme string
|
||||
if bindIP.To4() == nil { // IPv6
|
||||
if r.Scheme == route.SchemeTCP {
|
||||
scheme = "tcp6"
|
||||
} else {
|
||||
scheme = "udp6"
|
||||
}
|
||||
} else {
|
||||
if r.Scheme == route.SchemeTCP {
|
||||
scheme = "tcp4"
|
||||
} else {
|
||||
scheme = "udp4"
|
||||
}
|
||||
remoteIP := net.ParseIP(r.Host)
|
||||
if remoteIP == nil {
|
||||
return gperr.Errorf("invalid remote address %s", r.Host)
|
||||
}
|
||||
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", scheme, r.Bind, r.Port.Listening))
|
||||
toNetwork := func(ip net.IP, scheme route.Scheme) string {
|
||||
if ip.To4() == nil {
|
||||
if scheme == route.SchemeTCP {
|
||||
return "tcp6"
|
||||
}
|
||||
return "udp6"
|
||||
}
|
||||
if scheme == route.SchemeTCP {
|
||||
return "tcp4"
|
||||
}
|
||||
return "udp4"
|
||||
}
|
||||
lScheme := toNetwork(bindIP, r.Scheme)
|
||||
rScheme := toNetwork(remoteIP, r.Scheme)
|
||||
|
||||
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", lScheme, r.Bind, r.Port.Listening))
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", rScheme, r.Host, r.Port.Proxy))
|
||||
}
|
||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||
}
|
||||
|
||||
if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) {
|
||||
|
||||
@@ -70,7 +70,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
||||
|
||||
r.ListenAndServe(r.task.Context(), nil, nil)
|
||||
r.l = log.With().
|
||||
Str("type", r.LisURL.Scheme).
|
||||
Str("type", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme).
|
||||
Str("name", r.Name()).
|
||||
Stringer("rurl", r.ProxyURL).
|
||||
Stringer("laddr", r.LocalAddr()).Logger()
|
||||
@@ -102,9 +102,10 @@ func (r *StreamRoute) LocalAddr() net.Addr {
|
||||
|
||||
func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
||||
lurl, rurl := r.LisURL, r.ProxyURL
|
||||
// lurl scheme is either tcp4/tcp6 -> tcp, udp4/udp6 -> udp
|
||||
// rurl scheme does not have the trailing 4/6
|
||||
if strings.TrimRight(lurl.Scheme, "46") != rurl.Scheme {
|
||||
// tcp4/tcp6 -> tcp, udp4/udp6 -> udp
|
||||
lScheme := strings.TrimRight(lurl.Scheme, "46")
|
||||
rScheme := strings.TrimRight(rurl.Scheme, "46")
|
||||
if lScheme != rScheme {
|
||||
return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme)
|
||||
}
|
||||
|
||||
@@ -113,11 +114,11 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
||||
laddr = lurl.Host
|
||||
}
|
||||
|
||||
switch rurl.Scheme {
|
||||
switch rScheme {
|
||||
case "tcp":
|
||||
return stream.NewTCPTCPStream(r.LisURL.Scheme, laddr, rurl.Host)
|
||||
return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host)
|
||||
case "udp":
|
||||
return stream.NewUDPUDPStream(r.LisURL.Scheme, laddr, rurl.Host)
|
||||
return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host)
|
||||
}
|
||||
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
||||
}
|
||||
|
||||
@@ -14,10 +14,13 @@ import (
|
||||
)
|
||||
|
||||
type TCPTCPStream struct {
|
||||
network string
|
||||
listener net.Listener
|
||||
laddr *net.TCPAddr
|
||||
dst *net.TCPAddr
|
||||
|
||||
network string
|
||||
dstNetwork string
|
||||
|
||||
laddr *net.TCPAddr
|
||||
dst *net.TCPAddr
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
@@ -25,8 +28,8 @@ type TCPTCPStream struct {
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewTCPTCPStream(network, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveTCPAddr(network, dstAddr)
|
||||
func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveTCPAddr(dstNetwork, dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -34,7 +37,7 @@ func NewTCPTCPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPTCPStream{network: network, laddr: laddr, dst: dst}, nil
|
||||
return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst}, nil
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
@@ -72,7 +75,7 @@ func (s *TCPTCPStream) LocalAddr() net.Addr {
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Str("protocol", "tcp-tcp")
|
||||
e.Str("protocol", s.network+"->"+s.dstNetwork)
|
||||
|
||||
if s.listener != nil {
|
||||
e.Str("listen", s.listener.Addr().String())
|
||||
@@ -127,7 +130,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
dstConn, err := net.DialTCP("tcp", nil, s.dst)
|
||||
dstConn, err := net.DialTCP(s.dstNetwork, nil, s.dst)
|
||||
if err != nil {
|
||||
if !s.closed.Load() {
|
||||
logErr(s, err, "failed to dial destination")
|
||||
|
||||
@@ -17,9 +17,11 @@ import (
|
||||
)
|
||||
|
||||
type UDPUDPStream struct {
|
||||
network string
|
||||
listener net.PacketConn
|
||||
|
||||
network string
|
||||
dstNetwork string
|
||||
|
||||
laddr *net.UDPAddr
|
||||
dst *net.UDPAddr
|
||||
|
||||
@@ -35,7 +37,7 @@ type UDPUDPStream struct {
|
||||
|
||||
type udpUDPConn struct {
|
||||
srcAddr *net.UDPAddr
|
||||
dstConn *net.UDPConn
|
||||
dstConn net.Conn
|
||||
listener net.PacketConn
|
||||
lastUsed atomic.Time
|
||||
closed atomic.Bool
|
||||
@@ -51,8 +53,8 @@ const (
|
||||
|
||||
var bufPool = synk.GetSizedBytesPool()
|
||||
|
||||
func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveUDPAddr(network, dstAddr)
|
||||
func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveUDPAddr(dstNetwork, dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -61,20 +63,21 @@ func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro
|
||||
return nil, err
|
||||
}
|
||||
return &UDPUDPStream{
|
||||
network: network,
|
||||
laddr: laddr,
|
||||
dst: dst,
|
||||
conns: make(map[string]*udpUDPConn),
|
||||
network: network,
|
||||
dstNetwork: dstNetwork,
|
||||
laddr: laddr,
|
||||
dst: dst,
|
||||
conns: make(map[string]*udpUDPConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
var err error
|
||||
s.listener, err = net.ListenUDP(s.network, s.laddr)
|
||||
l, err := net.ListenUDP(s.network, s.laddr)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to listen")
|
||||
return
|
||||
}
|
||||
s.listener = l
|
||||
if acl := acl.ActiveConfig.Load(); acl != nil {
|
||||
s.listener = acl.WrapUDP(s.listener)
|
||||
}
|
||||
@@ -114,7 +117,7 @@ func (s *UDPUDPStream) LocalAddr() net.Addr {
|
||||
}
|
||||
|
||||
func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) {
|
||||
e.Str("protocol", "udp-udp")
|
||||
e.Str("protocol", s.network+"->"+s.dstNetwork)
|
||||
if s.dst != nil {
|
||||
e.Str("dst", s.dst.String())
|
||||
}
|
||||
@@ -187,8 +190,12 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
||||
}
|
||||
}
|
||||
|
||||
// Create UDP connection to destination
|
||||
dstConn, err := net.DialUDP("udp", nil, s.dst)
|
||||
// Create connection to destination (direct UDP or via agent stream tunnel)
|
||||
var (
|
||||
dstConn net.Conn
|
||||
err error
|
||||
)
|
||||
dstConn, err = net.DialUDP(s.dstNetwork, nil, s.dst)
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to dial dst")
|
||||
return nil, false
|
||||
@@ -203,7 +210,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
||||
|
||||
// Send initial data before starting response handler
|
||||
if !conn.forwardToDestination(initialData) {
|
||||
dstConn.Close()
|
||||
_ = dstConn.Close()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -326,6 +333,6 @@ func (conn *udpUDPConn) Close() {
|
||||
|
||||
conn.closed.Store(true)
|
||||
|
||||
conn.dstConn.Close()
|
||||
_ = conn.dstConn.Close()
|
||||
conn.dstConn = nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user