feat(agent): agent stream tunneling with TLS and dTLS (UDP); combined agent APIs

- Add `StreamPort` configuration to agent configuration and environment variables
- Implement TCP and UDP stream client support in agent package
- Update agent verification to test stream connectivity (TCP/UDP)
- Add `/info` endpoint to agent HTTP handler for version, name, runtime, and stream port
- Remove /version, /name, /runtime APIs, replaced by /info
- Update agent compose template to expose stream port for TCP and UDP
- Update agent creation API to optionally specify stream port (defaults to port + 1)
- Modify `StreamRoute` to pass agent configuration to stream implementations
- Update `TCPTCPStream` and `UDPUDPStream` to use agent stream tunneling when agent is configured
- Add support for both direct connections and agent-tunneled connections in stream routes

This enables agents to handle TCP and UDP route tunneling, expanding the proxy capabilities beyond HTTP-only connections.
This commit is contained in:
yusing
2026-01-07 00:44:12 +08:00
parent a44b9e352c
commit fe619f1dd9
25 changed files with 1225 additions and 107 deletions

View File

@@ -8,6 +8,7 @@ import (
var (
installScript = `AGENT_NAME="{{.Name}}" \
AGENT_PORT="{{.Port}}" \
AGENT_STREAM_PORT="{{.StreamPort}}" \
AGENT_CA_CERT="{{.CACert}}" \
AGENT_SSL_CERT="{{.SSLCert}}" \
{{ if eq .ContainerRuntime "nerdctl" -}}

View File

@@ -4,38 +4,61 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream"
"github.com/yusing/godoxy/agent/pkg/certs"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/version"
)
type AgentConfig struct {
Addr string `json:"addr"`
Name string `json:"name"`
Version version.Version `json:"version" swaggertype:"string"`
Runtime ContainerRuntime `json:"runtime"`
AgentInfo
Addr string `json:"addr"`
httpClient *http.Client
fasthttpClientHealthCheck *fasthttp.Client
tlsConfig tls.Config
l zerolog.Logger
// for stream
caCert *x509.Certificate
clientCert *tls.Certificate
isTCPStreamSupported bool
isUDPStreamSupported bool
streamServerAddr string
l zerolog.Logger
} // @name Agent
type AgentInfo struct {
Version version.Version `json:"version" swaggertype:"string"`
Name string `json:"name"`
Runtime ContainerRuntime `json:"runtime"`
StreamPort int `json:"stream_port"`
}
// Deprecated. Replaced by EndpointInfo
const (
EndpointVersion = "/version"
EndpointName = "/name"
EndpointRuntime = "/runtime"
EndpointVersion = "/version"
EndpointName = "/name"
EndpointRuntime = "/runtime"
)
const (
EndpointInfo = "/info"
EndpointProxyHTTP = "/proxy/http"
EndpointHealth = "/health"
EndpointLogs = "/logs"
@@ -90,6 +113,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
if err != nil {
return err
}
cfg.clientCert = &clientCert
// create tls config
caCertPool := x509.NewCertPool()
@@ -97,6 +121,14 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
if !ok {
return errors.New("invalid ca certificate")
}
// Keep the CA leaf for stream client dialing.
if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" {
return errors.New("invalid ca certificate")
} else if cert, err := x509.ParseCertificate(block.Bytes); err != nil {
return err
} else {
cfg.caCert = cert
}
cfg.tlsConfig = tls.Config{
Certificates: []tls.Certificate{clientCert},
@@ -113,48 +145,97 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// get agent name
name, _, err := cfg.fetchString(ctx, EndpointName)
status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo)
if err != nil {
return err
}
cfg.Name = name
var streamUnsupportedErrs gperr.Builder
if status == http.StatusOK {
if cfg.StreamPort <= 0 {
return fmt.Errorf("invalid agent stream port: %d", cfg.StreamPort)
}
host, _, err := net.SplitHostPort(cfg.Addr)
if err != nil {
return err
}
cfg.streamServerAddr = net.JoinHostPort(host, strconv.Itoa(cfg.StreamPort))
// test stream server connection
const fakeAddress = "localhost:8080" // it won't be used, just for testing
// test TCP stream support
conn, err := agentstream.NewTCPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
if err != nil {
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
} else {
conn.Close()
cfg.isTCPStreamSupported = true
}
// test UDP stream support
conn, err = agentstream.NewUDPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
if err != nil {
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
} else {
conn.Close()
cfg.isUDPStreamSupported = true
}
} else {
// old agent does not support EndpointInfo
// fallback with old logic
cfg.isTCPStreamSupported = false
cfg.isUDPStreamSupported = false
streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling")
// get agent name
name, _, err := cfg.fetchString(ctx, EndpointName)
if err != nil {
return err
}
cfg.Name = name
// check agent version
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
if err != nil {
return err
}
cfg.Version = version.Parse(agentVersion)
// check agent runtime
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
if err != nil {
return err
}
switch status {
case http.StatusOK:
switch runtime {
case "docker":
cfg.Runtime = ContainerRuntimeDocker
// case "nerdctl":
// cfg.Runtime = ContainerRuntimeNerdctl
case "podman":
cfg.Runtime = ContainerRuntimePodman
default:
return fmt.Errorf("invalid agent runtime: %s", runtime)
}
case http.StatusNotFound:
// backward compatibility, old agent does not have runtime endpoint
cfg.Runtime = ContainerRuntimeDocker
default:
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
}
}
cfg.l = log.With().Str("agent", cfg.Name).Logger()
// check agent version
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
if err != nil {
return err
if err := streamUnsupportedErrs.Error(); err != nil {
gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l)
}
// check agent runtime
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
if err != nil {
return err
}
switch status {
case http.StatusOK:
switch runtime {
case "docker":
cfg.Runtime = ContainerRuntimeDocker
// case "nerdctl":
// cfg.Runtime = ContainerRuntimeNerdctl
case "podman":
cfg.Runtime = ContainerRuntimePodman
default:
return fmt.Errorf("invalid agent runtime: %s", runtime)
}
case http.StatusNotFound:
// backward compatibility, old agent does not have runtime endpoint
cfg.Runtime = ContainerRuntimeDocker
default:
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
}
cfg.Version = version.Parse(agentVersion)
if serverVersion.IsNewerThanMajor(cfg.Version) {
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version)
}
@@ -163,6 +244,53 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
return nil
}
func (cfg *AgentConfig) getStreamServerAddr() (string, error) {
if cfg.streamServerAddr == "" {
return "", errors.New("agent stream server address is not initialized")
}
return cfg.streamServerAddr, nil
}
// NewTCPClient creates a new TCP client for the agent.
//
// It returns an error if
// - the agent is not initialized
// - the agent does not support TCP stream tunneling
// - the agent stream server address is not initialized
func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) {
if cfg.caCert == nil || cfg.clientCert == nil {
return nil, errors.New("agent is not initialized")
}
if !cfg.isTCPStreamSupported {
return nil, errors.New("agent does not support TCP stream tunneling")
}
serverAddr, err := cfg.getStreamServerAddr()
if err != nil {
return nil, err
}
return agentstream.NewTCPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
}
// NewUDPClient creates a new UDP client for the agent.
//
// It returns an error if
// - the agent is not initialized
// - the agent does not support UDP stream tunneling
// - the agent stream server address is not initialized
func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) {
if cfg.caCert == nil || cfg.clientCert == nil {
return nil, errors.New("agent is not initialized")
}
if !cfg.isUDPStreamSupported {
return nil, errors.New("agent does not support UDP stream tunneling")
}
serverAddr, err := cfg.getStreamServerAddr()
if err != nil {
return nil, err
}
return agentstream.NewUDPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
}
func (cfg *AgentConfig) Start(ctx context.Context) error {
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
if !ok {

View File

@@ -5,6 +5,7 @@ type (
AgentEnvConfig struct {
Name string
Port int
StreamPort int
CACert string
SSLCert string
ContainerRuntime ContainerRuntime

View File

@@ -87,6 +87,34 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin
return ret, resp.StatusCode, nil
}
// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct
//
// It will return the status code of the response, and error if any.
// If the status code is not http.StatusOK, out will be unchanged but error will still be nil.
func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) {
resp, err := cfg.Do(ctx, "GET", endpoint, nil)
if err != nil {
return 0, err
}
defer resp.Body.Close()
data, release, err := httputils.ReadAllBody(resp)
if err != nil {
return 0, err
}
defer release(data)
if resp.StatusCode != http.StatusOK {
return resp.StatusCode, nil
}
err = sonic.Unmarshal(data, out)
if err != nil {
return 0, err
}
return resp.StatusCode, nil
}
func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) {
transport := cfg.Transport()
dialer := websocket.Dialer{

View File

@@ -0,0 +1,40 @@
# Stream proxy protocol
This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination.
## Header
The on-wire header is a fixed-size binary blob:
- `Version` (8 bytes)
- `Host` (255 bytes, NUL padded)
- `Port` (5 bytes, NUL padded)
- `Checksum` (4 bytes, big-endian CRC32)
Total: `headerSize = 8 + 255 + 5 + 4 = 272` bytes.
Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`.
See [`StreamRequestHeader`](payload.go:26).
## TCP behavior
1. Client establishes a TLS connection to the stream server.
2. Client sends exactly one header as a handshake.
3. After the handshake, both sides proxy raw TCP bytes between client and destination.
Server reads the header using `io.ReadFull` to avoid dropping bytes.
See [`NewTCPClient()`](tcp_client.go:15) and [`(*TCPServer).redirect()`](tcp_server.go:77).
## UDP-over-DTLS behavior
1. Client establishes a DTLS connection to the stream server.
2. Client sends exactly one header as a handshake.
3. After the handshake, both sides proxy raw UDP datagrams:
- client → destination: DTLS payload is written to destination `UDPConn`
- destination → client: destination payload is written back to the DTLS connection
Responses do **not** include a header.
See [`NewUDPClient()`](udp_client.go:17) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:67).

View File

@@ -0,0 +1,57 @@
package stream
import (
"net"
"time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/goutils/synk"
)
const (
dialTimeout = 10 * time.Second
readDeadline = 10 * time.Second
)
var sizedPool = synk.GetSizedBytesPool()
type CreateConnFunc[Conn net.Conn] func(host, port string) (Conn, error)
type ConnectionManager[Conn net.Conn] struct {
m *xsync.Map[string, Conn]
createConnection CreateConnFunc[Conn]
}
func NewConnectionManager[Conn net.Conn](createConnection CreateConnFunc[Conn]) *ConnectionManager[Conn] {
return &ConnectionManager[Conn]{
m: xsync.NewMap[string, Conn](),
createConnection: createConnection,
}
}
func (c *ConnectionManager[Conn]) GetOrCreateDestConnection(clientConn net.Conn, host, port string) (ret Conn, connErr error) {
clientKey := clientConn.RemoteAddr().String()
ret, _ = c.m.LoadOrCompute(clientKey, func() (conn Conn, cancel bool) {
conn, connErr = c.createConnection(host, port)
if connErr != nil {
cancel = true
}
return
})
return
}
func (c *ConnectionManager[Conn]) DeleteDestConnection(clientConn net.Conn) {
clientKey := clientConn.RemoteAddr().String()
conn, loaded := c.m.LoadAndDelete(clientKey)
if loaded {
conn.Close()
}
}
func (c *ConnectionManager[Conn]) CloseAllConnections() {
for _, conn := range c.m.Range {
conn.Close()
}
c.m.Clear()
}

View File

@@ -0,0 +1,109 @@
package stream
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"unsafe"
)
const (
versionSize = 8
hostSize = 255
portSize = 5
checksumSize = 4 // crc32 checksum
headerSize = versionSize + hostSize + portSize + checksumSize
)
var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0}
var ErrInvalidHeader = errors.New("invalid header")
type StreamRequestHeader struct {
Version [versionSize]byte
Host [hostSize]byte
Port [portSize]byte
Checksum [checksumSize]byte
}
type StreamRequestPayload struct {
StreamRequestHeader
Data []byte
}
func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) {
if len(host) > hostSize {
return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host))
}
if len(port) > portSize {
return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port))
}
header := &StreamRequestHeader{}
copy(header.Version[:], version[:])
copy(header.Host[:], host)
copy(header.Port[:], port)
header.updateChecksum()
return header, nil
}
func ToHeader(buf [headerSize]byte) *StreamRequestHeader {
return (*StreamRequestHeader)(unsafe.Pointer(&buf[0]))
}
// WriteTo implements the io.WriterTo interface.
func (p *StreamRequestPayload) WriteTo(w io.Writer) (n int64, err error) {
n1, err := w.Write(p.StreamRequestHeader.Bytes())
if err != nil {
return
}
if len(p.Data) == 0 {
return int64(n1), nil
}
n2, err := w.Write(p.Data)
if err != nil {
return
}
return int64(n1) + int64(n2), nil
}
func (h *StreamRequestHeader) GetHostPort() (string, string) {
hostEnd := bytes.IndexByte(h.Host[:], 0)
portEnd := bytes.IndexByte(h.Port[:], 0)
if hostEnd == -1 {
hostEnd = hostSize
}
if portEnd == -1 {
portEnd = portSize
}
return string(h.Host[:hostEnd]), string(h.Port[:portEnd])
}
func (h *StreamRequestHeader) Validate() bool {
if h.Version != version {
return false
}
return h.validateChecksum()
}
func (h *StreamRequestHeader) updateChecksum() {
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
binary.BigEndian.PutUint32(h.Checksum[:], checksum)
}
func (h *StreamRequestHeader) validateChecksum() bool {
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
return checksum == binary.BigEndian.Uint32(h.Checksum[:])
}
func (h *StreamRequestHeader) BytesWithoutChecksum() []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize-checksumSize)
}
func (h *StreamRequestHeader) Bytes() []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize)
}

