diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 307422ed..5112e91a 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -1,21 +1,31 @@ package main import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "net" + "net/http" "os" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/yusing/godoxy/agent/pkg/agent" + "github.com/yusing/godoxy/agent/pkg/agent/stream" "github.com/yusing/godoxy/agent/pkg/env" - "github.com/yusing/godoxy/agent/pkg/server" + "github.com/yusing/godoxy/agent/pkg/handler" "github.com/yusing/godoxy/internal/metrics/systeminfo" socketproxy "github.com/yusing/godoxy/socketproxy/pkg" + gperr "github.com/yusing/goutils/errs" httpServer "github.com/yusing/goutils/server" strutils "github.com/yusing/goutils/strings" "github.com/yusing/goutils/task" "github.com/yusing/goutils/version" ) +// TODO: support IPv6 + func main() { writer := zerolog.ConsoleWriter{ Out: os.Stderr, @@ -52,16 +62,84 @@ func main() { Tips: 1. To change the agent name, you can set the AGENT_NAME environment variable. 2. To change the agent port, you can set the AGENT_PORT environment variable. -`) + `) t := task.RootTask("agent", false) - opts := server.Options{ - CACert: caCert, - ServerCert: srvCert, - Port: env.AgentPort, + + // One TCP listener on AGENT_PORT, then multiplex by TLS ALPN: + // - Stream ALPN: route to TCP stream tunnel handler (via http.Server.TLSNextProto) + // - Otherwise: route to HTTPS API handler + tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort}) + if err != nil { + gperr.LogFatal("failed to listen on port", err) } - server.StartAgentServer(t, opts) + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert.Leaf) + + muxTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{*srvCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + // Keep HTTP limited to HTTP/1.1 (matching current agent server behavior) + // and add the stream tunnel ALPN for multiplexing. + NextProtos: []string{"http/1.1", stream.StreamALPN}, + } + if env.AgentSkipClientCertCheck { + muxTLSConfig.ClientAuth = tls.NoClientCert + } + + // TLS listener feeds the HTTP server. ALPN stream connections are intercepted + // using http.Server.TLSNextProto. + tlsLn := tls.NewListener(tcpListener, muxTLSConfig) + + streamSrv := stream.NewTCPServerHandler(t.Context()) + + httpSrv := &http.Server{ + Handler: handler.NewAgentHandler(), + BaseContext: func(net.Listener) context.Context { + return t.Context() + }, + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ + // When a client negotiates StreamALPN, net/http will call this hook instead + // of treating the connection as HTTP. + stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + // ServeConn blocks until the tunnel finishes. + streamSrv.ServeConn(conn) + }, + }, + } + { + subtask := t.Subtask("agent-http", true) + t.OnCancel("stop_http", func() { + _ = streamSrv.Close() + _ = httpSrv.Close() + _ = tlsLn.Close() + }) + go func() { + err := httpSrv.Serve(tlsLn) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Error().Err(err).Msg("agent HTTP server stopped with error") + } + subtask.Finish(err) + }() + log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started (ALPN mux enabled)") + } + log.Info().Int("port", env.AgentPort).Msg("TCP stream handler started (via TLSNextProto)") + + { + udpServer := stream.NewUDPServer(t.Context(), "udp", &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert) + subtask := t.Subtask("agent-stream-udp", true) + t.OnCancel("stop_stream_udp", func() { + _ = udpServer.Close() + }) + go func() { + err := udpServer.Start() + subtask.Finish(err) + }() + log.Info().Int("port", env.AgentPort).Msg("UDP stream server started") + } if socketproxy.ListenAddr != "" { runtime := strutils.Title(string(env.Runtime)) diff --git a/agent/go.mod b/agent/go.mod index a37eb1bc..f0eabd4e 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -18,6 +18,8 @@ require ( github.com/bytedance/sonic v1.14.2 github.com/gin-gonic/gin v1.11.0 github.com/gorilla/websocket v1.5.3 + github.com/pion/dtls/v3 v3.0.10 + github.com/pion/transport/v3 v3.1.1 github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.11.1 github.com/yusing/godoxy v0.0.0-00010101000000-000000000000 @@ -72,6 +74,8 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect github.com/pires/go-proxyproto v0.8.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect diff --git a/agent/go.sum b/agent/go.sum index 26ba7305..97118415 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -146,6 +146,14 @@ github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5 github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pion/dtls/v3 v3.0.10 h1:k9ekkq1kaZoxnNEbyLKI8DI37j/Nbk1HWmMuywpQJgg= +github.com/pion/dtls/v3 v3.0.10/go.mod h1:YEmmBYIoBsY3jmG56dsziTv/Lca9y4Om83370CXfqJ8= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/agent/pkg/agent/common/common.go b/agent/pkg/agent/common/common.go new file mode 100644 index 00000000..882226f2 --- /dev/null +++ b/agent/pkg/agent/common/common.go @@ -0,0 +1,3 @@ +package common + +const CertsDNSName = "godoxy.agent" diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index bf6b102b..d7128f39 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -4,6 +4,8 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/json" + "encoding/pem" "errors" "fmt" "io" @@ -16,31 +18,51 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/yusing/godoxy/agent/pkg/agent/common" + agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream" "github.com/yusing/godoxy/agent/pkg/certs" + gperr "github.com/yusing/goutils/errs" httputils "github.com/yusing/goutils/http" "github.com/yusing/goutils/version" ) type AgentConfig struct { - Addr string `json:"addr"` - Name string `json:"name"` - Version version.Version `json:"version" swaggertype:"string"` - Runtime ContainerRuntime `json:"runtime"` + AgentInfo + + Addr string `json:"addr"` + IsTCPStreamSupported bool `json:"supports_tcp_stream"` + IsUDPStreamSupported bool `json:"supports_udp_stream"` + + // for stream + caCert *x509.Certificate + clientCert *tls.Certificate tlsConfig tls.Config - l zerolog.Logger + + l zerolog.Logger } // @name Agent +type AgentInfo struct { + Version version.Version `json:"version" swaggertype:"string"` + Name string `json:"name"` + Runtime ContainerRuntime `json:"runtime"` +} + +// Deprecated. Replaced by EndpointInfo const ( - EndpointVersion = "/version" - EndpointName = "/name" - EndpointRuntime = "/runtime" + EndpointVersion = "/version" + EndpointName = "/name" + EndpointRuntime = "/runtime" +) + +const ( + EndpointInfo = "/info" EndpointProxyHTTP = "/proxy/http" EndpointHealth = "/health" EndpointLogs = "/logs" EndpointSystemInfo = "/system_info" - AgentHost = CertsDNSName + AgentHost = common.CertsDNSName APIEndpointBase = "/godoxy/agent" APIBaseURL = "https://" + AgentHost + APIEndpointBase @@ -90,6 +112,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte) if err != nil { return err } + cfg.clientCert = &clientCert // create tls config caCertPool := x509.NewCertPool() @@ -97,58 +120,105 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte) if !ok { return errors.New("invalid ca certificate") } + // Keep the CA leaf for stream client dialing. + if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" { + return errors.New("invalid ca certificate") + } else if cert, err := x509.ParseCertificate(block.Bytes); err != nil { + return err + } else { + cfg.caCert = cert + } cfg.tlsConfig = tls.Config{ Certificates: []tls.Certificate{clientCert}, RootCAs: caCertPool, - ServerName: CertsDNSName, + ServerName: common.CertsDNSName, + MinVersion: tls.VersionTLS12, } ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - // get agent name - name, _, err := cfg.fetchString(ctx, EndpointName) + status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo) if err != nil { return err } - cfg.Name = name + var streamUnsupportedErrs gperr.Builder + + if status == http.StatusOK { + // test stream server connection + const fakeAddress = "localhost:8080" // it won't be used, just for testing + // test TCP stream support + err := agentstream.TCPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert) + if err != nil { + streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err) + } else { + cfg.IsTCPStreamSupported = true + } + + // test UDP stream support + err = agentstream.UDPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert) + if err != nil { + streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err) + } else { + cfg.IsUDPStreamSupported = true + } + } else { + // old agent does not support EndpointInfo + // fallback with old logic + cfg.IsTCPStreamSupported = false + cfg.IsUDPStreamSupported = false + streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling") + + // get agent name + name, _, err := cfg.fetchString(ctx, EndpointName) + if err != nil { + return err + } + + cfg.Name = name + + // check agent version + agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion) + if err != nil { + return err + } + + cfg.Version = version.Parse(agentVersion) + + // check agent runtime + runtime, status, err := cfg.fetchString(ctx, EndpointRuntime) + if err != nil { + return err + } + + switch status { + case http.StatusOK: + switch runtime { + case "docker": + cfg.Runtime = ContainerRuntimeDocker + // case "nerdctl": + // cfg.Runtime = ContainerRuntimeNerdctl + case "podman": + cfg.Runtime = ContainerRuntimePodman + default: + return fmt.Errorf("invalid agent runtime: %s", runtime) + } + case http.StatusNotFound: + // backward compatibility, old agent does not have runtime endpoint + cfg.Runtime = ContainerRuntimeDocker + default: + return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime) + } + } cfg.l = log.With().Str("agent", cfg.Name).Logger() - // check agent version - agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion) - if err != nil { - return err + if err := streamUnsupportedErrs.Error(); err != nil { + gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l) } - // check agent runtime - runtime, status, err := cfg.fetchString(ctx, EndpointRuntime) - if err != nil { - return err - } - switch status { - case http.StatusOK: - switch runtime { - case "docker": - cfg.Runtime = ContainerRuntimeDocker - // case "nerdctl": - // cfg.Runtime = ContainerRuntimeNerdctl - case "podman": - cfg.Runtime = ContainerRuntimePodman - default: - return fmt.Errorf("invalid agent runtime: %s", runtime) - } - case http.StatusNotFound: - // backward compatibility, old agent does not have runtime endpoint - cfg.Runtime = ContainerRuntimeDocker - default: - return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime) - } - - cfg.Version = version.Parse(agentVersion) - if serverVersion.IsNewerThanMajor(cfg.Version) { log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version) } @@ -177,6 +247,38 @@ func (cfg *AgentConfig) Init(ctx context.Context) error { return cfg.InitWithCerts(ctx, ca, crt, key) } +// NewTCPClient creates a new TCP client for the agent. +// +// It returns an error if +// - the agent is not initialized +// - the agent does not support TCP stream tunneling +// - the agent stream server address is not initialized +func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) { + if cfg.caCert == nil || cfg.clientCert == nil { + return nil, errors.New("agent is not initialized") + } + if !cfg.IsTCPStreamSupported { + return nil, errors.New("agent does not support TCP stream tunneling") + } + return agentstream.NewTCPClient(cfg.Addr, targetAddress, cfg.caCert, cfg.clientCert) +} + +// NewUDPClient creates a new UDP client for the agent. +// +// It returns an error if +// - the agent is not initialized +// - the agent does not support UDP stream tunneling +// - the agent stream server address is not initialized +func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) { + if cfg.caCert == nil || cfg.clientCert == nil { + return nil, errors.New("agent is not initialized") + } + if !cfg.IsUDPStreamSupported { + return nil, errors.New("agent does not support UDP stream tunneling") + } + return agentstream.NewUDPClient(cfg.Addr, targetAddress, cfg.caCert, cfg.clientCert) +} + func (cfg *AgentConfig) Transport() *http.Transport { return &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -232,3 +334,31 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin release(data) return ret, resp.StatusCode, nil } + +// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct +// +// It will return the status code of the response, and error if any. +// If the status code is not http.StatusOK, out will be unchanged but error will still be nil. +func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) { + resp, err := cfg.do(ctx, "GET", endpoint, nil) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + data, release, err := httputils.ReadAllBody(resp) + if err != nil { + return 0, err + } + + defer release(data) + if resp.StatusCode != http.StatusOK { + return resp.StatusCode, nil + } + + err = json.Unmarshal(data, out) + if err != nil { + return 0, err + } + return resp.StatusCode, nil +} diff --git a/agent/pkg/agent/new_agent.go b/agent/pkg/agent/new_agent.go index df0af75d..e7145622 100644 --- a/agent/pkg/agent/new_agent.go +++ b/agent/pkg/agent/new_agent.go @@ -17,10 +17,8 @@ import ( "math/big" "strings" "time" -) -const ( - CertsDNSName = "godoxy.agent" + "github.com/yusing/godoxy/agent/pkg/agent/common" ) func toPEMPair(certDER []byte, key *ecdsa.PrivateKey) *PEMPair { @@ -156,7 +154,7 @@ func NewAgent() (ca, srv, client *PEMPair, err error) { SerialNumber: caSerialNumber, Subject: pkix.Name{ Organization: []string{"GoDoxy"}, - CommonName: CertsDNSName, + CommonName: common.CertsDNSName, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years @@ -196,9 +194,9 @@ func NewAgent() (ca, srv, client *PEMPair, err error) { Subject: pkix.Name{ Organization: caTemplate.Subject.Organization, OrganizationalUnit: []string{"Server"}, - CommonName: CertsDNSName, + CommonName: common.CertsDNSName, }, - DNSNames: []string{CertsDNSName}, + DNSNames: []string{common.CertsDNSName}, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1000, 0, 0), // Add validity period KeyUsage: x509.KeyUsageDigitalSignature, @@ -228,9 +226,9 @@ func NewAgent() (ca, srv, client *PEMPair, err error) { Subject: pkix.Name{ Organization: caTemplate.Subject.Organization, OrganizationalUnit: []string{"Client"}, - CommonName: CertsDNSName, + CommonName: common.CertsDNSName, }, - DNSNames: []string{CertsDNSName}, + DNSNames: []string{common.CertsDNSName}, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1000, 0, 0), KeyUsage: x509.KeyUsageDigitalSignature, diff --git a/agent/pkg/agent/new_agent_test.go b/agent/pkg/agent/new_agent_test.go index 14a34b78..e446f9ae 100644 --- a/agent/pkg/agent/new_agent_test.go +++ b/agent/pkg/agent/new_agent_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/agent/pkg/agent/common" ) func TestNewAgent(t *testing.T) { @@ -72,7 +73,7 @@ func TestServerClient(t *testing.T) { clientTLSConfig := &tls.Config{ Certificates: []tls.Certificate{*clientTLS}, RootCAs: caPool, - ServerName: CertsDNSName, + ServerName: common.CertsDNSName, } server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/agent/pkg/agent/stream/README.md b/agent/pkg/agent/stream/README.md new file mode 100644 index 00000000..720d713f --- /dev/null +++ b/agent/pkg/agent/stream/README.md @@ -0,0 +1,197 @@ +# Stream proxy protocol + +This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination. It supports both TCP-over-TLS and UDP-over-DTLS transports. + +## Overview + +```mermaid +graph TD + subgraph Client + TC[TCPClient] -->|TLS| TSS[TCPServer] + UC[UDPClient] -->|DTLS| USS[UDPServer] + end + + subgraph Stream Protocol + H[StreamRequestHeader] + end + + TSS -->|Redirect| DST1[Destination TCP] + USS -->|Forward UDP| DST2[Destination UDP] +``` + +## Header + +The on-wire header is a fixed-size binary blob: + +- `Version` (8 bytes) +- `HostLength` (1 byte) +- `Host` (255 bytes, NUL padded) +- `PortLength` (1 byte) +- `Port` (5 bytes, NUL padded) +- `Flag` (1 byte, protocol flags) +- `Checksum` (4 bytes, big-endian CRC32) + +Total: `headerSize = 8 + 1 + 255 + 1 + 5 + 1 + 4 = 275` bytes. + +Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`. + +### Flags + +The `Flag` field is a bitmask of protocol flags defined by `FlagType`: + +| Flag | Value | Purpose | +| ---------------------- | ----- | ---------------------------------------------------------------------- | +| `FlagCloseImmediately` | `1` | Health check probe - server closes immediately after validating header | + +See [`FlagType`](header.go:26) and [`FlagCloseImmediately`](header.go:28). + +See [`StreamRequestHeader`](header.go:30). + +## File Structure + +| File | Purpose | +| ----------------------------------- | ------------------------------------------------------------ | +| [`header.go`](header.go) | Stream request header structure and validation. | +| [`tcp_client.go`](tcp_client.go:12) | TCP client implementation with TLS transport. | +| [`tcp_server.go`](tcp_server.go:13) | TCP server implementation for handling stream requests. | +| [`udp_client.go`](udp_client.go:13) | UDP client implementation with DTLS transport. | +| [`udp_server.go`](udp_server.go:17) | UDP server implementation for handling DTLS stream requests. | +| [`common.go`](common.go:11) | Connection manager and shared constants. | + +## Constants + +| Constant | Value | Purpose | +| ---------------------- | ------------------------- | ------------------------------------------------------- | +| `StreamALPN` | `"godoxy-agent-stream/1"` | TLS ALPN protocol for stream multiplexing. | +| `headerSize` | `275` bytes | Total size of the stream request header. | +| `dialTimeout` | `10s` | Timeout for establishing destination connections. | +| `readDeadline` | `10s` | Read timeout for UDP destination sockets. | +| `FlagCloseImmediately` | `1` | Flag for health check probe - server closes immediately | + +See [`common.go`](common.go:11). + +## Public API + +### Types + +#### `StreamRequestHeader` + +Represents the on-wire protocol header used to negotiate a stream tunnel. + +```go +type StreamRequestHeader struct { + Version [8]byte // Fixed to "0.1.0" with NUL padding + HostLength byte // Actual host name length (0-255) + Host [255]byte // NUL-padded host name + PortLength byte // Actual port string length (0-5) + Port [5]byte // NUL-padded port string + Flag FlagType // Protocol flags (e.g., FlagCloseImmediately) + Checksum [4]byte // CRC32 checksum of header without checksum +} +``` + +**Methods:** + +- `NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error)` - Creates a header for the given host and port. Returns error if host exceeds 255 bytes or port exceeds 5 bytes. +- `NewStreamHealthCheckHeader() *StreamRequestHeader` - Creates a header with `FlagCloseImmediately` set for health check probes. +- `Validate() bool` - Validates the version and checksum. +- `GetHostPort() (string, string)` - Extracts the host and port from the header. +- `ShouldCloseImmediately() bool` - Returns true if `FlagCloseImmediately` is set. + +### TCP Functions + +- [`NewTCPClient()`](tcp_client.go:26) - Creates a TLS client connection and sends the stream header. +- [`NewTCPServerHandler()`](tcp_server.go:24) - Creates a handler for ALPN-multiplexed connections (no listener). +- [`NewTCPServerFromListener()`](tcp_server.go:36) - Wraps an existing TLS listener. +- [`NewTCPServer()`](tcp_server.go:45) - Creates a fully-configured TCP server with TLS listener. + +### UDP Functions + +- [`NewUDPClient()`](udp_client.go:27) - Creates a DTLS client connection and sends the stream header. +- [`NewUDPServer()`](udp_server.go:26) - Creates a DTLS server listening on the given UDP address. + +## Health Check Probes + +The protocol supports health check probes using the `FlagCloseImmediately` flag. When a client sends a header with this flag set, the server validates the header and immediately closes the connection without establishing a destination tunnel. + +This is useful for: + +- Connectivity testing between agent and server +- Verifying TLS/DTLS handshake and mTLS authentication +- Monitoring stream protocol availability + +**Usage:** + +```go +header := stream.NewStreamHealthCheckHeader() +// Send header over TLS/DTLS connection +// Server will validate and close immediately +``` + +Both TCP and UDP servers silently handle health check probes without logging errors. + +See [`NewStreamHealthCheckHeader()`](header.go:66) and [`FlagCloseImmediately`](header.go:28). + +## TCP behavior + +1. Client establishes a TLS connection to the stream server. +2. Client sends exactly one header as a handshake. +3. After the handshake, both sides proxy raw TCP bytes between client and destination. + +Server reads the header using `io.ReadFull` to avoid dropping bytes. + +See [`NewTCPClient()`](tcp_client.go:26) and [`(*TCPServer).redirect()`](tcp_server.go:116). + +## UDP-over-DTLS behavior + +1. Client establishes a DTLS connection to the stream server. +2. Client sends exactly one header as a handshake. +3. After the handshake, both sides proxy raw UDP datagrams: + - client -> destination: DTLS payload is written to destination `UDPConn` + - destination -> client: destination payload is written back to the DTLS connection + +Responses do **not** include a header. + +The UDP server uses a bidirectional forwarding model: + +- One goroutine forwards from client to destination +- Another goroutine forwards from destination to client + +The destination reader uses `readDeadline` to periodically wake up and check for context cancellation. Timeouts do not terminate the session. + +See [`NewUDPClient()`](udp_client.go:27) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:89). + +## Connection Management + +Both `TCPServer` and `UDPServer` create a dedicated destination connection per incoming stream session and close it when the session ends (no destination connection reuse). + +## Error Handling + +| Error | Description | +| --------------------- | ----------------------------------------------- | +| `ErrInvalidHeader` | Header validation failed (version or checksum). | +| `ErrCloseImmediately` | Health check probe - server closed immediately. | + +Errors from connection creation are propagated to the caller. + +See [`header.go`](header.go:23). + +## Integration + +This package is used by the agent to provide stream tunneling capabilities. See the parent [`agent`](../README.md) package for integration details with the GoDoxy server. + +### Certificate Requirements + +Both TCP and UDP servers require: + +- CA certificate for client verification +- Server certificate for TLS/DTLS termination + +Both clients require: + +- CA certificate for server verification +- Client certificate for mTLS authentication + +### ALPN Protocol + +The `StreamALPN` constant (`"godoxy-agent-stream/1"`) is used to multiplex stream tunnel traffic and HTTPS API traffic on the same port. Connections negotiating this ALPN are routed to the stream handler. diff --git a/agent/pkg/agent/stream/common.go b/agent/pkg/agent/stream/common.go new file mode 100644 index 00000000..0a209079 --- /dev/null +++ b/agent/pkg/agent/stream/common.go @@ -0,0 +1,24 @@ +package stream + +import ( + "time" + + "github.com/pion/dtls/v3" + "github.com/yusing/goutils/synk" +) + +const ( + dialTimeout = 10 * time.Second + readDeadline = 10 * time.Second +) + +// StreamALPN is the TLS ALPN protocol id used to multiplex the TCP stream tunnel +// and the HTTPS API on the same TCP port. +// +// When a client negotiates this ALPN, the agent will route the connection to the +// stream tunnel handler instead of the HTTP handler. +const StreamALPN = "godoxy-agent-stream/1" + +var dTLSCipherSuites = []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256} + +var sizedPool = synk.GetSizedBytesPool() diff --git a/agent/pkg/agent/stream/header.go b/agent/pkg/agent/stream/header.go new file mode 100644 index 00000000..2acdfeaf --- /dev/null +++ b/agent/pkg/agent/stream/header.go @@ -0,0 +1,117 @@ +package stream + +import ( + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "reflect" + "unsafe" +) + +const ( + versionSize = 8 + hostSize = 255 + portSize = 5 + flagSize = 1 + checksumSize = 4 // crc32 checksum + + headerSize = versionSize + 1 + hostSize + 1 + portSize + flagSize + checksumSize +) + +var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0} + +var ErrInvalidHeader = errors.New("invalid header") +var ErrCloseImmediately = errors.New("close immediately") + +type FlagType uint8 + +const FlagCloseImmediately FlagType = 1 << iota + +type StreamRequestHeader struct { + Version [versionSize]byte + + HostLength byte + Host [hostSize]byte + + PortLength byte + Port [portSize]byte + + Flag FlagType + Checksum [checksumSize]byte +} + +func init() { + if headerSize != reflect.TypeFor[StreamRequestHeader]().Size() { + panic("headerSize does not match the size of StreamRequestHeader") + } +} + +func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) { + if len(host) > hostSize { + return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host)) + } + if len(port) > portSize { + return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port)) + } + header := &StreamRequestHeader{} + copy(header.Version[:], version[:]) + header.HostLength = byte(len(host)) + copy(header.Host[:], host) + header.PortLength = byte(len(port)) + copy(header.Port[:], port) + header.updateChecksum() + return header, nil +} + +func NewStreamHealthCheckHeader() *StreamRequestHeader { + header := &StreamRequestHeader{} + copy(header.Version[:], version[:]) + header.Flag |= FlagCloseImmediately + header.updateChecksum() + return header +} + +// ToHeader converts header byte array to a copy of itself as a StreamRequestHeader. +func ToHeader(buf *[headerSize]byte) StreamRequestHeader { + return *(*StreamRequestHeader)(unsafe.Pointer(buf)) +} + +func (h *StreamRequestHeader) GetHostPort() (string, string) { + return string(h.Host[:h.HostLength]), string(h.Port[:h.PortLength]) +} + +func (h *StreamRequestHeader) Validate() bool { + if h.Version != version { + return false + } + if h.HostLength > hostSize { + return false + } + if h.PortLength > portSize { + return false + } + return h.validateChecksum() +} + +func (h *StreamRequestHeader) ShouldCloseImmediately() bool { + return h.Flag&FlagCloseImmediately != 0 +} + +func (h *StreamRequestHeader) updateChecksum() { + checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum()) + binary.BigEndian.PutUint32(h.Checksum[:], checksum) +} + +func (h *StreamRequestHeader) validateChecksum() bool { + checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum()) + return checksum == binary.BigEndian.Uint32(h.Checksum[:]) +} + +func (h *StreamRequestHeader) BytesWithoutChecksum() []byte { + return (*[headerSize - checksumSize]byte)(unsafe.Pointer(h))[:] +} + +func (h *StreamRequestHeader) Bytes() []byte { + return (*[headerSize]byte)(unsafe.Pointer(h))[:] +} diff --git a/agent/pkg/agent/stream/payload_test.go b/agent/pkg/agent/stream/payload_test.go new file mode 100644 index 00000000..1c5d7c32 --- /dev/null +++ b/agent/pkg/agent/stream/payload_test.go @@ -0,0 +1,26 @@ +package stream + +import ( + "testing" +) + +func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) { + h, err := NewStreamRequestHeader("example.com", "443") + if err != nil { + t.Fatalf("NewStreamRequestHeader: %v", err) + } + if !h.Validate() { + t.Fatalf("expected header to validate") + } + + var buf [headerSize]byte + copy(buf[:], h.Bytes()) + h2 := ToHeader(&buf) + if !h2.Validate() { + t.Fatalf("expected round-tripped header to validate") + } + host, port := h2.GetHostPort() + if host != "example.com" || port != "443" { + t.Fatalf("unexpected host/port: %q:%q", host, port) + } +} diff --git a/agent/pkg/agent/stream/tcp_client.go b/agent/pkg/agent/stream/tcp_client.go new file mode 100644 index 00000000..3a9db398 --- /dev/null +++ b/agent/pkg/agent/stream/tcp_client.go @@ -0,0 +1,122 @@ +package stream + +import ( + "crypto/tls" + "crypto/x509" + "net" + "time" + + "github.com/yusing/godoxy/agent/pkg/agent/common" +) + +type TCPClient struct { + conn net.Conn +} + +// NewTCPClient creates a new TCP client for the agent. +// +// It will establish a TLS connection and send a stream request header to the server. +// +// It returns an error if +// - the target address is invalid +// - the stream request header is invalid +// - the TLS configuration is invalid +// - the TLS connection fails +// - the stream request header is not sent +func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { + host, port, err := net.SplitHostPort(targetAddress) + if err != nil { + return nil, err + } + + header, err := NewStreamRequestHeader(host, port) + if err != nil { + return nil, err + } + + return newTCPClientWIthHeader(serverAddr, header, caCert, clientCert) +} + +func TCPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { + header := NewStreamHealthCheckHeader() + + conn, err := newTCPClientWIthHeader(serverAddr, header, caCert, clientCert) + if err != nil { + return err + } + + conn.Close() + return nil +} + +func newTCPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { + // Setup TLS configuration + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*clientCert}, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + NextProtos: []string{StreamALPN}, + ServerName: common.CertsDNSName, + } + + // Establish TLS connection + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig) + if err != nil { + return nil, err + } + // Send the stream header once as a handshake. + if _, err := conn.Write(header.Bytes()); err != nil { + _ = conn.Close() + return nil, err + } + + return &TCPClient{ + conn: conn, + }, nil +} + +func (c *TCPClient) Read(p []byte) (n int, err error) { + return c.conn.Read(p) +} + +func (c *TCPClient) Write(p []byte) (n int, err error) { + return c.conn.Write(p) +} + +func (c *TCPClient) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *TCPClient) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *TCPClient) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *TCPClient) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *TCPClient) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *TCPClient) Close() error { + return c.conn.Close() +} + +// ConnectionState exposes the underlying TLS connection state when the client is +// backed by *tls.Conn. +// +// This is primarily used by tests and diagnostics. +func (c *TCPClient) ConnectionState() tls.ConnectionState { + if tc, ok := c.conn.(*tls.Conn); ok { + return tc.ConnectionState() + } + return tls.ConnectionState{} +} diff --git a/agent/pkg/agent/stream/tcp_server.go b/agent/pkg/agent/stream/tcp_server.go new file mode 100644 index 00000000..9432746c --- /dev/null +++ b/agent/pkg/agent/stream/tcp_server.go @@ -0,0 +1,176 @@ +package stream + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + ioutils "github.com/yusing/goutils/io" +) + +type TCPServer struct { + ctx context.Context + listener net.Listener +} + +// NewTCPServerHandler creates a TCP stream server that can serve already-accepted +// connections (e.g. handed off by an ALPN multiplexer). +// +// This variant does not require a listener. Use TCPServer.ServeConn to handle +// each incoming stream connection. +func NewTCPServerHandler(ctx context.Context) *TCPServer { + s := &TCPServer{ctx: ctx} + return s +} + +// NewTCPServerFromListener creates a TCP stream server from an already-prepared +// listener. +// +// The listener is expected to yield connections that are already secured (e.g. +// a TLS/mTLS listener, or pre-handshaked *tls.Conn). This is used when the agent +// multiplexes HTTPS and stream-tunnel traffic on the same port. +func NewTCPServerFromListener(ctx context.Context, listener net.Listener) *TCPServer { + s := &TCPServer{ + ctx: ctx, + listener: listener, + } + return s +} + +func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer { + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + NextProtos: []string{StreamALPN}, + } + + tcpListener := tls.NewListener(listener, tlsConfig) + return NewTCPServerFromListener(ctx, tcpListener) +} + +func (s *TCPServer) Start() error { + if s.listener == nil { + return net.ErrClosed + } + context.AfterFunc(s.ctx, func() { + _ = s.listener.Close() + }) + for { + conn, err := s.listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) && s.ctx.Err() != nil { + return s.ctx.Err() + } + return err + } + go s.handle(conn) + } +} + +// ServeConn serves a single stream connection. +// +// The provided connection is expected to be already secured (TLS/mTLS) and to +// speak the stream protocol (i.e. the client will send the stream header first). +// +// This method blocks until the stream finishes. +func (s *TCPServer) ServeConn(conn net.Conn) { + s.handle(conn) +} + +func (s *TCPServer) Addr() net.Addr { + if s.listener == nil { + return nil + } + return s.listener.Addr() +} + +func (s *TCPServer) Close() error { + if s.listener == nil { + return nil + } + return s.listener.Close() +} + +func (s *TCPServer) logger(clientConn net.Conn) *zerolog.Logger { + ev := log.With().Str("protocol", "tcp"). + Str("remote", clientConn.RemoteAddr().String()) + if s.listener != nil { + ev = ev.Str("addr", s.listener.Addr().String()) + } + l := ev.Logger() + return &l +} + +func (s *TCPServer) loggerWithDst(dstConn net.Conn, clientConn net.Conn) *zerolog.Logger { + ev := log.With().Str("protocol", "tcp"). + Str("remote", clientConn.RemoteAddr().String()). + Str("dst", dstConn.RemoteAddr().String()) + if s.listener != nil { + ev = ev.Str("addr", s.listener.Addr().String()) + } + l := ev.Logger() + return &l +} + +func (s *TCPServer) handle(conn net.Conn) { + defer conn.Close() + dst, err := s.redirect(conn) + if err != nil { + // Health check probe: close connection + if errors.Is(err, ErrCloseImmediately) { + s.logger(conn).Info().Msg("Health check received") + return + } + s.logger(conn).Err(err).Msg("failed to redirect connection") + return + } + + defer dst.Close() + pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst) + err = pipe.Start() + if err != nil { + s.loggerWithDst(dst, conn).Err(err).Msg("failed to start bidirectional pipe") + return + } +} + +func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) { + // Read the stream header once as a handshake. + var headerBuf [headerSize]byte + if _, err := io.ReadFull(conn, headerBuf[:]); err != nil { + return nil, err + } + + header := ToHeader(&headerBuf) + if !header.Validate() { + return nil, ErrInvalidHeader + } + + // Health check: close immediately if FlagCloseImmediately is set + if header.ShouldCloseImmediately() { + return nil, ErrCloseImmediately + } + + // get destination connection + host, port := header.GetHostPort() + return s.createDestConnection(host, port) +} + +func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) { + addr := net.JoinHostPort(host, port) + conn, err := net.DialTimeout("tcp", addr, dialTimeout) + if err != nil { + return nil, err + } + return conn, nil +} diff --git a/agent/pkg/agent/stream/tests/healthcheck_test.go b/agent/pkg/agent/stream/tests/healthcheck_test.go new file mode 100644 index 00000000..320e29a6 --- /dev/null +++ b/agent/pkg/agent/stream/tests/healthcheck_test.go @@ -0,0 +1,26 @@ +package stream_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/agent/pkg/agent/stream" +) + +func TestTCPHealthCheck(t *testing.T) { + certs := genTestCerts(t) + + srv := startTCPServer(t, certs) + + err := stream.TCPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert) + require.NoError(t, err, "health check") +} + +func TestUDPHealthCheck(t *testing.T) { + certs := genTestCerts(t) + + srv := startUDPServer(t, certs) + + err := stream.UDPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert) + require.NoError(t, err, "health check") +} diff --git a/agent/pkg/agent/stream/tests/mux_test.go b/agent/pkg/agent/stream/tests/mux_test.go new file mode 100644 index 00000000..8bb5b455 --- /dev/null +++ b/agent/pkg/agent/stream/tests/mux_test.go @@ -0,0 +1,94 @@ +package stream_test + +import ( + "bufio" + "context" + "crypto/tls" + "crypto/x509" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/agent/pkg/agent/common" + "github.com/yusing/godoxy/agent/pkg/agent/stream" +) + +func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) { + certs := genTestCerts(t) + + baseLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err, "listen tcp") + defer baseLn.Close() + baseAddr := baseLn.Addr().String() + + caCertPool := x509.NewCertPool() + caCertPool.AddCert(certs.CaCert) + + serverTLS := &tls.Config{ + Certificates: []tls.Certificate{*certs.SrvCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + NextProtos: []string{"http/1.1", stream.StreamALPN}, + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + streamSrv := stream.NewTCPServerHandler(ctx) + defer func() { _ = streamSrv.Close() }() + + tlsLn := tls.NewListener(baseLn, serverTLS) + defer func() { _ = tlsLn.Close() }() + + // HTTP server + httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + }), + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ + stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + streamSrv.ServeConn(conn) + }, + }, + } + go func() { _ = httpSrv.Serve(tlsLn) }() + defer func() { _ = httpSrv.Close() }() + + // Stream destination + dstAddr, closeDst := startTCPEcho(t) + defer closeDst() + + // HTTP client over the same port + clientTLS := &tls.Config{ + Certificates: []tls.Certificate{*certs.ClientCert}, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + NextProtos: []string{"http/1.1"}, + ServerName: common.CertsDNSName, + } + hc, err := tls.Dial("tcp", baseAddr, clientTLS) + require.NoError(t, err, "dial https") + defer hc.Close() + _ = hc.SetDeadline(time.Now().Add(2 * time.Second)) + _, err = hc.Write([]byte("GET / HTTP/1.1\r\nHost: godoxy-agent\r\n\r\n")) + require.NoError(t, err, "write http request") + r := bufio.NewReader(hc) + statusLine, err := r.ReadString('\n') + require.NoError(t, err, "read status line") + require.Contains(t, statusLine, "200", "expected 200") + + // Stream client over the same port + client := NewTCPClient(t, baseAddr, dstAddr, certs) + defer client.Close() + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + msg := []byte("ping over mux") + _, err = client.Write(msg) + require.NoError(t, err, "write stream payload") + buf := make([]byte, len(msg)) + _, err = io.ReadFull(client, buf) + require.NoError(t, err, "read stream payload") + require.Equal(t, msg, buf) +} diff --git a/agent/pkg/agent/stream/tests/server_flow_test.go b/agent/pkg/agent/stream/tests/server_flow_test.go new file mode 100644 index 00000000..f797a2b4 --- /dev/null +++ b/agent/pkg/agent/stream/tests/server_flow_test.go @@ -0,0 +1,201 @@ +package stream_test + +import ( + "crypto/tls" + "fmt" + "io" + "sync" + "testing" + "time" + + "github.com/pion/dtls/v3" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/agent/pkg/agent" + "github.com/yusing/godoxy/agent/pkg/agent/stream" +) + +func TestTCPServer_FullFlow(t *testing.T) { + certs := genTestCerts(t) + + dstAddr, closeDst := startTCPEcho(t) + defer closeDst() + + srv := startTCPServer(t, certs) + + client := NewTCPClient(t, srv.Addr.String(), dstAddr, certs) + defer client.Close() + + // Ensure ALPN is negotiated as expected (required for multiplexing). + withState, ok := client.(interface{ ConnectionState() tls.ConnectionState }) + require.True(t, ok, "tcp client should expose TLS connection state") + require.Equal(t, stream.StreamALPN, withState.ConnectionState().NegotiatedProtocol) + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + msg := []byte("ping over tcp") + _, err := client.Write(msg) + require.NoError(t, err, "write to client") + + buf := make([]byte, len(msg)) + _, err = io.ReadFull(client, buf) + require.NoError(t, err, "read from client") + require.Equal(t, string(msg), string(buf), "unexpected echo") +} + +func TestTCPServer_ConcurrentConnections(t *testing.T) { + certs := genTestCerts(t) + + dstAddr, closeDst := startTCPEcho(t) + defer closeDst() + + srv := startTCPServer(t, certs) + + const nClients = 25 + + errs := make(chan error, nClients) + var wg sync.WaitGroup + wg.Add(nClients) + + for i := range nClients { + go func() { + defer wg.Done() + + client := NewTCPClient(t, srv.Addr.String(), dstAddr, certs) + defer client.Close() + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + msg := fmt.Appendf(nil, "ping over tcp %d", i) + if _, err := client.Write(msg); err != nil { + errs <- fmt.Errorf("write to client: %w", err) + return + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(client, buf); err != nil { + errs <- fmt.Errorf("read from client: %w", err) + return + } + if string(msg) != string(buf) { + errs <- fmt.Errorf("unexpected echo: got=%q want=%q", string(buf), string(msg)) + return + } + }() + } + + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } +} + +func TestUDPServer_RejectInvalidClient(t *testing.T) { + certs := genTestCerts(t) + + // Generate a self-signed client cert that is NOT signed by the CA + _, _, invalidClientPEM, err := agent.NewAgent() + require.NoError(t, err, "generate invalid client certs") + invalidClientCert, err := invalidClientPEM.ToTLSCert() + require.NoError(t, err, "parse invalid client cert") + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + srv := startUDPServer(t, certs) + + + // Try to connect with a client cert from a different CA + _, err = stream.NewUDPClient(srv.Addr.String(), dstAddr, certs.CaCert, invalidClientCert) + require.Error(t, err, "expected error when connecting with client cert from different CA") + + var handshakeErr *dtls.HandshakeError + require.ErrorAs(t, err, &handshakeErr, "expected handshake error") +} + +func TestUDPServer_RejectClientWithoutCert(t *testing.T) { + certs := genTestCerts(t) + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + srv := startUDPServer(t, certs) + + time.Sleep(time.Second) + + // Try to connect without any client certificate + // Create a TLS cert without a private key to simulate no client cert + emptyCert := &tls.Certificate{} + _, err := stream.NewUDPClient(srv.Addr.String(), dstAddr, certs.CaCert, emptyCert) + require.Error(t, err, "expected error when connecting without client cert") + + require.ErrorContains(t, err, "no certificate provided", "expected no cert error") +} + +func TestUDPServer_FullFlow(t *testing.T) { + certs := genTestCerts(t) + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + srv := startUDPServer(t, certs) + + client := NewUDPClient(t, srv.Addr.String(), dstAddr, certs) + defer client.Close() + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + msg := []byte("ping over udp") + _, err := client.Write(msg) + require.NoError(t, err, "write to client") + + buf := make([]byte, 2048) + n, err := client.Read(buf) + require.NoError(t, err, "read from client") + require.Equal(t, string(msg), string(buf[:n]), "unexpected echo") +} + +func TestUDPServer_ConcurrentConnections(t *testing.T) { + certs := genTestCerts(t) + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + srv := startUDPServer(t, certs) + + const nClients = 25 + + errs := make(chan error, nClients) + var wg sync.WaitGroup + wg.Add(nClients) + + for i := range nClients { + go func() { + defer wg.Done() + + client := NewUDPClient(t, srv.Addr.String(), dstAddr, certs) + defer client.Close() + + _ = client.SetDeadline(time.Now().Add(5 * time.Second)) + msg := fmt.Appendf(nil, "ping over udp %d", i) + if _, err := client.Write(msg); err != nil { + errs <- fmt.Errorf("write to client: %w", err) + return + } + + buf := make([]byte, 2048) + n, err := client.Read(buf) + if err != nil { + errs <- fmt.Errorf("read from client: %w", err) + return + } + if string(msg) != string(buf[:n]) { + errs <- fmt.Errorf("unexpected echo: got=%q want=%q", string(buf[:n]), string(msg)) + return + } + }() + } + + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } +} diff --git a/agent/pkg/agent/stream/tests/testutils_test.go b/agent/pkg/agent/stream/tests/testutils_test.go new file mode 100644 index 00000000..c8000f87 --- /dev/null +++ b/agent/pkg/agent/stream/tests/testutils_test.go @@ -0,0 +1,177 @@ +package stream_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/pion/transport/v3/udp" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/agent/pkg/agent" + "github.com/yusing/godoxy/agent/pkg/agent/stream" +) + +// CertBundle holds all certificates needed for testing. +type CertBundle struct { + CaCert *x509.Certificate + SrvCert *tls.Certificate + ClientCert *tls.Certificate +} + +// genTestCerts generates certificates for testing and returns them as a CertBundle. +func genTestCerts(t *testing.T) CertBundle { + t.Helper() + + caPEM, srvPEM, clientPEM, err := agent.NewAgent() + require.NoError(t, err, "generate agent certs") + + caCert, err := caPEM.ToTLSCert() + require.NoError(t, err, "parse CA cert") + srvCert, err := srvPEM.ToTLSCert() + require.NoError(t, err, "parse server cert") + clientCert, err := clientPEM.ToTLSCert() + require.NoError(t, err, "parse client cert") + + return CertBundle{ + CaCert: caCert.Leaf, + SrvCert: srvCert, + ClientCert: clientCert, + } +} + +// startTCPEcho starts a TCP echo server and returns its address and close function. +func startTCPEcho(t *testing.T) (addr string, closeFn func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "listen tcp") + + done := make(chan struct{}) + go func() { + defer close(done) + for { + c, err := ln.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(c) + } + }() + + return ln.Addr().String(), func() { + _ = ln.Close() + <-done + } +} + +// startUDPEcho starts a UDP echo server and returns its address and close function. +func startUDPEcho(t *testing.T) (addr string, closeFn func()) { + t.Helper() + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err, "listen udp") + uc := pc.(*net.UDPConn) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 65535) + for { + n, raddr, err := uc.ReadFromUDP(buf) + if err != nil { + return + } + _, _ = uc.WriteToUDP(buf[:n], raddr) + } + }() + + return uc.LocalAddr().String(), func() { + _ = uc.Close() + <-done + } +} + +// TestServer wraps a server with its startup goroutine for cleanup. +type TestServer struct { + Server interface{ Close() error } + Addr net.Addr +} + +// startTCPServer starts a TCP server and returns a TestServer for cleanup. +func startTCPServer(t *testing.T, certs CertBundle) TestServer { + t.Helper() + + tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err, "listen tcp") + + ctx, cancel := context.WithCancel(t.Context()) + + srv := stream.NewTCPServer(ctx, tcpLn, certs.CaCert, certs.SrvCert) + + errCh := make(chan error, 1) + go func() { errCh <- srv.Start() }() + + t.Cleanup(func() { + cancel() + _ = srv.Close() + err := <-errCh + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) { + t.Logf("tcp server exit: %v", err) + } + }) + + return TestServer{ + Server: srv, + Addr: srv.Addr(), + } +} + +// startUDPServer starts a UDP server and returns a TestServer for cleanup. +func startUDPServer(t *testing.T, certs CertBundle) TestServer { + t.Helper() + + ctx, cancel := context.WithCancel(t.Context()) + + srv := stream.NewUDPServer(ctx, "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, certs.CaCert, certs.SrvCert) + + errCh := make(chan error, 1) + go func() { errCh <- srv.Start() }() + + time.Sleep(100 * time.Millisecond) + + t.Cleanup(func() { + cancel() + _ = srv.Close() + err := <-errCh + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) && !errors.Is(err, udp.ErrClosedListener) { + t.Logf("udp server exit: %v", err) + } + }) + + return TestServer{ + Server: srv, + Addr: srv.Addr(), + } +} + +// NewTCPClient creates a TCP client connected to the server with test certificates. +func NewTCPClient(t *testing.T, serverAddr, targetAddress string, certs CertBundle) net.Conn { + t.Helper() + client, err := stream.NewTCPClient(serverAddr, targetAddress, certs.CaCert, certs.ClientCert) + require.NoError(t, err, "create tcp client") + return client +} + +// NewUDPClient creates a UDP client connected to the server with test certificates. +func NewUDPClient(t *testing.T, serverAddr, targetAddress string, certs CertBundle) net.Conn { + t.Helper() + client, err := stream.NewUDPClient(serverAddr, targetAddress, certs.CaCert, certs.ClientCert) + require.NoError(t, err, "create udp client") + return client +} diff --git a/agent/pkg/agent/stream/udp_client.go b/agent/pkg/agent/stream/udp_client.go new file mode 100644 index 00000000..4d372be8 --- /dev/null +++ b/agent/pkg/agent/stream/udp_client.go @@ -0,0 +1,118 @@ +package stream + +import ( + "crypto/tls" + "crypto/x509" + "net" + "time" + + "github.com/pion/dtls/v3" + "github.com/yusing/godoxy/agent/pkg/agent/common" +) + +type UDPClient struct { + conn net.Conn +} + +// NewUDPClient creates a new UDP client for the agent. +// +// It will establish a DTLS connection and send a stream request header to the server. +// +// It returns an error if +// - the target address is invalid +// - the stream request header is invalid +// - the DTLS configuration is invalid +// - the DTLS connection fails +// - the stream request header is not sent +func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { + host, port, err := net.SplitHostPort(targetAddress) + if err != nil { + return nil, err + } + + header, err := NewStreamRequestHeader(host, port) + if err != nil { + return nil, err + } + + return newUDPClientWIthHeader(serverAddr, header, caCert, clientCert) +} + +func newUDPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { + // Setup DTLS configuration + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + dtlsConfig := &dtls.Config{ + Certificates: []tls.Certificate{*clientCert}, + RootCAs: caCertPool, + InsecureSkipVerify: false, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + ServerName: common.CertsDNSName, + CipherSuites: dTLSCipherSuites, + } + + raddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + return nil, err + } + + // Establish DTLS connection + conn, err := dtls.Dial("udp", raddr, dtlsConfig) + if err != nil { + return nil, err + } + // Send the stream header once as a handshake. + if _, err := conn.Write(header.Bytes()); err != nil { + _ = conn.Close() + return nil, err + } + + return &UDPClient{ + conn: conn, + }, nil +} + +func UDPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error { + header := NewStreamHealthCheckHeader() + + conn, err := newUDPClientWIthHeader(serverAddr, header, caCert, clientCert) + if err != nil { + return err + } + + conn.Close() + return nil +} + +func (c *UDPClient) Read(p []byte) (n int, err error) { + return c.conn.Read(p) +} + +func (c *UDPClient) Write(p []byte) (n int, err error) { + return c.conn.Write(p) +} + +func (c *UDPClient) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *UDPClient) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *UDPClient) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *UDPClient) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *UDPClient) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *UDPClient) Close() error { + return c.conn.Close() +} diff --git a/agent/pkg/agent/stream/udp_server.go b/agent/pkg/agent/stream/udp_server.go new file mode 100644 index 00000000..0ceb7124 --- /dev/null +++ b/agent/pkg/agent/stream/udp_server.go @@ -0,0 +1,205 @@ +package stream + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "time" + + "github.com/pion/dtls/v3" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +type UDPServer struct { + ctx context.Context + network string + laddr *net.UDPAddr + listener net.Listener + + dtlsConfig *dtls.Config +} + +func NewUDPServer(ctx context.Context, network string, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer { + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + dtlsConfig := &dtls.Config{ + Certificates: []tls.Certificate{*serverCert}, + ClientCAs: caCertPool, + ClientAuth: dtls.RequireAndVerifyClientCert, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + CipherSuites: dTLSCipherSuites, + } + + s := &UDPServer{ + ctx: ctx, + network: network, + laddr: laddr, + dtlsConfig: dtlsConfig, + } + return s +} + +func (s *UDPServer) Start() error { + listener, err := dtls.Listen(s.network, s.laddr, s.dtlsConfig) + if err != nil { + return err + } + s.listener = listener + + context.AfterFunc(s.ctx, func() { + _ = s.listener.Close() + }) + + for { + conn, err := s.listener.Accept() + if err != nil { + // Expected error when context cancelled + if errors.Is(err, net.ErrClosed) && s.ctx.Err() != nil { + return s.ctx.Err() + } + return err + } + go s.handleDTLSConnection(conn) + } +} + +func (s *UDPServer) Addr() net.Addr { + if s.listener != nil { + return s.listener.Addr() + } + return s.laddr +} + +func (s *UDPServer) Close() error { + if s.listener != nil { + return s.listener.Close() + } + return nil +} + +func (s *UDPServer) logger(clientConn net.Conn) *zerolog.Logger { + l := log.With().Str("protocol", "udp"). + Str("addr", s.Addr().String()). + Str("remote", clientConn.RemoteAddr().String()).Logger() + return &l +} + +func (s *UDPServer) loggerWithDst(clientConn net.Conn, dstConn *net.UDPConn) *zerolog.Logger { + l := log.With().Str("protocol", "udp"). + Str("addr", s.Addr().String()). + Str("remote", clientConn.RemoteAddr().String()). + Str("dst", dstConn.RemoteAddr().String()).Logger() + return &l +} + +func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) { + defer clientConn.Close() + + // Read the stream header once as a handshake. + var headerBuf [headerSize]byte + if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil { + s.logger(clientConn).Err(err).Msg("failed to read stream header") + return + } + header := ToHeader(&headerBuf) + if !header.Validate() { + s.logger(clientConn).Error().Bytes("header", headerBuf[:]).Msg("invalid stream header received") + return + } + + // Health check probe: close connection + if header.ShouldCloseImmediately() { + s.logger(clientConn).Info().Msg("Health check received") + return + } + + host, port := header.GetHostPort() + dstConn, err := s.createDestConnection(host, port) + if err != nil { + s.logger(clientConn).Err(err).Msg("failed to get or create destination connection") + return + } + defer dstConn.Close() + + go s.forwardFromDestination(dstConn, clientConn) + + buf := sizedPool.GetSized(65535) + defer sizedPool.Put(buf) + + for { + select { + case <-s.ctx.Done(): + return + default: + n, err := clientConn.Read(buf) + // Per net.Conn contract, Read may return (n > 0, err == io.EOF). + // Always forward any bytes we got before acting on the error. + if n > 0 { + if _, werr := dstConn.Write(buf[:n]); werr != nil { + s.logger(clientConn).Err(werr).Msgf("failed to write %d bytes to destination", n) + return + } + } + if err != nil { + // Expected shutdown paths. + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return + } + s.logger(clientConn).Err(err).Msg("failed to read from client") + return + } + } + } +} + +func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) { + addr := net.JoinHostPort(host, port) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + dstConn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, err + } + + return dstConn, nil +} + +func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) { + buffer := sizedPool.GetSized(65535) + defer sizedPool.Put(buffer) + + for { + select { + case <-s.ctx.Done(): + return + default: + _ = dstConn.SetReadDeadline(time.Now().Add(readDeadline)) + n, err := dstConn.Read(buffer) + if err != nil { + // The destination socket can be closed when the client disconnects (e.g. during + // the stream support probe in AgentConfig.StartWithCerts). Treat that as a + // normal exit and avoid noisy logs. + if errors.Is(err, net.ErrClosed) { + return + } + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + s.loggerWithDst(clientConn, dstConn).Err(err).Msg("failed to read from destination") + return + } + if _, err := clientConn.Write(buffer[:n]); err != nil { + s.loggerWithDst(clientConn, dstConn).Err(err).Msgf("failed to write %d bytes to client", n) + return + } + } + } +} diff --git a/agent/pkg/agent/templates/agent.compose.yml b/agent/pkg/agent/templates/agent.compose.yml deleted file mode 100644 index bb10d1f7..00000000 --- a/agent/pkg/agent/templates/agent.compose.yml +++ /dev/null @@ -1,44 +0,0 @@ -services: - agent: - image: "{{.Image}}" - container_name: godoxy-agent - restart: always - network_mode: host # do not change this - environment: - AGENT_NAME: "{{.Name}}" - AGENT_PORT: "{{.Port}}" - AGENT_CA_CERT: "{{.CACert}}" - AGENT_SSL_CERT: "{{.SSLCert}}" - # use agent as a docker socket proxy: [host]:port - # set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375) - LISTEN_ADDR: - POST: false - ALLOW_RESTARTS: false - ALLOW_START: false - ALLOW_STOP: false - AUTH: false - BUILD: false - COMMIT: false - CONFIGS: false - CONTAINERS: false - DISTRIBUTION: false - EVENTS: true - EXEC: false - GRPC: false - IMAGES: false - INFO: false - NETWORKS: false - NODES: false - PING: true - PLUGINS: false - SECRETS: false - SERVICES: false - SESSION: false - SWARM: false - SYSTEM: false - TASKS: false - VERSION: true - VOLUMES: false - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - ./data:/app/data diff --git a/agent/pkg/agent/templates/agent.compose.yml.tmpl b/agent/pkg/agent/templates/agent.compose.yml.tmpl index cc0864c0..c0d9fc4c 100644 --- a/agent/pkg/agent/templates/agent.compose.yml.tmpl +++ b/agent/pkg/agent/templates/agent.compose.yml.tmpl @@ -5,7 +5,8 @@ services: restart: always {{ if eq .ContainerRuntime "podman" -}} ports: - - "{{.Port}}:{{.Port}}" + - "{{.Port}}:{{.Port}}/tcp" + - "{{.Port}}:{{.Port}}/udp" {{ else -}} network_mode: host # do not change this {{ end -}} diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go index 31401cf7..9ddf3225 100644 --- a/agent/pkg/handler/handler.go +++ b/agent/pkg/handler/handler.go @@ -1,9 +1,9 @@ package handler import ( - "fmt" "net/http" + "github.com/bytedance/sonic" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/yusing/godoxy/agent/pkg/agent" @@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler { } mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP) - mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, version.Get()) - }) - mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, env.AgentName) - }) - mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, env.Runtime) + mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) { + agentInfo := agent.AgentInfo{ + Version: version.Get(), + Name: env.AgentName, + Runtime: env.Runtime, + } + w.Header().Set("Content-Type", "application/json") + sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo) }) mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth) mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP) diff --git a/go.mod b/go.mod index ec2fc7ad..2aa57a13 100644 --- a/go.mod +++ b/go.mod @@ -172,6 +172,9 @@ require ( github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect + github.com/pion/dtls/v3 v3.0.10 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/pquerna/otp v1.5.0 // indirect github.com/stretchr/objx v0.5.3 // indirect diff --git a/go.sum b/go.sum index 5b290d61..349745b4 100644 --- a/go.sum +++ b/go.sum @@ -247,6 +247,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pion/dtls/v3 v3.0.10 h1:k9ekkq1kaZoxnNEbyLKI8DI37j/Nbk1HWmMuywpQJgg= +github.com/pion/dtls/v3 v3.0.10/go.mod h1:YEmmBYIoBsY3jmG56dsziTv/Lca9y4Om83370CXfqJ8= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= diff --git a/internal/api/v1/docs/swagger.json b/internal/api/v1/docs/swagger.json index 03ea87f4..d37822a4 100644 --- a/internal/api/v1/docs/swagger.json +++ b/internal/api/v1/docs/swagger.json @@ -2356,6 +2356,16 @@ "x-nullable": false, "x-omitempty": false }, + "supports_tcp_stream": { + "type": "boolean", + "x-nullable": false, + "x-omitempty": false + }, + "supports_udp_stream": { + "type": "boolean", + "x-nullable": false, + "x-omitempty": false + }, "version": { "type": "string", "x-nullable": false, @@ -2439,7 +2449,7 @@ "type": "object", "properties": { "agent": { - "$ref": "#/definitions/Agent", + "$ref": "#/definitions/agentpool.Agent", "x-nullable": false, "x-omitempty": false }, @@ -4909,6 +4919,43 @@ "x-nullable": false, "x-omitempty": false }, + "agentpool.Agent": { + "type": "object", + "properties": { + "addr": { + "type": "string", + "x-nullable": false, + "x-omitempty": false + }, + "name": { + "type": "string", + "x-nullable": false, + "x-omitempty": false + }, + "runtime": { + "$ref": "#/definitions/agent.ContainerRuntime", + "x-nullable": false, + "x-omitempty": false + }, + "supports_tcp_stream": { + "type": "boolean", + "x-nullable": false, + "x-omitempty": false + }, + "supports_udp_stream": { + "type": "boolean", + "x-nullable": false, + "x-omitempty": false + }, + "version": { + "type": "string", + "x-nullable": false, + "x-omitempty": false + } + }, + "x-nullable": false, + "x-omitempty": false + }, "auth.UserPassAuthCallbackRequest": { "type": "object", "properties": { diff --git a/internal/api/v1/docs/swagger.yaml b/internal/api/v1/docs/swagger.yaml index 38cb6d7c..a7f2c0ae 100644 --- a/internal/api/v1/docs/swagger.yaml +++ b/internal/api/v1/docs/swagger.yaml @@ -8,6 +8,10 @@ definitions: type: string runtime: $ref: '#/definitions/agent.ContainerRuntime' + supports_tcp_stream: + type: boolean + supports_udp_stream: + type: boolean version: type: string type: object @@ -48,7 +52,7 @@ definitions: Container: properties: agent: - $ref: '#/definitions/Agent' + $ref: '#/definitions/agentpool.Agent' aliases: items: type: string @@ -1236,6 +1240,21 @@ definitions: x-enum-varnames: - ContainerRuntimeDocker - ContainerRuntimePodman + agentpool.Agent: + properties: + addr: + type: string + name: + type: string + runtime: + $ref: '#/definitions/agent.ContainerRuntime' + supports_tcp_stream: + type: boolean + supports_udp_stream: + type: boolean + version: + type: string + type: object auth.UserPassAuthCallbackRequest: properties: password: diff --git a/internal/route/stream.go b/internal/route/stream.go index 0b6184b7..a9cd14ca 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -116,9 +116,9 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) { switch rScheme { case "tcp": - return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host) + return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host, r.GetAgent()) case "udp": - return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host) + return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host, r.GetAgent()) } return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme) } diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go index 620f4d4f..83f85262 100644 --- a/internal/route/stream/tcp_tcp.go +++ b/internal/route/stream/tcp_tcp.go @@ -7,6 +7,7 @@ import ( "github.com/pires/go-proxyproto" "github.com/rs/zerolog" "github.com/yusing/godoxy/internal/acl" + "github.com/yusing/godoxy/internal/agentpool" "github.com/yusing/godoxy/internal/entrypoint" nettypes "github.com/yusing/godoxy/internal/net/types" ioutils "github.com/yusing/goutils/io" @@ -21,6 +22,7 @@ type TCPTCPStream struct { laddr *net.TCPAddr dst *net.TCPAddr + agent *agentpool.Agent preDial nettypes.HookFunc onRead nettypes.HookFunc @@ -28,7 +30,7 @@ type TCPTCPStream struct { closed atomic.Bool } -func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) { +func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent) (nettypes.Stream, error) { dst, err := net.ResolveTCPAddr(dstNetwork, dstAddr) if err != nil { return nil, err @@ -37,7 +39,7 @@ func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes. if err != nil { return nil, err } - return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst}, nil + return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil } func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { @@ -130,7 +132,15 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { return } - dstConn, err := net.DialTCP(s.dstNetwork, nil, s.dst) + var ( + dstConn net.Conn + err error + ) + if s.agent != nil { + dstConn, err = s.agent.NewTCPClient(s.dst.String()) + } else { + dstConn, err = net.DialTCP(s.dstNetwork, nil, s.dst) + } if err != nil { if !s.closed.Load() { logErr(s, err, "failed to dial destination") @@ -144,7 +154,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { } src := conn - dst := net.Conn(dstConn) + dst := dstConn if s.onRead != nil { src = &wrapperConn{ Conn: conn, diff --git a/internal/route/stream/udp_udp.go b/internal/route/stream/udp_udp.go index 7b8615ec..39f96360 100644 --- a/internal/route/stream/udp_udp.go +++ b/internal/route/stream/udp_udp.go @@ -11,6 +11,7 @@ import ( "github.com/rs/zerolog" "github.com/yusing/godoxy/internal/acl" + "github.com/yusing/godoxy/internal/agentpool" nettypes "github.com/yusing/godoxy/internal/net/types" "github.com/yusing/goutils/synk" "go.uber.org/atomic" @@ -24,6 +25,7 @@ type UDPUDPStream struct { laddr *net.UDPAddr dst *net.UDPAddr + agent *agentpool.Agent preDial nettypes.HookFunc onRead nettypes.HookFunc @@ -53,7 +55,7 @@ const ( var bufPool = synk.GetSizedBytesPool() -func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) { +func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent) (nettypes.Stream, error) { dst, err := net.ResolveUDPAddr(dstNetwork, dstAddr) if err != nil { return nil, err @@ -67,6 +69,7 @@ func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes. dstNetwork: dstNetwork, laddr: laddr, dst: dst, + agent: agent, conns: make(map[string]*udpUDPConn), }, nil } @@ -195,7 +198,11 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd dstConn net.Conn err error ) - dstConn, err = net.DialUDP(s.dstNetwork, nil, s.dst) + if s.agent != nil { + dstConn, err = s.agent.NewUDPClient(s.dst.String()) + } else { + dstConn, err = net.DialUDP(s.dst.Network(), nil, s.dst) + } if err != nil { logErr(s, err, "failed to dial dst") return nil, false