diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 307422ed..c8502ec5 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -1,15 +1,18 @@ package main import ( + "net" "os" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/yusing/godoxy/agent/pkg/agent" + "github.com/yusing/godoxy/agent/pkg/agent/stream" "github.com/yusing/godoxy/agent/pkg/env" "github.com/yusing/godoxy/agent/pkg/server" "github.com/yusing/godoxy/internal/metrics/systeminfo" socketproxy "github.com/yusing/godoxy/socketproxy/pkg" + gperr "github.com/yusing/goutils/errs" httpServer "github.com/yusing/goutils/server" strutils "github.com/yusing/goutils/strings" "github.com/yusing/goutils/task" @@ -63,6 +66,16 @@ Tips: server.StartAgentServer(t, opts) + tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentStreamPort}) + if err != nil { + gperr.LogFatal("failed to listen on port", err) + } + tcpServer := stream.NewTCPServer(t.Context(), tcpListener, caCert.Leaf, srvCert) + go tcpServer.Start() + + udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentStreamPort}, caCert.Leaf, srvCert) + go udpServer.Start() + if socketproxy.ListenAddr != "" { runtime := strutils.Title(string(env.Runtime)) diff --git a/agent/go.mod b/agent/go.mod index 613b0f2e..8b67faf1 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -18,6 +18,7 @@ require ( github.com/bytedance/sonic v1.14.2 github.com/gin-gonic/gin v1.11.0 github.com/gorilla/websocket v1.5.3 + github.com/pion/dtls/v3 v3.0.9 github.com/puzpuzpuz/xsync/v4 v4.2.0 github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.11.1 @@ -87,6 +88,8 @@ require ( github.com/opencontainers/image-spec v1.1.1 // indirect github.com/oschwald/maxminddb-golang v1.13.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/transport/v3 v3.1.1 // indirect github.com/pires/go-proxyproto v0.8.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect diff --git a/agent/go.sum b/agent/go.sum index 573a3a70..aed1c0d5 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -217,6 +217,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM= +github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= diff --git a/agent/pkg/agent/bare_metal.go b/agent/pkg/agent/bare_metal.go index 8176ebc8..f510cdba 100644 --- a/agent/pkg/agent/bare_metal.go +++ b/agent/pkg/agent/bare_metal.go @@ -8,6 +8,7 @@ import ( var ( installScript = `AGENT_NAME="{{.Name}}" \ AGENT_PORT="{{.Port}}" \ + AGENT_STREAM_PORT="{{.StreamPort}}" \ AGENT_CA_CERT="{{.CACert}}" \ AGENT_SSL_CERT="{{.SSLCert}}" \ {{ if eq .ContainerRuntime "nerdctl" -}} diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index 95b4c28f..8a1ea324 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -4,38 +4,61 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "net" "net/http" "net/url" "os" + "strconv" "strings" "time" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" + agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream" "github.com/yusing/godoxy/agent/pkg/certs" + gperr "github.com/yusing/goutils/errs" "github.com/yusing/goutils/version" ) type AgentConfig struct { - Addr string `json:"addr"` - Name string `json:"name"` - Version version.Version `json:"version" swaggertype:"string"` - Runtime ContainerRuntime `json:"runtime"` + AgentInfo + + Addr string `json:"addr"` httpClient *http.Client fasthttpClientHealthCheck *fasthttp.Client tlsConfig tls.Config - l zerolog.Logger + + // for stream + caCert *x509.Certificate + clientCert *tls.Certificate + isTCPStreamSupported bool + isUDPStreamSupported bool + streamServerAddr string + + l zerolog.Logger } // @name Agent +type AgentInfo struct { + Version version.Version `json:"version" swaggertype:"string"` + Name string `json:"name"` + Runtime ContainerRuntime `json:"runtime"` + StreamPort int `json:"stream_port"` +} + +// Deprecated. Replaced by EndpointInfo const ( - EndpointVersion = "/version" - EndpointName = "/name" - EndpointRuntime = "/runtime" + EndpointVersion = "/version" + EndpointName = "/name" + EndpointRuntime = "/runtime" +) + +const ( + EndpointInfo = "/info" EndpointProxyHTTP = "/proxy/http" EndpointHealth = "/health" EndpointLogs = "/logs" @@ -90,6 +113,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) if err != nil { return err } + cfg.clientCert = &clientCert // create tls config caCertPool := x509.NewCertPool() @@ -97,6 +121,14 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) if !ok { return errors.New("invalid ca certificate") } + // Keep the CA leaf for stream client dialing. + if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" { + return errors.New("invalid ca certificate") + } else if cert, err := x509.ParseCertificate(block.Bytes); err != nil { + return err + } else { + cfg.caCert = cert + } cfg.tlsConfig = tls.Config{ Certificates: []tls.Certificate{clientCert}, @@ -113,48 +145,97 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - // get agent name - name, _, err := cfg.fetchString(ctx, EndpointName) + status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo) if err != nil { return err } - cfg.Name = name + var streamUnsupportedErrs gperr.Builder + + if status == http.StatusOK { + if cfg.StreamPort <= 0 { + return fmt.Errorf("invalid agent stream port: %d", cfg.StreamPort) + } + host, _, err := net.SplitHostPort(cfg.Addr) + if err != nil { + return err + } + cfg.streamServerAddr = net.JoinHostPort(host, strconv.Itoa(cfg.StreamPort)) + + // test stream server connection + const fakeAddress = "localhost:8080" // it won't be used, just for testing + // test TCP stream support + conn, err := agentstream.NewTCPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert) + if err != nil { + streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err) + } else { + conn.Close() + cfg.isTCPStreamSupported = true + } + + // test UDP stream support + conn, err = agentstream.NewUDPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert) + if err != nil { + streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err) + } else { + conn.Close() + cfg.isUDPStreamSupported = true + } + } else { + // old agent does not support EndpointInfo + // fallback with old logic + cfg.isTCPStreamSupported = false + cfg.isUDPStreamSupported = false + streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling") + + // get agent name + name, _, err := cfg.fetchString(ctx, EndpointName) + if err != nil { + return err + } + + cfg.Name = name + + // check agent version + agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion) + if err != nil { + return err + } + + cfg.Version = version.Parse(agentVersion) + + // check agent runtime + runtime, status, err := cfg.fetchString(ctx, EndpointRuntime) + if err != nil { + return err + } + + switch status { + case http.StatusOK: + switch runtime { + case "docker": + cfg.Runtime = ContainerRuntimeDocker + // case "nerdctl": + // cfg.Runtime = ContainerRuntimeNerdctl + case "podman": + cfg.Runtime = ContainerRuntimePodman + default: + return fmt.Errorf("invalid agent runtime: %s", runtime) + } + case http.StatusNotFound: + // backward compatibility, old agent does not have runtime endpoint + cfg.Runtime = ContainerRuntimeDocker + default: + return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime) + } + } cfg.l = log.With().Str("agent", cfg.Name).Logger() - // check agent version - agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion) - if err != nil { - return err + if err := streamUnsupportedErrs.Error(); err != nil { + gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l) } - // check agent runtime - runtime, status, err := cfg.fetchString(ctx, EndpointRuntime) - if err != nil { - return err - } - switch status { - case http.StatusOK: - switch runtime { - case "docker": - cfg.Runtime = ContainerRuntimeDocker - // case "nerdctl": - // cfg.Runtime = ContainerRuntimeNerdctl - case "podman": - cfg.Runtime = ContainerRuntimePodman - default: - return fmt.Errorf("invalid agent runtime: %s", runtime) - } - case http.StatusNotFound: - // backward compatibility, old agent does not have runtime endpoint - cfg.Runtime = ContainerRuntimeDocker - default: - return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime) - } - - cfg.Version = version.Parse(agentVersion) - if serverVersion.IsNewerThanMajor(cfg.Version) { log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version) } @@ -163,6 +244,53 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte) return nil } +func (cfg *AgentConfig) getStreamServerAddr() (string, error) { + if cfg.streamServerAddr == "" { + return "", errors.New("agent stream server address is not initialized") + } + return cfg.streamServerAddr, nil +} + +// NewTCPClient creates a new TCP client for the agent. +// +// It returns an error if +// - the agent is not initialized +// - the agent does not support TCP stream tunneling +// - the agent stream server address is not initialized +func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) { + if cfg.caCert == nil || cfg.clientCert == nil { + return nil, errors.New("agent is not initialized") + } + if !cfg.isTCPStreamSupported { + return nil, errors.New("agent does not support TCP stream tunneling") + } + serverAddr, err := cfg.getStreamServerAddr() + if err != nil { + return nil, err + } + return agentstream.NewTCPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert) +} + +// NewUDPClient creates a new UDP client for the agent. +// +// It returns an error if +// - the agent is not initialized +// - the agent does not support UDP stream tunneling +// - the agent stream server address is not initialized +func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) { + if cfg.caCert == nil || cfg.clientCert == nil { + return nil, errors.New("agent is not initialized") + } + if !cfg.isUDPStreamSupported { + return nil, errors.New("agent does not support UDP stream tunneling") + } + serverAddr, err := cfg.getStreamServerAddr() + if err != nil { + return nil, err + } + return agentstream.NewUDPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert) +} + func (cfg *AgentConfig) Start(ctx context.Context) error { filepath, ok := certs.AgentCertsFilepath(cfg.Addr) if !ok { diff --git a/agent/pkg/agent/env.go b/agent/pkg/agent/env.go index f7683d02..a3727e39 100644 --- a/agent/pkg/agent/env.go +++ b/agent/pkg/agent/env.go @@ -5,6 +5,7 @@ type ( AgentEnvConfig struct { Name string Port int + StreamPort int CACert string SSLCert string ContainerRuntime ContainerRuntime diff --git a/agent/pkg/agent/http_requests.go b/agent/pkg/agent/http_requests.go index aeae1221..4414cf53 100644 --- a/agent/pkg/agent/http_requests.go +++ b/agent/pkg/agent/http_requests.go @@ -87,6 +87,34 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin return ret, resp.StatusCode, nil } +// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct +// +// It will return the status code of the response, and error if any. +// If the status code is not http.StatusOK, out will be unchanged but error will still be nil. +func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) { + resp, err := cfg.Do(ctx, "GET", endpoint, nil) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + data, release, err := httputils.ReadAllBody(resp) + if err != nil { + return 0, err + } + + defer release(data) + if resp.StatusCode != http.StatusOK { + return resp.StatusCode, nil + } + + err = sonic.Unmarshal(data, out) + if err != nil { + return 0, err + } + return resp.StatusCode, nil +} + func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) { transport := cfg.Transport() dialer := websocket.Dialer{ diff --git a/agent/pkg/agent/stream/PROTOCOL.md b/agent/pkg/agent/stream/PROTOCOL.md new file mode 100644 index 00000000..d78b38a8 --- /dev/null +++ b/agent/pkg/agent/stream/PROTOCOL.md @@ -0,0 +1,40 @@ +# Stream proxy protocol + +This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination. + +## Header + +The on-wire header is a fixed-size binary blob: + +- `Version` (8 bytes) +- `Host` (255 bytes, NUL padded) +- `Port` (5 bytes, NUL padded) +- `Checksum` (4 bytes, big-endian CRC32) + +Total: `headerSize = 8 + 255 + 5 + 4 = 272` bytes. + +Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`. + +See [`StreamRequestHeader`](payload.go:26). + +## TCP behavior + +1. Client establishes a TLS connection to the stream server. +2. Client sends exactly one header as a handshake. +3. After the handshake, both sides proxy raw TCP bytes between client and destination. + +Server reads the header using `io.ReadFull` to avoid dropping bytes. + +See [`NewTCPClient()`](tcp_client.go:15) and [`(*TCPServer).redirect()`](tcp_server.go:77). + +## UDP-over-DTLS behavior + +1. Client establishes a DTLS connection to the stream server. +2. Client sends exactly one header as a handshake. +3. After the handshake, both sides proxy raw UDP datagrams: + - client → destination: DTLS payload is written to destination `UDPConn` + - destination → client: destination payload is written back to the DTLS connection + +Responses do **not** include a header. + +See [`NewUDPClient()`](udp_client.go:17) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:67). diff --git a/agent/pkg/agent/stream/common.go b/agent/pkg/agent/stream/common.go new file mode 100644 index 00000000..18124a44 --- /dev/null +++ b/agent/pkg/agent/stream/common.go @@ -0,0 +1,57 @@ +package stream + +import ( + "net" + "time" + + "github.com/puzpuzpuz/xsync/v4" + "github.com/yusing/goutils/synk" +) + +const ( + dialTimeout = 10 * time.Second + readDeadline = 10 * time.Second +) + +var sizedPool = synk.GetSizedBytesPool() + +type CreateConnFunc[Conn net.Conn] func(host, port string) (Conn, error) +type ConnectionManager[Conn net.Conn] struct { + m *xsync.Map[string, Conn] + createConnection CreateConnFunc[Conn] +} + +func NewConnectionManager[Conn net.Conn](createConnection CreateConnFunc[Conn]) *ConnectionManager[Conn] { + return &ConnectionManager[Conn]{ + m: xsync.NewMap[string, Conn](), + createConnection: createConnection, + } +} + +func (c *ConnectionManager[Conn]) GetOrCreateDestConnection(clientConn net.Conn, host, port string) (ret Conn, connErr error) { + clientKey := clientConn.RemoteAddr().String() + ret, _ = c.m.LoadOrCompute(clientKey, func() (conn Conn, cancel bool) { + conn, connErr = c.createConnection(host, port) + if connErr != nil { + cancel = true + } + return + }) + + return +} + +func (c *ConnectionManager[Conn]) DeleteDestConnection(clientConn net.Conn) { + clientKey := clientConn.RemoteAddr().String() + conn, loaded := c.m.LoadAndDelete(clientKey) + if loaded { + conn.Close() + } +} + +func (c *ConnectionManager[Conn]) CloseAllConnections() { + for _, conn := range c.m.Range { + conn.Close() + } + c.m.Clear() +} diff --git a/agent/pkg/agent/stream/payload.go b/agent/pkg/agent/stream/payload.go new file mode 100644 index 00000000..ebc6c53a --- /dev/null +++ b/agent/pkg/agent/stream/payload.go @@ -0,0 +1,109 @@ +package stream + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "io" + "unsafe" +) + +const ( + versionSize = 8 + hostSize = 255 + portSize = 5 + checksumSize = 4 // crc32 checksum + + headerSize = versionSize + hostSize + portSize + checksumSize +) + +var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0} + +var ErrInvalidHeader = errors.New("invalid header") + +type StreamRequestHeader struct { + Version [versionSize]byte + Host [hostSize]byte + Port [portSize]byte + Checksum [checksumSize]byte +} + +type StreamRequestPayload struct { + StreamRequestHeader + Data []byte +} + +func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) { + if len(host) > hostSize { + return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host)) + } + if len(port) > portSize { + return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port)) + } + header := &StreamRequestHeader{} + copy(header.Version[:], version[:]) + copy(header.Host[:], host) + copy(header.Port[:], port) + header.updateChecksum() + return header, nil +} + +func ToHeader(buf [headerSize]byte) *StreamRequestHeader { + return (*StreamRequestHeader)(unsafe.Pointer(&buf[0])) +} + +// WriteTo implements the io.WriterTo interface. +func (p *StreamRequestPayload) WriteTo(w io.Writer) (n int64, err error) { + n1, err := w.Write(p.StreamRequestHeader.Bytes()) + if err != nil { + return + } + if len(p.Data) == 0 { + return int64(n1), nil + } + + n2, err := w.Write(p.Data) + if err != nil { + return + } + return int64(n1) + int64(n2), nil +} + +func (h *StreamRequestHeader) GetHostPort() (string, string) { + hostEnd := bytes.IndexByte(h.Host[:], 0) + portEnd := bytes.IndexByte(h.Port[:], 0) + if hostEnd == -1 { + hostEnd = hostSize + } + if portEnd == -1 { + portEnd = portSize + } + return string(h.Host[:hostEnd]), string(h.Port[:portEnd]) +} + +func (h *StreamRequestHeader) Validate() bool { + if h.Version != version { + return false + } + return h.validateChecksum() +} + +func (h *StreamRequestHeader) updateChecksum() { + checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum()) + binary.BigEndian.PutUint32(h.Checksum[:], checksum) +} + +func (h *StreamRequestHeader) validateChecksum() bool { + checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum()) + return checksum == binary.BigEndian.Uint32(h.Checksum[:]) +} + +func (h *StreamRequestHeader) BytesWithoutChecksum() []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize-checksumSize) +} + +func (h *StreamRequestHeader) Bytes() []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize) +} diff --git a/agent/pkg/agent/stream/payload_test.go b/agent/pkg/agent/stream/payload_test.go new file mode 100644 index 00000000..09424bcf --- /dev/null +++ b/agent/pkg/agent/stream/payload_test.go @@ -0,0 +1,53 @@ +package stream + +import ( + "bytes" + "testing" +) + +func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) { + h, err := NewStreamRequestHeader("example.com", "443") + if err != nil { + t.Fatalf("NewStreamRequestHeader: %v", err) + } + if !h.Validate() { + t.Fatalf("expected header to validate") + } + + var buf [headerSize]byte + copy(buf[:], h.Bytes()) + h2 := ToHeader(buf) + if !h2.Validate() { + t.Fatalf("expected round-tripped header to validate") + } + host, port := h2.GetHostPort() + if host != "example.com" || port != "443" { + t.Fatalf("unexpected host/port: %q:%q", host, port) + } +} + +func TestStreamRequestPayload_WriteTo_WritesFullHeader(t *testing.T) { + h, err := NewStreamRequestHeader("127.0.0.1", "53") + if err != nil { + t.Fatalf("NewStreamRequestHeader: %v", err) + } + + p := &StreamRequestPayload{StreamRequestHeader: *h, Data: []byte("hello")} + + var out bytes.Buffer + n, err := p.WriteTo(&out) + if err != nil { + t.Fatalf("WriteTo: %v", err) + } + if int(n) != headerSize+len(p.Data) { + t.Fatalf("unexpected bytes written: got %d want %d", n, headerSize+len(p.Data)) + } + + written := out.Bytes() + if len(written) != headerSize+len(p.Data) { + t.Fatalf("unexpected output size: got %d", len(written)) + } + if !bytes.Equal(written[:headerSize], h.Bytes()) { + t.Fatalf("expected full header (including checksum) to be written") + } +} diff --git a/agent/pkg/agent/stream/server_flow_test.go b/agent/pkg/agent/stream/server_flow_test.go new file mode 100644 index 00000000..59a1636f --- /dev/null +++ b/agent/pkg/agent/stream/server_flow_test.go @@ -0,0 +1,235 @@ +package stream + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "errors" + "io" + "math/big" + "net" + "testing" + "time" +) + +func newSerial(t *testing.T) *big.Int { + t.Helper() + sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + t.Fatalf("rand serial: %v", err) + } + return sn +} + +func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: newSerial(t), + Subject: pkix.Name{CommonName: "stream-test-ca"}, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("CreateCertificate(CA): %v", err) + } + cert, err := x509.ParseCertificate(der) + if err != nil { + t.Fatalf("ParseCertificate(CA): %v", err) + } + return cert, key +} + +func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: newSerial(t), + Subject: pkix.Name{CommonName: cn}, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{eku}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey) + if err != nil { + t.Fatalf("CreateCertificate(%s): %v", cn, err) + } + leaf, err := x509.ParseCertificate(der) + if err != nil { + t.Fatalf("ParseCertificate(%s): %v", cn, err) + } + return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf} +} + +func startTCPEcho(t *testing.T) (addr string, closeFn func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen tcp: %v", err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + for { + c, err := ln.Accept() + if err != nil { + return + } + go func(conn net.Conn) { + defer conn.Close() + _, _ = io.Copy(conn, conn) + }(c) + } + }() + + return ln.Addr().String(), func() { + _ = ln.Close() + <-done + } +} + +func startUDPEcho(t *testing.T) (addr string, closeFn func()) { + t.Helper() + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen udp: %v", err) + } + uc := pc.(*net.UDPConn) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 65535) + for { + n, raddr, err := uc.ReadFromUDP(buf) + if err != nil { + return + } + _, _ = uc.WriteToUDP(buf[:n], raddr) + } + }() + + return uc.LocalAddr().String(), func() { + _ = uc.Close() + <-done + } +} + +func TestTCPServer_FullFlow(t *testing.T) { + ca, caKey := genCA(t) + serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth) + clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth) + + dstAddr, closeDst := startTCPEcho(t) + defer closeDst() + + tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + t.Fatalf("ListenTCP: %v", err) + } + defer tcpLn.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv := NewTCPServer(ctx, tcpLn, ca, serverCert) + errCh := make(chan error, 1) + go func() { errCh <- srv.Start() }() + defer func() { + cancel() + _ = srv.Close() + _ = <-errCh + }() + + client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert) + if err != nil { + t.Fatalf("NewTCPClient: %v", err) + } + defer client.Close() + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + msg := []byte("ping over tcp") + if _, err := client.Write(msg); err != nil { + t.Fatalf("client.Write: %v", err) + } + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(client, buf); err != nil { + t.Fatalf("client.ReadFull: %v", err) + } + if string(buf) != string(msg) { + t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg)) + } +} + +func TestUDPServer_FullFlow(t *testing.T) { + ca, caKey := genCA(t) + serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth) + clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth) + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert) + errCh := make(chan error, 1) + go func() { errCh <- srv.Start() }() + defer func() { + cancel() + _ = srv.Close() + err := <-errCh + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) { + t.Logf("udp server exit: %v", err) + } + }() + + deadline := time.Now().Add(2 * time.Second) + for srv.listener == nil { + if time.Now().After(deadline) { + t.Fatalf("udp server listener did not start") + } + time.Sleep(10 * time.Millisecond) + } + + client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert) + if err != nil { + t.Fatalf("NewUDPClient: %v", err) + } + defer client.Close() + + _ = client.SetDeadline(time.Now().Add(2 * time.Second)) + msg := []byte("ping over udp") + if _, err := client.Write(msg); err != nil { + t.Fatalf("client.Write: %v", err) + } + + buf := make([]byte, 2048) + n, err := client.Read(buf) + if err != nil { + t.Fatalf("client.Read: %v", err) + } + if string(buf[:n]) != string(msg) { + t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg)) + } +} diff --git a/agent/pkg/agent/stream/tcp_client.go b/agent/pkg/agent/stream/tcp_client.go new file mode 100644 index 00000000..19419139 --- /dev/null +++ b/agent/pkg/agent/stream/tcp_client.go @@ -0,0 +1,91 @@ +package stream + +import ( + "crypto/tls" + "crypto/x509" + "net" + "time" +) + +type TCPClient struct { + conn net.Conn +} + +// NewTCPClient creates a new TCP client for the agent. +// +// It will establish a TLS connection and send a stream request header to the server. +// +// It returns an error if +// - the target address is invalid +// - the stream request header is invalid +// - the TLS configuration is invalid +// - the TLS connection fails +// - the stream request header is not sent +func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { + host, port, err := net.SplitHostPort(targetAddress) + if err != nil { + return nil, err + } + + header, err := NewStreamRequestHeader(host, port) + if err != nil { + return nil, err + } + + // Setup TLS configuration + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*clientCert}, + RootCAs: caCertPool, + MinVersion: tls.VersionTLS12, + } + + // Establish TLS connection + conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig) + if err != nil { + return nil, err + } + // Send the stream header once as a handshake. + if _, err := conn.Write(header.Bytes()); err != nil { + _ = conn.Close() + return nil, err + } + + return &TCPClient{ + conn: conn, + }, nil +} + +func (c *TCPClient) Read(p []byte) (n int, err error) { + return c.conn.Read(p) +} + +func (c *TCPClient) Write(p []byte) (n int, err error) { + return c.conn.Write(p) +} + +func (c *TCPClient) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *TCPClient) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *TCPClient) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *TCPClient) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *TCPClient) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *TCPClient) Close() error { + return c.conn.Close() +} diff --git a/agent/pkg/agent/stream/tcp_server.go b/agent/pkg/agent/stream/tcp_server.go new file mode 100644 index 00000000..f46c7291 --- /dev/null +++ b/agent/pkg/agent/stream/tcp_server.go @@ -0,0 +1,99 @@ +package stream + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io" + "net" + + ioutils "github.com/yusing/goutils/io" +) + +type TCPServer struct { + ctx context.Context + listener net.Listener + connMgr *ConnectionManager[net.Conn] +} + +func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer { + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*serverCert}, + ClientCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + MinVersion: tls.VersionTLS12, + } + + tcpListener := tls.NewListener(listener, tlsConfig) + s := &TCPServer{ + ctx: ctx, + listener: tcpListener, + } + s.connMgr = NewConnectionManager(s.createDestConnection) + return s +} + +func (s *TCPServer) Start() error { + for { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + conn, err := s.listener.Accept() + if err != nil { + return err + } + go s.handle(conn) + } + } +} + +func (s *TCPServer) Addr() net.Addr { + return s.listener.Addr() +} + +func (s *TCPServer) Close() error { + s.connMgr.CloseAllConnections() + return s.listener.Close() +} + +func (s *TCPServer) handle(conn net.Conn) { + defer conn.Close() + dst, err := s.redirect(conn) + if err != nil { + // TODO: log error + return + } + defer s.connMgr.DeleteDestConnection(conn) + pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst) + pipe.Start() +} + +func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) { + // Read the stream header once as a handshake. + var headerBuf [headerSize]byte + if _, err := io.ReadFull(conn, headerBuf[:]); err != nil { + return nil, err + } + + header := ToHeader(headerBuf) + if !header.Validate() { + return nil, ErrInvalidHeader + } + + // get destination connection + host, port := header.GetHostPort() + return s.connMgr.GetOrCreateDestConnection(conn, host, port) +} + +func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) { + addr := host + ":" + port + conn, err := net.DialTimeout("tcp", addr, dialTimeout) + if err != nil { + return nil, err + } + return conn, nil +} diff --git a/agent/pkg/agent/stream/udp_client.go b/agent/pkg/agent/stream/udp_client.go new file mode 100644 index 00000000..9365da0d --- /dev/null +++ b/agent/pkg/agent/stream/udp_client.go @@ -0,0 +1,99 @@ +package stream + +import ( + "crypto/tls" + "crypto/x509" + "net" + "time" + + "github.com/pion/dtls/v3" +) + +type UDPClient struct { + conn net.Conn +} + +// NewUDPClient creates a new UDP client for the agent. +// +// It will establish a DTLS connection and send a stream request header to the server. +// +// It returns an error if +// - the target address is invalid +// - the stream request header is invalid +// - the DTLS configuration is invalid +// - the DTLS connection fails +// - the stream request header is not sent +func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) { + host, port, err := net.SplitHostPort(targetAddress) + if err != nil { + return nil, err + } + + header, err := NewStreamRequestHeader(host, port) + if err != nil { + return nil, err + } + + // Setup DTLS configuration + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + dtlsConfig := &dtls.Config{ + Certificates: []tls.Certificate{*clientCert}, + RootCAs: caCertPool, + InsecureSkipVerify: false, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + } + + raddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + return nil, err + } + + // Establish DTLS connection + conn, err := dtls.Dial("udp", raddr, dtlsConfig) + if err != nil { + return nil, err + } + // Send the stream header once as a handshake. + if _, err := conn.Write(header.Bytes()); err != nil { + _ = conn.Close() + return nil, err + } + + return &UDPClient{ + conn: conn, + }, nil +} + +func (c *UDPClient) Read(p []byte) (n int, err error) { + return c.conn.Read(p) +} + +func (c *UDPClient) Write(p []byte) (n int, err error) { + return c.conn.Write(p) +} + +func (c *UDPClient) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *UDPClient) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *UDPClient) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *UDPClient) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *UDPClient) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (c *UDPClient) Close() error { + return c.conn.Close() +} diff --git a/agent/pkg/agent/stream/udp_server.go b/agent/pkg/agent/stream/udp_server.go new file mode 100644 index 00000000..894496de --- /dev/null +++ b/agent/pkg/agent/stream/udp_server.go @@ -0,0 +1,164 @@ +package stream + +import ( + "context" + "crypto/tls" + "crypto/x509" + "io" + "net" + "time" + + "github.com/pion/dtls/v3" +) + +type UDPServer struct { + ctx context.Context + laddr *net.UDPAddr + listener net.Listener + + dtlsConfig *dtls.Config + connMgr *ConnectionManager[*net.UDPConn] +} + +func NewUDPServer(ctx context.Context, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer { + caCertPool := x509.NewCertPool() + caCertPool.AddCert(caCert) + + dtlsConfig := &dtls.Config{ + Certificates: []tls.Certificate{*serverCert}, + ClientCAs: caCertPool, + ClientAuth: dtls.RequireAndVerifyClientCert, + ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, + } + + s := &UDPServer{ + ctx: ctx, + laddr: laddr, + dtlsConfig: dtlsConfig, + } + s.connMgr = NewConnectionManager(s.createDestConnection) + return s +} + +func (s *UDPServer) Start() error { + listener, err := dtls.Listen("udp", s.laddr, s.dtlsConfig) + if err != nil { + return err + } + s.listener = listener + + for { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + conn, err := s.listener.Accept() + if err != nil { + return err + } + go s.handleDTLSConnection(conn) + } + } +} + +func (s *UDPServer) Addr() net.Addr { + if s.listener != nil { + return s.listener.Addr() + } + return s.laddr +} + +func (s *UDPServer) Close() error { + s.connMgr.CloseAllConnections() + if s.listener != nil { + return s.listener.Close() + } + return nil +} + +func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) { + defer clientConn.Close() + + // Read the stream header once as a handshake. + var headerBuf [headerSize]byte + if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil { + // TODO: log error + return + } + header := ToHeader(headerBuf) + if !header.Validate() { + // TODO: log error + return + } + + host, port := header.GetHostPort() + dstConn, err := s.connMgr.GetOrCreateDestConnection(clientConn, host, port) + if err != nil { + // TODO: log error + return + } + defer s.connMgr.DeleteDestConnection(clientConn) + + go s.forwardFromDestination(dstConn, clientConn) + + buf := sizedPool.GetSized(65535) + defer sizedPool.Put(buf) + + for { + select { + case <-s.ctx.Done(): + return + default: + n, err := clientConn.Read(buf) + if err != nil { + // TODO: log error + return + } + if _, err := dstConn.Write(buf[:n]); err != nil { + // TODO: log error + return + } + } + } +} + +func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) { + addr := host + ":" + port + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + dstConn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, err + } + + return dstConn, nil +} + +func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) { + buffer := sizedPool.GetSized(65535) + defer sizedPool.Put(buffer) + + for { + select { + case <-s.ctx.Done(): + return + default: + _ = dstConn.SetReadDeadline(time.Now().Add(readDeadline)) + n, err := dstConn.Read(buffer) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return + } + // TODO: log error + return + } + if _, err := clientConn.Write(buffer[:n]); err != nil { + // TODO: log error + return + } + } + } +} diff --git a/agent/pkg/agent/templates/agent.compose.yml b/agent/pkg/agent/templates/agent.compose.yml deleted file mode 100644 index bb10d1f7..00000000 --- a/agent/pkg/agent/templates/agent.compose.yml +++ /dev/null @@ -1,44 +0,0 @@ -services: - agent: - image: "{{.Image}}" - container_name: godoxy-agent - restart: always - network_mode: host # do not change this - environment: - AGENT_NAME: "{{.Name}}" - AGENT_PORT: "{{.Port}}" - AGENT_CA_CERT: "{{.CACert}}" - AGENT_SSL_CERT: "{{.SSLCert}}" - # use agent as a docker socket proxy: [host]:port - # set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375) - LISTEN_ADDR: - POST: false - ALLOW_RESTARTS: false - ALLOW_START: false - ALLOW_STOP: false - AUTH: false - BUILD: false - COMMIT: false - CONFIGS: false - CONTAINERS: false - DISTRIBUTION: false - EVENTS: true - EXEC: false - GRPC: false - IMAGES: false - INFO: false - NETWORKS: false - NODES: false - PING: true - PLUGINS: false - SECRETS: false - SERVICES: false - SESSION: false - SWARM: false - SYSTEM: false - TASKS: false - VERSION: true - VOLUMES: false - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - ./data:/app/data diff --git a/agent/pkg/agent/templates/agent.compose.yml.tmpl b/agent/pkg/agent/templates/agent.compose.yml.tmpl index cc0864c0..aa13335c 100644 --- a/agent/pkg/agent/templates/agent.compose.yml.tmpl +++ b/agent/pkg/agent/templates/agent.compose.yml.tmpl @@ -5,7 +5,9 @@ services: restart: always {{ if eq .ContainerRuntime "podman" -}} ports: - - "{{.Port}}:{{.Port}}" + - "{{.Port}}:{{.Port}}/tcp" + - "{{.StreamPort}}:{{.StreamPort}}/tcp" + - "{{.StreamPort}}:{{.StreamPort}}/udp" {{ else -}} network_mode: host # do not change this {{ end -}} @@ -22,6 +24,7 @@ services: {{ end -}} AGENT_NAME: "{{.Name}}" AGENT_PORT: "{{.Port}}" + AGENT_STREAM_PORT: "{{.StreamPort}}" AGENT_CA_CERT: "{{.CACert}}" AGENT_SSL_CERT: "{{.SSLCert}}" # use agent as a docker socket proxy: [host]:port diff --git a/agent/pkg/env/env.go b/agent/pkg/env/env.go index 7dae9c7b..c38c1784 100644 --- a/agent/pkg/env/env.go +++ b/agent/pkg/env/env.go @@ -20,6 +20,7 @@ func DefaultAgentName() string { var ( AgentName string AgentPort int + AgentStreamPort int AgentSkipClientCertCheck bool AgentCACert string AgentSSLCert string @@ -35,6 +36,7 @@ func Load() { DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock") AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName()) AgentPort = env.GetEnvInt("AGENT_PORT", 8890) + AgentStreamPort = env.GetEnvInt("AGENT_STREAM_PORT", AgentPort+1) AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false) AgentCACert = env.GetEnvString("AGENT_CA_CERT", "") diff --git a/agent/pkg/handler/handler.go b/agent/pkg/handler/handler.go index 31401cf7..836a804f 100644 --- a/agent/pkg/handler/handler.go +++ b/agent/pkg/handler/handler.go @@ -1,9 +1,9 @@ package handler import ( - "fmt" "net/http" + "github.com/bytedance/sonic" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/yusing/godoxy/agent/pkg/agent" @@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler { } mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP) - mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, version.Get()) - }) - mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, env.AgentName) - }) - mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, env.Runtime) + mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) { + agentInfo := agent.AgentInfo{ + Version: version.Get(), + Name: env.AgentName, + Runtime: env.Runtime, + StreamPort: env.AgentStreamPort, + } + sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo) }) mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth) mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP) diff --git a/go.mod b/go.mod index ec2fc7ad..fc9f6c99 100644 --- a/go.mod +++ b/go.mod @@ -172,6 +172,9 @@ require ( github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect + github.com/pion/dtls/v3 v3.0.9 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/transport/v3 v3.1.1 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/pquerna/otp v1.5.0 // indirect github.com/stretchr/objx v0.5.3 // indirect diff --git a/go.sum b/go.sum index 5b290d61..161903b4 100644 --- a/go.sum +++ b/go.sum @@ -247,6 +247,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0 github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM= +github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= diff --git a/internal/route/stream.go b/internal/route/stream.go index 0c38ef74..3b194b71 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -110,9 +110,9 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) { switch rurl.Scheme { case "tcp": - return stream.NewTCPTCPStream(laddr, rurl.Host) + return stream.NewTCPTCPStream(laddr, rurl.Host, r.GetAgent()) case "udp": - return stream.NewUDPUDPStream(laddr, rurl.Host) + return stream.NewUDPUDPStream(laddr, rurl.Host, r.GetAgent()) } return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme) } diff --git a/internal/route/stream/tcp_tcp.go b/internal/route/stream/tcp_tcp.go index 1033d519..c8f64256 100644 --- a/internal/route/stream/tcp_tcp.go +++ b/internal/route/stream/tcp_tcp.go @@ -6,6 +6,7 @@ import ( "github.com/pires/go-proxyproto" "github.com/rs/zerolog" + "github.com/yusing/godoxy/agent/pkg/agent" "github.com/yusing/godoxy/internal/acl" "github.com/yusing/godoxy/internal/entrypoint" nettypes "github.com/yusing/godoxy/internal/net/types" @@ -17,6 +18,7 @@ type TCPTCPStream struct { listener net.Listener laddr *net.TCPAddr dst *net.TCPAddr + agent *agent.AgentConfig preDial nettypes.HookFunc onRead nettypes.HookFunc @@ -24,7 +26,7 @@ type TCPTCPStream struct { closed atomic.Bool } -func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { +func NewTCPTCPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) { dst, err := net.ResolveTCPAddr("tcp", dstAddr) if err != nil { return nil, err @@ -33,7 +35,7 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { if err != nil { return nil, err } - return &TCPTCPStream{laddr: laddr, dst: dst}, nil + return &TCPTCPStream{laddr: laddr, dst: dst, agent: agentCfg}, nil } func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) { @@ -126,7 +128,15 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { return } - dstConn, err := net.DialTCP("tcp", nil, s.dst) + var ( + dstConn net.Conn + err error + ) + if s.agent != nil { + dstConn, err = s.agent.NewTCPClient(s.dst.String()) + } else { + dstConn, err = net.DialTCP("tcp", nil, s.dst) + } if err != nil { if !s.closed.Load() { logErr(s, err, "failed to dial destination") @@ -140,7 +150,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) { } src := conn - dst := net.Conn(dstConn) + dst := dstConn if s.onRead != nil { src = &wrapperConn{ Conn: conn, diff --git a/internal/route/stream/udp_udp.go b/internal/route/stream/udp_udp.go index 20b813a9..b8580b6e 100644 --- a/internal/route/stream/udp_udp.go +++ b/internal/route/stream/udp_udp.go @@ -10,6 +10,7 @@ import ( "time" "github.com/rs/zerolog" + "github.com/yusing/godoxy/agent/pkg/agent" "github.com/yusing/godoxy/internal/acl" nettypes "github.com/yusing/godoxy/internal/net/types" "github.com/yusing/goutils/synk" @@ -22,6 +23,7 @@ type UDPUDPStream struct { laddr *net.UDPAddr dst *net.UDPAddr + agent *agent.AgentConfig preDial nettypes.HookFunc onRead nettypes.HookFunc @@ -35,7 +37,7 @@ type UDPUDPStream struct { type udpUDPConn struct { srcAddr *net.UDPAddr - dstConn *net.UDPConn + dstConn net.Conn listener net.PacketConn lastUsed atomic.Time closed atomic.Bool @@ -51,7 +53,7 @@ const ( var bufPool = synk.GetSizedBytesPool() -func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { +func NewUDPUDPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) { dst, err := net.ResolveUDPAddr("udp", dstAddr) if err != nil { return nil, err @@ -63,6 +65,7 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) { return &UDPUDPStream{ laddr: laddr, dst: dst, + agent: agentCfg, conns: make(map[string]*udpUDPConn), }, nil } @@ -189,8 +192,16 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd } } - // Create UDP connection to destination - dstConn, err := net.DialUDP("udp", nil, s.dst) + // Create connection to destination (direct UDP or via agent stream tunnel) + var ( + dstConn net.Conn + err error + ) + if s.agent != nil { + dstConn, err = s.agent.NewUDPClient(s.dst.String()) + } else { + dstConn, err = net.DialUDP("udp", nil, s.dst) + } if err != nil { logErr(s, err, "failed to dial dst") return nil, false @@ -205,7 +216,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd // Send initial data before starting response handler if !conn.forwardToDestination(initialData) { - dstConn.Close() + _ = dstConn.Close() return nil, false } @@ -328,6 +339,6 @@ func (conn *udpUDPConn) Close() { conn.closed.Store(true) - conn.dstConn.Close() + _ = conn.dstConn.Close() conn.dstConn = nil }