View File

@@ -0,0 +1,53 @@
package stream
import (
"bytes"
"testing"
)
func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) {
h, err := NewStreamRequestHeader("example.com", "443")
if err != nil {
t.Fatalf("NewStreamRequestHeader: %v", err)
}
if !h.Validate() {
t.Fatalf("expected header to validate")
}
var buf [headerSize]byte
copy(buf[:], h.Bytes())
h2 := ToHeader(buf)
if !h2.Validate() {
t.Fatalf("expected round-tripped header to validate")
}
host, port := h2.GetHostPort()
if host != "example.com" || port != "443" {
t.Fatalf("unexpected host/port: %q:%q", host, port)
}
}
func TestStreamRequestPayload_WriteTo_WritesFullHeader(t *testing.T) {
h, err := NewStreamRequestHeader("127.0.0.1", "53")
if err != nil {
t.Fatalf("NewStreamRequestHeader: %v", err)
}
p := &StreamRequestPayload{StreamRequestHeader: *h, Data: []byte("hello")}
var out bytes.Buffer
n, err := p.WriteTo(&out)
if err != nil {
t.Fatalf("WriteTo: %v", err)
}
if int(n) != headerSize+len(p.Data) {
t.Fatalf("unexpected bytes written: got %d want %d", n, headerSize+len(p.Data))
}
written := out.Bytes()
if len(written) != headerSize+len(p.Data) {
t.Fatalf("unexpected output size: got %d", len(written))
}
if !bytes.Equal(written[:headerSize], h.Bytes()) {
t.Fatalf("expected full header (including checksum) to be written")
}
}

