mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 17:28:53 +02:00
fix(stream): properly handle remote stream scheme IPv4/6
This commit is contained in:
@@ -270,16 +270,19 @@ func (r *Route) validate() gperr.Error {
|
|||||||
|
|
||||||
switch r.Scheme {
|
switch r.Scheme {
|
||||||
case route.SchemeFileServer:
|
case route.SchemeFileServer:
|
||||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
|
|
||||||
r.Host = ""
|
r.Host = ""
|
||||||
r.Port.Proxy = 0
|
r.Port.Proxy = 0
|
||||||
|
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
|
||||||
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
|
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
|
||||||
if r.Port.Listening != 0 {
|
if r.Port.Listening != 0 {
|
||||||
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
|
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
|
||||||
}
|
}
|
||||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||||
case route.SchemeTCP, route.SchemeUDP:
|
case route.SchemeTCP, route.SchemeUDP:
|
||||||
if !r.ShouldExclude() {
|
if r.ShouldExclude() {
|
||||||
|
// should exclude, we don't care the scheme here.
|
||||||
|
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
||||||
|
} else {
|
||||||
if r.Bind == "" {
|
if r.Bind == "" {
|
||||||
r.Bind = "0.0.0.0"
|
r.Bind = "0.0.0.0"
|
||||||
}
|
}
|
||||||
@@ -287,23 +290,28 @@ func (r *Route) validate() gperr.Error {
|
|||||||
if bindIP == nil {
|
if bindIP == nil {
|
||||||
return gperr.Errorf("invalid bind address %s", r.Bind)
|
return gperr.Errorf("invalid bind address %s", r.Bind)
|
||||||
}
|
}
|
||||||
var scheme string
|
remoteIP := net.ParseIP(r.Host)
|
||||||
if bindIP.To4() == nil { // IPv6
|
if remoteIP == nil {
|
||||||
if r.Scheme == route.SchemeTCP {
|
return gperr.Errorf("invalid remote address %s", r.Host)
|
||||||
scheme = "tcp6"
|
|
||||||
} else {
|
|
||||||
scheme = "udp6"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if r.Scheme == route.SchemeTCP {
|
|
||||||
scheme = "tcp4"
|
|
||||||
} else {
|
|
||||||
scheme = "udp4"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", scheme, r.Bind, r.Port.Listening))
|
toNetwork := func(ip net.IP, scheme route.Scheme) string {
|
||||||
|
if ip.To4() == nil {
|
||||||
|
if scheme == route.SchemeTCP {
|
||||||
|
return "tcp6"
|
||||||
|
}
|
||||||
|
return "udp6"
|
||||||
|
}
|
||||||
|
if scheme == route.SchemeTCP {
|
||||||
|
return "tcp4"
|
||||||
|
}
|
||||||
|
return "udp4"
|
||||||
|
}
|
||||||
|
lScheme := toNetwork(bindIP, r.Scheme)
|
||||||
|
rScheme := toNetwork(remoteIP, r.Scheme)
|
||||||
|
|
||||||
|
r.LisURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", lScheme, r.Bind, r.Port.Listening))
|
||||||
|
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", rScheme, r.Host, r.Port.Proxy))
|
||||||
}
|
}
|
||||||
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, fmt.Sprintf("%s://%s:%d", r.Scheme, r.Host, r.Port.Proxy))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) {
|
if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) {
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func (r *StreamRoute) Start(parent task.Parent) gperr.Error {
|
|||||||
|
|
||||||
r.ListenAndServe(r.task.Context(), nil, nil)
|
r.ListenAndServe(r.task.Context(), nil, nil)
|
||||||
r.l = log.With().
|
r.l = log.With().
|
||||||
Str("type", r.LisURL.Scheme).
|
Str("type", r.LisURL.Scheme+"->"+r.ProxyURL.Scheme).
|
||||||
Str("name", r.Name()).
|
Str("name", r.Name()).
|
||||||
Stringer("rurl", r.ProxyURL).
|
Stringer("rurl", r.ProxyURL).
|
||||||
Stringer("laddr", r.LocalAddr()).Logger()
|
Stringer("laddr", r.LocalAddr()).Logger()
|
||||||
@@ -102,9 +102,10 @@ func (r *StreamRoute) LocalAddr() net.Addr {
|
|||||||
|
|
||||||
func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
||||||
lurl, rurl := r.LisURL, r.ProxyURL
|
lurl, rurl := r.LisURL, r.ProxyURL
|
||||||
// lurl scheme is either tcp4/tcp6 -> tcp, udp4/udp6 -> udp
|
// tcp4/tcp6 -> tcp, udp4/udp6 -> udp
|
||||||
// rurl scheme does not have the trailing 4/6
|
lScheme := strings.TrimRight(lurl.Scheme, "46")
|
||||||
if strings.TrimRight(lurl.Scheme, "46") != rurl.Scheme {
|
rScheme := strings.TrimRight(rurl.Scheme, "46")
|
||||||
|
if lScheme != rScheme {
|
||||||
return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme)
|
return nil, fmt.Errorf("incoherent scheme is not yet supported: %s != %s", lurl.Scheme, rurl.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,11 +114,11 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
|||||||
laddr = lurl.Host
|
laddr = lurl.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
switch rurl.Scheme {
|
switch rScheme {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return stream.NewTCPTCPStream(r.LisURL.Scheme, laddr, rurl.Host)
|
return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host)
|
||||||
case "udp":
|
case "udp":
|
||||||
return stream.NewUDPUDPStream(r.LisURL.Scheme, laddr, rurl.Host)
|
return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TCPTCPStream struct {
|
type TCPTCPStream struct {
|
||||||
network string
|
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
laddr *net.TCPAddr
|
|
||||||
dst *net.TCPAddr
|
network string
|
||||||
|
dstNetwork string
|
||||||
|
|
||||||
|
laddr *net.TCPAddr
|
||||||
|
dst *net.TCPAddr
|
||||||
|
|
||||||
preDial nettypes.HookFunc
|
preDial nettypes.HookFunc
|
||||||
onRead nettypes.HookFunc
|
onRead nettypes.HookFunc
|
||||||
@@ -25,8 +28,8 @@ type TCPTCPStream struct {
|
|||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTCPTCPStream(network, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||||
dst, err := net.ResolveTCPAddr(network, dstAddr)
|
dst, err := net.ResolveTCPAddr(dstNetwork, dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -34,7 +37,7 @@ func NewTCPTCPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &TCPTCPStream{network: network, laddr: laddr, dst: dst}, nil
|
return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||||
@@ -72,7 +75,7 @@ func (s *TCPTCPStream) LocalAddr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
|
func (s *TCPTCPStream) MarshalZerologObject(e *zerolog.Event) {
|
||||||
e.Str("protocol", "tcp-tcp")
|
e.Str("protocol", s.network+"->"+s.dstNetwork)
|
||||||
|
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
e.Str("listen", s.listener.Addr().String())
|
e.Str("listen", s.listener.Addr().String())
|
||||||
@@ -127,7 +130,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dstConn, err := net.DialTCP("tcp", nil, s.dst)
|
dstConn, err := net.DialTCP(s.dstNetwork, nil, s.dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !s.closed.Load() {
|
if !s.closed.Load() {
|
||||||
logErr(s, err, "failed to dial destination")
|
logErr(s, err, "failed to dial destination")
|
||||||
|
|||||||
@@ -17,9 +17,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type UDPUDPStream struct {
|
type UDPUDPStream struct {
|
||||||
network string
|
|
||||||
listener net.PacketConn
|
listener net.PacketConn
|
||||||
|
|
||||||
|
network string
|
||||||
|
dstNetwork string
|
||||||
|
|
||||||
laddr *net.UDPAddr
|
laddr *net.UDPAddr
|
||||||
dst *net.UDPAddr
|
dst *net.UDPAddr
|
||||||
|
|
||||||
@@ -35,7 +37,7 @@ type UDPUDPStream struct {
|
|||||||
|
|
||||||
type udpUDPConn struct {
|
type udpUDPConn struct {
|
||||||
srcAddr *net.UDPAddr
|
srcAddr *net.UDPAddr
|
||||||
dstConn *net.UDPConn
|
dstConn net.Conn
|
||||||
listener net.PacketConn
|
listener net.PacketConn
|
||||||
lastUsed atomic.Time
|
lastUsed atomic.Time
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
@@ -51,8 +53,8 @@ const (
|
|||||||
|
|
||||||
var bufPool = synk.GetSizedBytesPool()
|
var bufPool = synk.GetSizedBytesPool()
|
||||||
|
|
||||||
func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||||
dst, err := net.ResolveUDPAddr(network, dstAddr)
|
dst, err := net.ResolveUDPAddr(dstNetwork, dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -61,20 +63,21 @@ func NewUDPUDPStream(network, listenAddr, dstAddr string) (nettypes.Stream, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &UDPUDPStream{
|
return &UDPUDPStream{
|
||||||
network: network,
|
network: network,
|
||||||
laddr: laddr,
|
dstNetwork: dstNetwork,
|
||||||
dst: dst,
|
laddr: laddr,
|
||||||
conns: make(map[string]*udpUDPConn),
|
dst: dst,
|
||||||
|
conns: make(map[string]*udpUDPConn),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
func (s *UDPUDPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||||
var err error
|
l, err := net.ListenUDP(s.network, s.laddr)
|
||||||
s.listener, err = net.ListenUDP(s.network, s.laddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(s, err, "failed to listen")
|
logErr(s, err, "failed to listen")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
s.listener = l
|
||||||
if acl := acl.ActiveConfig.Load(); acl != nil {
|
if acl := acl.ActiveConfig.Load(); acl != nil {
|
||||||
s.listener = acl.WrapUDP(s.listener)
|
s.listener = acl.WrapUDP(s.listener)
|
||||||
}
|
}
|
||||||
@@ -114,7 +117,7 @@ func (s *UDPUDPStream) LocalAddr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) {
|
func (s *UDPUDPStream) MarshalZerologObject(e *zerolog.Event) {
|
||||||
e.Str("protocol", "udp-udp")
|
e.Str("protocol", s.network+"->"+s.dstNetwork)
|
||||||
if s.dst != nil {
|
if s.dst != nil {
|
||||||
e.Str("dst", s.dst.String())
|
e.Str("dst", s.dst.String())
|
||||||
}
|
}
|
||||||
@@ -187,8 +190,12 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create UDP connection to destination
|
// Create connection to destination (direct UDP or via agent stream tunnel)
|
||||||
dstConn, err := net.DialUDP("udp", nil, s.dst)
|
var (
|
||||||
|
dstConn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
dstConn, err = net.DialUDP(s.dstNetwork, nil, s.dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(s, err, "failed to dial dst")
|
logErr(s, err, "failed to dial dst")
|
||||||
return nil, false
|
return nil, false
|
||||||
@@ -203,7 +210,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
|||||||
|
|
||||||
// Send initial data before starting response handler
|
// Send initial data before starting response handler
|
||||||
if !conn.forwardToDestination(initialData) {
|
if !conn.forwardToDestination(initialData) {
|
||||||
dstConn.Close()
|
_ = dstConn.Close()
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -326,6 +333,6 @@ func (conn *udpUDPConn) Close() {
|
|||||||
|
|
||||||
conn.closed.Store(true)
|
conn.closed.Store(true)
|
||||||
|
|
||||||
conn.dstConn.Close()
|
_ = conn.dstConn.Close()
|
||||||
conn.dstConn = nil
|
conn.dstConn = nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user