diff --git a/config.example.yml b/config.example.yml index a4bf4ecf..4de8a686 100644 --- a/config.example.yml +++ b/config.example.yml @@ -52,6 +52,9 @@ entrypoint: # Note that HTTP/3 with proxy protocol is not supported yet. support_proxy_protocol: false + # To relay the client address to a TCP upstream, enable `relay_proxy_protocol_header: true` + # on that specific TCP route. UDP relay is not supported yet. + # Below define an example of middleware config # 1. set security headers # 2. block non local IP connections diff --git a/internal/route/provider/all_fields.yaml b/internal/route/provider/all_fields.yaml index 451a32ee..009ac0bd 100644 --- a/internal/route/provider/all_fields.yaml +++ b/internal/route/provider/all_fields.yaml @@ -3,6 +3,7 @@ example: # matching `example.y.z` host: 10.0.0.254 port: 80 bind: 0.0.0.0 + relay_proxy_protocol_header: false # tcp only, sends PROXY header to upstream root: /var/www/example spa: true index: index.html diff --git a/internal/route/route.go b/internal/route/route.go index 28f0c0e9..b477f62e 100644 --- a/internal/route/route.go +++ b/internal/route/route.go @@ -54,15 +54,16 @@ type ( Index string `json:"index,omitempty"` // Index file to serve for single-page app mode route.HTTPConfig - PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"` - Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"` - RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"` - HealthCheck types.HealthCheckConfig `json:"healthcheck,omitzero" extensions:"x-nullable"` // null on load-balancer routes - LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"` - Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"` - Homepage *homepage.ItemConfig `json:"homepage"` - AccessLog *accesslog.RequestLoggerConfig `json:"access_log,omitempty" extensions:"x-nullable"` - Agent string `json:"agent,omitempty"` + PathPatterns []string `json:"path_patterns,omitempty" extensions:"x-nullable"` + Rules rules.Rules `json:"rules,omitempty" extensions:"x-nullable"` + RuleFile string `json:"rule_file,omitempty" extensions:"x-nullable"` + HealthCheck types.HealthCheckConfig `json:"healthcheck,omitzero" extensions:"x-nullable"` // null on load-balancer routes + LoadBalance *types.LoadBalancerConfig `json:"load_balance,omitempty" extensions:"x-nullable"` + Middlewares map[string]types.LabelMap `json:"middlewares,omitempty" extensions:"x-nullable"` + Homepage *homepage.ItemConfig `json:"homepage"` + AccessLog *accesslog.RequestLoggerConfig `json:"access_log,omitempty" extensions:"x-nullable"` + RelayProxyProtocolHeader bool `json:"relay_proxy_protocol_header,omitempty"` // TCP only: relay PROXY protocol header to the destination + Agent string `json:"agent,omitempty"` Proxmox *proxmox.NodeConfig `json:"proxmox,omitempty" extensions:"x-nullable"` @@ -310,6 +311,9 @@ func (r *Route) validate() error { if !r.UseHealthCheck() && (r.UseLoadBalance() || r.UseIdleWatcher()) { errs.Adds("cannot disable healthcheck when loadbalancer or idle watcher is enabled") } + if r.RelayProxyProtocolHeader && r.Scheme != route.SchemeTCP { + errs.Adds("relay_proxy_protocol_header is only supported for tcp routes") + } if errs.HasError() { return errs.Error() diff --git a/internal/route/route_test.go b/internal/route/route_test.go index 8be85c1b..c5c7ab4e 100644 --- a/internal/route/route_test.go +++ b/internal/route/route_test.go @@ -78,6 +78,19 @@ func TestRouteValidate(t *testing.T) { require.NotNil(t, r.impl, "Impl should be initialized") }) + t.Run("RelayProxyProtocolHeaderTCPOnly", func(t *testing.T) { + r := &Route{ + Alias: "test-udp-relay", + Scheme: route.SchemeUDP, + Host: "127.0.0.1", + Port: route.Port{Proxy: 53, Listening: 53}, + RelayProxyProtocolHeader: true, + } + err := r.Validate() + require.Error(t, err, "Validate should reject proxy protocol relay on UDP routes") + require.ErrorContains(t, err, "relay_proxy_protocol_header is only supported for tcp routes") + }) + t.Run("DockerContainer", func(t *testing.T) { r := &Route{ Alias: "test", diff --git a/internal/route/stream.go b/internal/route/stream.go index cdaa6158..faf56fb3 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -110,7 +110,14 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) { switch rScheme { case "tcp": - return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host, r.GetAgent()) + return stream.NewTCPTCPStream( + lurl.Scheme, + rurl.Scheme, + laddr, + rurl.Host, + r.GetAgent(), + r.RelayProxyProtocolHeader, + ) case "udp": return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host, r.GetAgent()) } diff --git a/internal/route/stream/README.md b/internal/route/stream/README.md index fed29b44..86c3bb9e 100644 --- a/internal/route/stream/README.md +++ b/internal/route/stream/README.md @@ -181,6 +181,7 @@ routes: scheme: tcp4 bind: 0.0.0.0 # optional port: 2222:22 # listening port: target port + relay_proxy_protocol_header: true # optional, tcp only dns-proxy: scheme: udp4 @@ -223,6 +224,7 @@ Log context includes: `protocol`, `listen`, `dst`, `action` - ACL wrapping available for TCP and UDP listeners - PROXY protocol support for original client IP +- TCP routes can optionally emit a fresh upstream PROXY v2 header with `relay_proxy_protocol_header: true` - No protocol validation (relies on upstream) - Connection limits managed by OS diff --git a/internal/route/stream/proxyproto.go b/internal/route/stream/proxyproto.go new file mode 100644 index 00000000..2e4039e9 --- /dev/null +++ b/internal/route/stream/proxyproto.go @@ -0,0 +1,37 @@ +package stream + +import ( + "fmt" + "io" + "net" + + "github.com/pires/go-proxyproto" +) + +func writeProxyProtocolHeader(dst io.Writer, src net.Conn) error { + srcAddr, ok := src.RemoteAddr().(*net.TCPAddr) + if !ok { + return fmt.Errorf("unexpected source address type %T", src.RemoteAddr()) + } + dstAddr, ok := src.LocalAddr().(*net.TCPAddr) + if !ok { + return fmt.Errorf("unexpected destination address type %T", src.LocalAddr()) + } + + header := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: transportProtocol(srcAddr, dstAddr), + SourceAddr: srcAddr, + DestinationAddr: dstAddr, + } + _, err := header.WriteTo(dst) + return err +} + +func transportProtocol(src, dst *net.TCPAddr) proxyproto.AddressFamilyAndProtocol { + if src.IP.To4() != nil && dst.IP.To4() != nil { + return proxyproto.TCPv4 + } + return proxyproto.TCPv6 +} diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go index 16e328b0..9cbbc04d 100644 --- a/internal/route/stream/tcp_tcp.go +++ b/internal/route/stream/tcp_tcp.go @@ -25,13 +25,15 @@ type TCPTCPStream struct { dst *net.TCPAddr agent *agentpool.Agent + relayProxyProtocolHeader bool + preDial nettypes.HookFunc onRead nettypes.HookFunc closed atomic.Bool } -func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent) (nettypes.Stream, error) { +func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent, relayProxyProtocolHeader bool) (nettypes.Stream, error) { dst, err := net.ResolveTCPAddr(dstNetwork, dstAddr) if err != nil { return nil, err @@ -40,7 +42,14 @@ func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *age if err != nil { return nil, err } - return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil + return &TCPTCPStream{ + network: network, + dstNetwork: dstNetwork, + laddr: laddr, + dst: dst, + agent: agent, + relayProxyProtocolHeader: relayProxyProtocolHeader, + }, nil } func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) error { @@ -158,6 +167,14 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { if s.closed.Load() { return } + if s.relayProxyProtocolHeader { + if err := writeProxyProtocolHeader(dstConn, conn); err != nil { + if !s.closed.Load() { + logErr(s, err, "failed to write proxy protocol header") + } + return + } + } src := conn dst := dstConn diff --git a/internal/route/stream/tcp_tcp_test.go b/internal/route/stream/tcp_tcp_test.go new file mode 100644 index 00000000..b6768b72 --- /dev/null +++ b/internal/route/stream/tcp_tcp_test.go @@ -0,0 +1,148 @@ +package stream + +import ( + "bufio" + "context" + "io" + "net" + "testing" + + "github.com/pires/go-proxyproto" + entrypoint "github.com/yusing/godoxy/internal/entrypoint" + entrypointtypes "github.com/yusing/godoxy/internal/entrypoint/types" + "github.com/yusing/goutils/task" + + "github.com/stretchr/testify/require" +) + +func TestTCPTCPStreamRelayProxyProtocolHeader(t *testing.T) { + t.Run("Disabled", func(t *testing.T) { + upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer upstreamLn.Close() + + s, err := NewTCPTCPStream("tcp", "tcp", "127.0.0.1:0", upstreamLn.Addr().String(), nil, false) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + require.NoError(t, s.ListenAndServe(ctx, nil, nil)) + defer s.Close() + + client, err := net.Dial("tcp", s.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + _, err = client.Write([]byte("ping")) + require.NoError(t, err) + + upstreamConn, err := upstreamLn.Accept() + require.NoError(t, err) + defer upstreamConn.Close() + + payload := make([]byte, 4) + _, err = io.ReadFull(upstreamConn, payload) + require.NoError(t, err) + require.Equal(t, []byte("ping"), payload) + }) + + t.Run("Enabled", func(t *testing.T) { + upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer upstreamLn.Close() + + s, err := NewTCPTCPStream("tcp", "tcp", "127.0.0.1:0", upstreamLn.Addr().String(), nil, true) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + require.NoError(t, s.ListenAndServe(ctx, nil, nil)) + defer s.Close() + + client, err := net.Dial("tcp", s.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + _, err = client.Write([]byte("ping")) + require.NoError(t, err) + + upstreamConn, err := upstreamLn.Accept() + require.NoError(t, err) + defer upstreamConn.Close() + + reader := bufio.NewReader(upstreamConn) + header, err := proxyproto.Read(reader) + require.NoError(t, err) + require.Equal(t, proxyproto.PROXY, header.Command) + + srcAddr, ok := header.SourceAddr.(*net.TCPAddr) + require.True(t, ok) + dstAddr, ok := header.DestinationAddr.(*net.TCPAddr) + require.True(t, ok) + require.Equal(t, client.LocalAddr().String(), srcAddr.String()) + require.Equal(t, s.LocalAddr().String(), dstAddr.String()) + + payload := make([]byte, 4) + _, err = io.ReadFull(reader, payload) + require.NoError(t, err) + require.Equal(t, []byte("ping"), payload) + }) +} + +func TestTCPTCPStreamRelayProxyProtocolUsesIncomingProxyHeader(t *testing.T) { + upstreamLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer upstreamLn.Close() + + s, err := NewTCPTCPStream("tcp", "tcp", "127.0.0.1:0", upstreamLn.Addr().String(), nil, true) + require.NoError(t, err) + + parent := task.GetTestTask(t) + ep := entrypoint.NewEntrypoint(parent, &entrypoint.Config{ + SupportProxyProtocol: true, + }) + entrypointtypes.SetCtx(parent, ep) + + ctx, cancel := context.WithCancel(parent.Context()) + defer cancel() + require.NoError(t, s.ListenAndServe(ctx, nil, nil)) + defer s.Close() + + client, err := net.Dial("tcp", s.LocalAddr().String()) + require.NoError(t, err) + defer client.Close() + + downstreamHeader := &proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + TransportProtocol: proxyproto.TCPv4, + SourceAddr: &net.TCPAddr{ + IP: net.ParseIP("203.0.113.10"), + Port: 42300, + }, + DestinationAddr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: s.LocalAddr().(*net.TCPAddr).Port, + }, + } + _, err = downstreamHeader.WriteTo(client) + require.NoError(t, err) + + _, err = client.Write([]byte("pong")) + require.NoError(t, err) + + upstreamConn, err := upstreamLn.Accept() + require.NoError(t, err) + defer upstreamConn.Close() + + reader := bufio.NewReader(upstreamConn) + header, err := proxyproto.Read(reader) + require.NoError(t, err) + require.Equal(t, downstreamHeader.SourceAddr.String(), header.SourceAddr.String()) + require.Equal(t, downstreamHeader.DestinationAddr.String(), header.DestinationAddr.String()) + + payload := make([]byte, 4) + _, err = io.ReadFull(reader, payload) + require.NoError(t, err) + require.Equal(t, []byte("pong"), payload) +} diff --git a/providers.example.yml b/providers.example.yml index 63f17436..ed897ebc 100644 --- a/providers.example.yml +++ b/providers.example.yml @@ -26,3 +26,9 @@ app2: scheme: udp host: 10.0.0.2 port: 2223:dns + +ssh-with-proxy-protocol: + scheme: tcp + host: 10.0.0.3 + port: 2222:22 + relay_proxy_protocol_header: true