View File

@@ -0,0 +1,235 @@
package stream
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"io"
"math/big"
"net"
"testing"
"time"
)
func newSerial(t *testing.T) *big.Int {
t.Helper()
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
t.Fatalf("rand serial: %v", err)
}
return sn
}
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: newSerial(t),
Subject: pkix.Name{CommonName: "stream-test-ca"},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate(CA): %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate(CA): %v", err)
}
return cert, key
}
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: newSerial(t),
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{eku},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
if err != nil {
t.Fatalf("CreateCertificate(%s): %v", cn, err)
}
leaf, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate(%s): %v", cn, err)
}
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
}
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen tcp: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
for {
c, err := ln.Accept()
if err != nil {
return
}
go func(conn net.Conn) {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}(c)
}
}()
return ln.Addr().String(), func() {
_ = ln.Close()
<-done
}
}
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
t.Helper()
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen udp: %v", err)
}
uc := pc.(*net.UDPConn)
done := make(chan struct{})
go func() {
defer close(done)
buf := make([]byte, 65535)
for {
n, raddr, err := uc.ReadFromUDP(buf)
if err != nil {
return
}
_, _ = uc.WriteToUDP(buf[:n], raddr)
}
}()
return uc.LocalAddr().String(), func() {
_ = uc.Close()
<-done
}
}
func TestTCPServer_FullFlow(t *testing.T) {
ca, caKey := genCA(t)
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
dstAddr, closeDst := startTCPEcho(t)
defer closeDst()
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenTCP: %v", err)
}
defer tcpLn.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
_ = <-errCh
}()
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
if err != nil {
t.Fatalf("NewTCPClient: %v", err)
}
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over tcp")
if _, err := client.Write(msg); err != nil {
t.Fatalf("client.Write: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(client, buf); err != nil {
t.Fatalf("client.ReadFull: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
}
}
func TestUDPServer_FullFlow(t *testing.T) {
ca, caKey := genCA(t)
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
dstAddr, closeDst := startUDPEcho(t)
defer closeDst()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
err := <-errCh
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
t.Logf("udp server exit: %v", err)
}
}()
deadline := time.Now().Add(2 * time.Second)
for srv.listener == nil {
if time.Now().After(deadline) {
t.Fatalf("udp server listener did not start")
}
time.Sleep(10 * time.Millisecond)
}
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
if err != nil {
t.Fatalf("NewUDPClient: %v", err)
}
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over udp")
if _, err := client.Write(msg); err != nil {
t.Fatalf("client.Write: %v", err)
}
buf := make([]byte, 2048)
n, err := client.Read(buf)
if err != nil {
t.Fatalf("client.Read: %v", err)
}
if string(buf[:n]) != string(msg) {
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
}
}

