diff --git a/internal/route/route.go b/internal/route/route.go index 52f93301..f08383e5 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -269,16 +269,19 @@ func (r *Route) validate() gperr.Error { switch r.Scheme { case route.SchemeFileServer: - r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root) r.Host = "" r.Port.Proxy = 0 + r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root) case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C: if r.Port.Listening != 0 { errs.Addf("unexpected listening port for %s scheme", r.Scheme) } r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) case route.SchemeTCP, route.SchemeUDP: - if !r.ShouldExclude() { + if r.ShouldExclude() { + // should exclude, we don't care the scheme here. + r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) + } else { if r.Bind == "" { r.Bind = "0.0.0.0" } @@ -286,23 +289,28 @@ func (r *Route) validate() gperr.Error { if bindIP == nil { return gperr.Errorf("invalid bind address %s", r.Bind) } - var scheme string - if bindIP.To4() == nil { // IPv6 - if r.Scheme == route.SchemeTCP { - scheme = "tcp6" - } else { - scheme = "udp6" - } - } else { - if r.Scheme == route.SchemeTCP { - scheme = "tcp4" - } else { - scheme = "udp4" - } + remoteIP := net.ParseIP(r.Host) + if remoteIP == nil { + return gperr.Errorf("invalid remote address %s", r.Host) } - r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", scheme, r.Bind, r.Port.Listening)) + toNetwork := func(ip net.IP, scheme route.Scheme) string { + if ip.To4() == nil { + if scheme == route.SchemeTCP { + return "tcp6" + } + return "udp6" + } + if scheme == route.SchemeTCP { + return "tcp4" + } + return "udp4" + } + lScheme := toNetwork(bindIP, r.Scheme) + rScheme := toNetwork(remoteIP, r.Scheme) + + r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", lScheme, r.Bind, r.Port.Listening)) + r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", rScheme, r.Host, r.Port.Proxy)) } - r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy)) } if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) { diff --git a/internal/route/stream.go b/internal/route/stream.go index f337ff16..0b6184b7 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -70,7 +70,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error { r.ListenAndServe(r.task.Context(), nil, nil) r.l = log.With(). - Str("type", r.LisURL.Scheme). + Str("type", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme). Str("name", r.Name()). Stringer("rurl", r.ProxyURL). Stringer("laddr", r.LocalAddr()).Logger() @@ -102,9 +102,10 @@ func (r *StreamRoute) LocalAddr() net.Addr { func (r *StreamRoute) initStream() (nettypes.Stream, error) { lurl, rurl := r.LisURL, r.ProxyURL - // lurl scheme is either tcp4/tcp6 -> tcp, udp4/udp6 -> udp - // rurl scheme does not have the trailing 4/6 - if strings.TrimRight(lurl.Scheme, "46") != rurl.Scheme { + // tcp4/tcp6 -> tcp, udp4/udp6 -> udp + lScheme := strings.TrimRight(lurl.Scheme, "46") + rScheme := strings.TrimRight(rurl.Scheme, "46") + if lScheme != rScheme { return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme) } @@ -113,11 +114,11 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) { laddr = lurl.Host } - switch rurl.Scheme { + switch rScheme { case "tcp": - return stream.NewTCPTCPStream(r.LisURL.Scheme, laddr, rurl.Host) + return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host) case "udp": - return stream.NewUDPUDPStream(r.LisURL.Scheme, laddr, rurl.Host) + return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host) } return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme) } diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go index 429c7cc4..620f4d4f 100644 --- a/internal/route/stream/tcp_tcp.go +++ b/internal/route/stream/tcp_tcp.go @@ -14,10 +14,13 @@ import ( ) type TCPTCPStream struct { - network string listener net.Listener - laddr *net.TCPAddr - dst *net.TCPAddr + + network string + dstNetwork string + + laddr *net.TCPAddr + dst *net.TCPAddr preDial nettypes.HookFunc onRead nettypes.HookFunc @@ -25,8 +28,8 @@ type TCPTCPStream struct { closed atomic.Bool } -func NewTCPTCPStream(network, listenAddr, dstAddr string) (nettypes.Stream, error) { - dst, err := net.ResolveTCPAddr(network, dstAddr) +func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) { + dst, err := net.ResolveTCPAddr(dstNetwork, dstAddr) if err != nil { return nil, err } @@ -34,7 +37,7 @@ func NewTCPTCPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro if err != nil { return nil, err } - return &TCPTCPStream{network: network, laddr: laddr, dst: dst}, nil + return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst}, nil } func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { @@ -72,7 +75,7 @@ func (s *TCPTCPStream) LocalAddr() net.Addr { } func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) { - e.Str("protocol", "tcp-tcp") + e.Str("protocol", s.network+"->"+s.dstNetwork) if s.listener != nil { e.Str("listen", s.listener.Addr().String()) @@ -127,7 +130,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { return } - dstConn, err := net.DialTCP("tcp", nil, s.dst) + dstConn, err := net.DialTCP(s.dstNetwork, nil, s.dst) if err != nil { if !s.closed.Load() { logErr(s, err, "failed to dial destination") diff --git a/internal/route/stream/udp_udp.go b/internal/route/stream/udp_udp.go index 8600e962..7b8615ec 100644 --- a/internal/route/stream/udp_udp.go +++ b/internal/route/stream/udp_udp.go @@ -17,9 +17,11 @@ import ( ) type UDPUDPStream struct { - network string listener net.PacketConn + network string + dstNetwork string + laddr *net.UDPAddr dst *net.UDPAddr @@ -35,7 +37,7 @@ type UDPUDPStream struct { type udpUDPConn struct { srcAddr *net.UDPAddr - dstConn *net.UDPConn + dstConn net.Conn listener net.PacketConn lastUsed atomic.Time closed atomic.Bool @@ -51,8 +53,8 @@ const ( var bufPool = synk.GetSizedBytesPool() -func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, error) { - dst, err := net.ResolveUDPAddr(network, dstAddr) +func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) { + dst, err := net.ResolveUDPAddr(dstNetwork, dstAddr) if err != nil { return nil, err } @@ -61,20 +63,21 @@ func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro return nil, err } return &UDPUDPStream{ - network: network, - laddr: laddr, - dst: dst, - conns: make(map[string]*udpUDPConn), + network: network, + dstNetwork: dstNetwork, + laddr: laddr, + dst: dst, + conns: make(map[string]*udpUDPConn), }, nil } func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { - var err error - s.listener, err = net.ListenUDP(s.network, s.laddr) + l, err := net.ListenUDP(s.network, s.laddr) if err != nil { logErr(s, err, "failed to listen") return } + s.listener = l if acl := acl.ActiveConfig.Load(); acl != nil { s.listener = acl.WrapUDP(s.listener) } @@ -114,7 +117,7 @@ func (s *UDPUDPStream) LocalAddr() net.Addr { } func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) { - e.Str("protocol", "udp-udp") + e.Str("protocol", s.network+"->"+s.dstNetwork) if s.dst != nil { e.Str("dst", s.dst.String()) } @@ -187,8 +190,12 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd } } - // Create UDP connection to destination - dstConn, err := net.DialUDP("udp", nil, s.dst) + // Create connection to destination (direct UDP or via agent stream tunnel) + var ( + dstConn net.Conn + err error + ) + dstConn, err = net.DialUDP(s.dstNetwork, nil, s.dst) if err != nil { logErr(s, err, "failed to dial dst") return nil, false @@ -203,7 +210,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd // Send initial data before starting response handler if !conn.forwardToDestination(initialData) { - dstConn.Close() + _ = dstConn.Close() return nil, false } @@ -326,6 +333,6 @@ func (conn *udpUDPConn) Close() { conn.closed.Store(true) - conn.dstConn.Close() + _ = conn.dstConn.Close() conn.dstConn = nil }