View File

@@ -0,0 +1,91 @@
package stream
import (
"crypto/tls"
"crypto/x509"
"net"
"time"
)
type TCPClient struct {
conn net.Conn
}
// NewTCPClient creates a new TCP client for the agent.
//
// It will establish a TLS connection and send a stream request header to the server.
//
// It returns an error if
// - the target address is invalid
// - the stream request header is invalid
// - the TLS configuration is invalid
// - the TLS connection fails
// - the stream request header is not sent
func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
host, port, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
header, err := NewStreamRequestHeader(host, port)
if err != nil {
return nil, err
}
// Setup TLS configuration
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*clientCert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}
// Establish TLS connection
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig)
if err != nil {
return nil, err
}
// Send the stream header once as a handshake.
if _, err := conn.Write(header.Bytes()); err != nil {
_ = conn.Close()
return nil, err
}
return &TCPClient{
conn: conn,
}, nil
}
func (c *TCPClient) Read(p []byte) (n int, err error) {
return c.conn.Read(p)
}
func (c *TCPClient) Write(p []byte) (n int, err error) {
return c.conn.Write(p)
}
func (c *TCPClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *TCPClient) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *TCPClient) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *TCPClient) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *TCPClient) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *TCPClient) Close() error {
return c.conn.Close()
}

View File

@@ -0,0 +1,99 @@
package stream
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
ioutils "github.com/yusing/goutils/io"
)
type TCPServer struct {
ctx context.Context
listener net.Listener
connMgr *ConnectionManager[net.Conn]
}
func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer {
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*serverCert},
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
}
tcpListener := tls.NewListener(listener, tlsConfig)
s := &TCPServer{
ctx: ctx,
listener: tcpListener,
}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
}
func (s *TCPServer) Start() error {
for {
select {
case <-s.ctx.Done():
return s.ctx.Err()
default:
conn, err := s.listener.Accept()
if err != nil {
return err
}
go s.handle(conn)
}
}
}
func (s *TCPServer) Addr() net.Addr {
return s.listener.Addr()
}
func (s *TCPServer) Close() error {
s.connMgr.CloseAllConnections()
return s.listener.Close()
}
func (s *TCPServer) handle(conn net.Conn) {
defer conn.Close()
dst, err := s.redirect(conn)
if err != nil {
// TODO: log error
return
}
defer s.connMgr.DeleteDestConnection(conn)
pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst)
pipe.Start()
}
func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) {
// Read the stream header once as a handshake.
var headerBuf [headerSize]byte
if _, err := io.ReadFull(conn, headerBuf[:]); err != nil {
return nil, err
}
header := ToHeader(headerBuf)
if !header.Validate() {
return nil, ErrInvalidHeader
}
// get destination connection
host, port := header.GetHostPort()
return s.connMgr.GetOrCreateDestConnection(conn, host, port)
}
func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) {
addr := host + ":" + port
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
return nil, err
}
return conn, nil
}

View File

@@ -0,0 +1,99 @@
package stream
import (
"crypto/tls"
"crypto/x509"
"net"
"time"
"github.com/pion/dtls/v3"
)
type UDPClient struct {
conn net.Conn
}
// NewUDPClient creates a new UDP client for the agent.
//
// It will establish a DTLS connection and send a stream request header to the server.
//
// It returns an error if
// - the target address is invalid
// - the stream request header is invalid
// - the DTLS configuration is invalid
// - the DTLS connection fails
// - the stream request header is not sent
func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
host, port, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
header, err := NewStreamRequestHeader(host, port)
if err != nil {
return nil, err
}
// Setup DTLS configuration
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
dtlsConfig := &dtls.Config{
Certificates: []tls.Certificate{*clientCert},
RootCAs: caCertPool,
InsecureSkipVerify: false,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
}
raddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
return nil, err
}
// Establish DTLS connection
conn, err := dtls.Dial("udp", raddr, dtlsConfig)
if err != nil {
return nil, err
}
// Send the stream header once as a handshake.
if _, err := conn.Write(header.Bytes()); err != nil {
_ = conn.Close()
return nil, err
}
return &UDPClient{
conn: conn,
}, nil
}
func (c *UDPClient) Read(p []byte) (n int, err error) {
return c.conn.Read(p)
}
func (c *UDPClient) Write(p []byte) (n int, err error) {
return c.conn.Write(p)
}
func (c *UDPClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *UDPClient) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *UDPClient) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *UDPClient) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *UDPClient) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *UDPClient) Close() error {
return c.conn.Close()
}

View File

@@ -0,0 +1,164 @@
package stream
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"
"github.com/pion/dtls/v3"
)
type UDPServer struct {
ctx context.Context
laddr *net.UDPAddr
listener net.Listener
dtlsConfig *dtls.Config
connMgr *ConnectionManager[*net.UDPConn]
}
func NewUDPServer(ctx context.Context, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer {
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
dtlsConfig := &dtls.Config{
Certificates: []tls.Certificate{*serverCert},
ClientCAs: caCertPool,
ClientAuth: dtls.RequireAndVerifyClientCert,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
}
s := &UDPServer{
ctx: ctx,
laddr: laddr,
dtlsConfig: dtlsConfig,
}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
}
func (s *UDPServer) Start() error {
listener, err := dtls.Listen("udp", s.laddr, s.dtlsConfig)
if err != nil {
return err
}
s.listener = listener
for {
select {
case <-s.ctx.Done():
return s.ctx.Err()
default:
conn, err := s.listener.Accept()
if err != nil {
return err
}
go s.handleDTLSConnection(conn)
}
}
}
func (s *UDPServer) Addr() net.Addr {
if s.listener != nil {
return s.listener.Addr()
}
return s.laddr
}
func (s *UDPServer) Close() error {
s.connMgr.CloseAllConnections()
if s.listener != nil {
return s.listener.Close()
}
return nil
}
func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) {
defer clientConn.Close()
// Read the stream header once as a handshake.
var headerBuf [headerSize]byte
if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil {
// TODO: log error
return
}
header := ToHeader(headerBuf)
if !header.Validate() {
// TODO: log error
return
}
host, port := header.GetHostPort()
dstConn, err := s.connMgr.GetOrCreateDestConnection(clientConn, host, port)
if err != nil {
// TODO: log error
return
}
defer s.connMgr.DeleteDestConnection(clientConn)
go s.forwardFromDestination(dstConn, clientConn)
buf := sizedPool.GetSized(65535)
defer sizedPool.Put(buf)
for {
select {
case <-s.ctx.Done():
return
default:
n, err := clientConn.Read(buf)
if err != nil {
// TODO: log error
return
}
if _, err := dstConn.Write(buf[:n]); err != nil {
// TODO: log error
return
}
}
}
}
func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) {
addr := host + ":" + port
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dstConn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
return nil, err
}
return dstConn, nil
}
func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) {
buffer := sizedPool.GetSized(65535)
defer sizedPool.Put(buffer)
for {
select {
case <-s.ctx.Done():
return
default:
_ = dstConn.SetReadDeadline(time.Now().Add(readDeadline))
n, err := dstConn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return
}
// TODO: log error
return
}
if _, err := clientConn.Write(buffer[:n]); err != nil {
// TODO: log error
return
}
}
}
}

View File

@@ -1,44 +0,0 @@
services:
agent:
image: "{{.Image}}"
container_name: godoxy-agent
restart: always
network_mode: host # do not change this
environment:
AGENT_NAME: "{{.Name}}"
AGENT_PORT: "{{.Port}}"
AGENT_CA_CERT: "{{.CACert}}"
AGENT_SSL_CERT: "{{.SSLCert}}"
# use agent as a docker socket proxy: [host]:port
# set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375)
LISTEN_ADDR:
POST: false
ALLOW_RESTARTS: false
ALLOW_START: false
ALLOW_STOP: false
AUTH: false
BUILD: false
COMMIT: false
CONFIGS: false
CONTAINERS: false
DISTRIBUTION: false
EVENTS: true
EXEC: false
GRPC: false
IMAGES: false
INFO: false
NETWORKS: false
NODES: false
PING: true
PLUGINS: false
SECRETS: false
SERVICES: false
SESSION: false
SWARM: false
SYSTEM: false
TASKS: false
VERSION: true
VOLUMES: false
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ./data:/app/data

View File

@@ -5,7 +5,9 @@ services:
restart: always
{{ if eq .ContainerRuntime "podman" -}}
ports:
- "{{.Port}}:{{.Port}}"
- "{{.Port}}:{{.Port}}/tcp"
- "{{.StreamPort}}:{{.StreamPort}}/tcp"
- "{{.StreamPort}}:{{.StreamPort}}/udp"
{{ else -}}
network_mode: host # do not change this
{{ end -}}
@@ -22,6 +24,7 @@ services:
{{ end -}}
AGENT_NAME: "{{.Name}}"
AGENT_PORT: "{{.Port}}"
AGENT_STREAM_PORT: "{{.StreamPort}}"
AGENT_CA_CERT: "{{.CACert}}"
AGENT_SSL_CERT: "{{.SSLCert}}"
# use agent as a docker socket proxy: [host]:port

View File

@@ -20,6 +20,7 @@ func DefaultAgentName() string {
var (
AgentName string
AgentPort int
AgentStreamPort int
AgentSkipClientCertCheck bool
AgentCACert string
AgentSSLCert string
@@ -35,6 +36,7 @@ func Load() {
DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock")
AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName())
AgentPort = env.GetEnvInt("AGENT_PORT", 8890)
AgentStreamPort = env.GetEnvInt("AGENT_STREAM_PORT", AgentPort+1)
AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
AgentCACert = env.GetEnvString("AGENT_CA_CERT", "")

View File

@@ -1,9 +1,9 @@
package handler
import (
"fmt"
"net/http"
"github.com/bytedance/sonic"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/yusing/godoxy/agent/pkg/agent"
@@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler {
}
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, version.Get())
})
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, env.AgentName)
})
mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, env.Runtime)
mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) {
agentInfo := agent.AgentInfo{
Version: version.Get(),
Name: env.AgentName,
Runtime: env.Runtime,
StreamPort: env.AgentStreamPort,
}
sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo)
})
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP)