feat(agent): agent stream tunneling with TLS and dTLS (UDP); combined agent APIs

- Add `StreamPort` configuration to agent configuration and environment variables
- Implement TCP and UDP stream client support in agent package
- Update agent verification to test stream connectivity (TCP/UDP)
- Add `/info` endpoint to agent HTTP handler for version, name, runtime, and stream port
- Remove /version, /name, /runtime APIs, replaced by /info
- Update agent compose template to expose stream port for TCP and UDP
- Update agent creation API to optionally specify stream port (defaults to port + 1)
- Modify `StreamRoute` to pass agent configuration to stream implementations
- Update `TCPTCPStream` and `UDPUDPStream` to use agent stream tunneling when agent is configured
- Add support for both direct connections and agent-tunneled connections in stream routes

This enables agents to handle TCP and UDP route tunneling, expanding the proxy capabilities beyond HTTP-only connections.
This commit is contained in:
yusing
2026-01-07 00:44:12 +08:00
parent a44b9e352c
commit fe619f1dd9
25 changed files with 1225 additions and 107 deletions

View File

@@ -1,15 +1,18 @@
package main
import (
"net"
"os"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/agent/pkg/agent/stream"
"github.com/yusing/godoxy/agent/pkg/env"
"github.com/yusing/godoxy/agent/pkg/server"
"github.com/yusing/godoxy/internal/metrics/systeminfo"
socketproxy "github.com/yusing/godoxy/socketproxy/pkg"
gperr "github.com/yusing/goutils/errs"
httpServer "github.com/yusing/goutils/server"
strutils "github.com/yusing/goutils/strings"
"github.com/yusing/goutils/task"
@@ -63,6 +66,16 @@ Tips:
server.StartAgentServer(t, opts)
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentStreamPort})
if err != nil {
gperr.LogFatal("failed to listen on port", err)
}
tcpServer := stream.NewTCPServer(t.Context(), tcpListener, caCert.Leaf, srvCert)
go tcpServer.Start()
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentStreamPort}, caCert.Leaf, srvCert)
go udpServer.Start()
if socketproxy.ListenAddr != "" {
runtime := strutils.Title(string(env.Runtime))

View File

@@ -18,6 +18,7 @@ require (
github.com/bytedance/sonic v1.14.2
github.com/gin-gonic/gin v1.11.0
github.com/gorilla/websocket v1.5.3
github.com/pion/dtls/v3 v3.0.9
github.com/puzpuzpuz/xsync/v4 v4.2.0
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.11.1
@@ -87,6 +88,8 @@ require (
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/oschwald/maxminddb-golang v1.13.1 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pion/logging v0.2.4 // indirect
github.com/pion/transport/v3 v3.1.1 // indirect
github.com/pires/go-proxyproto v0.8.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect

View File

@@ -217,6 +217,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=

View File

@@ -8,6 +8,7 @@ import (
var (
installScript = `AGENT_NAME="{{.Name}}" \
AGENT_PORT="{{.Port}}" \
AGENT_STREAM_PORT="{{.StreamPort}}" \
AGENT_CA_CERT="{{.CACert}}" \
AGENT_SSL_CERT="{{.SSLCert}}" \
{{ if eq .ContainerRuntime "nerdctl" -}}

View File

@@ -4,38 +4,61 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream"
"github.com/yusing/godoxy/agent/pkg/certs"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/version"
)
type AgentConfig struct {
Addr string `json:"addr"`
Name string `json:"name"`
Version version.Version `json:"version" swaggertype:"string"`
Runtime ContainerRuntime `json:"runtime"`
AgentInfo
Addr string `json:"addr"`
httpClient *http.Client
fasthttpClientHealthCheck *fasthttp.Client
tlsConfig tls.Config
l zerolog.Logger
// for stream
caCert *x509.Certificate
clientCert *tls.Certificate
isTCPStreamSupported bool
isUDPStreamSupported bool
streamServerAddr string
l zerolog.Logger
} // @name Agent
type AgentInfo struct {
Version version.Version `json:"version" swaggertype:"string"`
Name string `json:"name"`
Runtime ContainerRuntime `json:"runtime"`
StreamPort int `json:"stream_port"`
}
// Deprecated. Replaced by EndpointInfo
const (
EndpointVersion = "/version"
EndpointName = "/name"
EndpointRuntime = "/runtime"
EndpointVersion = "/version"
EndpointName = "/name"
EndpointRuntime = "/runtime"
)
const (
EndpointInfo = "/info"
EndpointProxyHTTP = "/proxy/http"
EndpointHealth = "/health"
EndpointLogs = "/logs"
@@ -90,6 +113,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
if err != nil {
return err
}
cfg.clientCert = &clientCert
// create tls config
caCertPool := x509.NewCertPool()
@@ -97,6 +121,14 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
if !ok {
return errors.New("invalid ca certificate")
}
// Keep the CA leaf for stream client dialing.
if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" {
return errors.New("invalid ca certificate")
} else if cert, err := x509.ParseCertificate(block.Bytes); err != nil {
return err
} else {
cfg.caCert = cert
}
cfg.tlsConfig = tls.Config{
Certificates: []tls.Certificate{clientCert},
@@ -113,48 +145,97 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// get agent name
name, _, err := cfg.fetchString(ctx, EndpointName)
status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo)
if err != nil {
return err
}
cfg.Name = name
var streamUnsupportedErrs gperr.Builder
if status == http.StatusOK {
if cfg.StreamPort <= 0 {
return fmt.Errorf("invalid agent stream port: %d", cfg.StreamPort)
}
host, _, err := net.SplitHostPort(cfg.Addr)
if err != nil {
return err
}
cfg.streamServerAddr = net.JoinHostPort(host, strconv.Itoa(cfg.StreamPort))
// test stream server connection
const fakeAddress = "localhost:8080" // it won't be used, just for testing
// test TCP stream support
conn, err := agentstream.NewTCPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
if err != nil {
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
} else {
conn.Close()
cfg.isTCPStreamSupported = true
}
// test UDP stream support
conn, err = agentstream.NewUDPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
if err != nil {
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
} else {
conn.Close()
cfg.isUDPStreamSupported = true
}
} else {
// old agent does not support EndpointInfo
// fallback with old logic
cfg.isTCPStreamSupported = false
cfg.isUDPStreamSupported = false
streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling")
// get agent name
name, _, err := cfg.fetchString(ctx, EndpointName)
if err != nil {
return err
}
cfg.Name = name
// check agent version
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
if err != nil {
return err
}
cfg.Version = version.Parse(agentVersion)
// check agent runtime
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
if err != nil {
return err
}
switch status {
case http.StatusOK:
switch runtime {
case "docker":
cfg.Runtime = ContainerRuntimeDocker
// case "nerdctl":
// cfg.Runtime = ContainerRuntimeNerdctl
case "podman":
cfg.Runtime = ContainerRuntimePodman
default:
return fmt.Errorf("invalid agent runtime: %s", runtime)
}
case http.StatusNotFound:
// backward compatibility, old agent does not have runtime endpoint
cfg.Runtime = ContainerRuntimeDocker
default:
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
}
}
cfg.l = log.With().Str("agent", cfg.Name).Logger()
// check agent version
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
if err != nil {
return err
if err := streamUnsupportedErrs.Error(); err != nil {
gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l)
}
// check agent runtime
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
if err != nil {
return err
}
switch status {
case http.StatusOK:
switch runtime {
case "docker":
cfg.Runtime = ContainerRuntimeDocker
// case "nerdctl":
// cfg.Runtime = ContainerRuntimeNerdctl
case "podman":
cfg.Runtime = ContainerRuntimePodman
default:
return fmt.Errorf("invalid agent runtime: %s", runtime)
}
case http.StatusNotFound:
// backward compatibility, old agent does not have runtime endpoint
cfg.Runtime = ContainerRuntimeDocker
default:
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
}
cfg.Version = version.Parse(agentVersion)
if serverVersion.IsNewerThanMajor(cfg.Version) {
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version)
}
@@ -163,6 +244,53 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
return nil
}
func (cfg *AgentConfig) getStreamServerAddr() (string, error) {
if cfg.streamServerAddr == "" {
return "", errors.New("agent stream server address is not initialized")
}
return cfg.streamServerAddr, nil
}
// NewTCPClient creates a new TCP client for the agent.
//
// It returns an error if
// - the agent is not initialized
// - the agent does not support TCP stream tunneling
// - the agent stream server address is not initialized
func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) {
if cfg.caCert == nil || cfg.clientCert == nil {
return nil, errors.New("agent is not initialized")
}
if !cfg.isTCPStreamSupported {
return nil, errors.New("agent does not support TCP stream tunneling")
}
serverAddr, err := cfg.getStreamServerAddr()
if err != nil {
return nil, err
}
return agentstream.NewTCPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
}
// NewUDPClient creates a new UDP client for the agent.
//
// It returns an error if
// - the agent is not initialized
// - the agent does not support UDP stream tunneling
// - the agent stream server address is not initialized
func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) {
if cfg.caCert == nil || cfg.clientCert == nil {
return nil, errors.New("agent is not initialized")
}
if !cfg.isUDPStreamSupported {
return nil, errors.New("agent does not support UDP stream tunneling")
}
serverAddr, err := cfg.getStreamServerAddr()
if err != nil {
return nil, err
}
return agentstream.NewUDPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
}
func (cfg *AgentConfig) Start(ctx context.Context) error {
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
if !ok {

View File

@@ -5,6 +5,7 @@ type (
AgentEnvConfig struct {
Name string
Port int
StreamPort int
CACert string
SSLCert string
ContainerRuntime ContainerRuntime

View File

@@ -87,6 +87,34 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin
return ret, resp.StatusCode, nil
}
// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct
//
// It will return the status code of the response, and error if any.
// If the status code is not http.StatusOK, out will be unchanged but error will still be nil.
func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) {
resp, err := cfg.Do(ctx, "GET", endpoint, nil)
if err != nil {
return 0, err
}
defer resp.Body.Close()
data, release, err := httputils.ReadAllBody(resp)
if err != nil {
return 0, err
}
defer release(data)
if resp.StatusCode != http.StatusOK {
return resp.StatusCode, nil
}
err = sonic.Unmarshal(data, out)
if err != nil {
return 0, err
}
return resp.StatusCode, nil
}
func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) {
transport := cfg.Transport()
dialer := websocket.Dialer{

View File

@@ -0,0 +1,40 @@
# Stream proxy protocol
This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination.
## Header
The on-wire header is a fixed-size binary blob:
- `Version` (8 bytes)
- `Host` (255 bytes, NUL padded)
- `Port` (5 bytes, NUL padded)
- `Checksum` (4 bytes, big-endian CRC32)
Total: `headerSize = 8 + 255 + 5 + 4 = 272` bytes.
Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`.
See [`StreamRequestHeader`](payload.go:26).
## TCP behavior
1. Client establishes a TLS connection to the stream server.
2. Client sends exactly one header as a handshake.
3. After the handshake, both sides proxy raw TCP bytes between client and destination.
Server reads the header using `io.ReadFull` to avoid dropping bytes.
See [`NewTCPClient()`](tcp_client.go:15) and [`(*TCPServer).redirect()`](tcp_server.go:77).
## UDP-over-DTLS behavior
1. Client establishes a DTLS connection to the stream server.
2. Client sends exactly one header as a handshake.
3. After the handshake, both sides proxy raw UDP datagrams:
- client → destination: DTLS payload is written to destination `UDPConn`
- destination → client: destination payload is written back to the DTLS connection
Responses do **not** include a header.
See [`NewUDPClient()`](udp_client.go:17) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:67).

View File

@@ -0,0 +1,57 @@
package stream
import (
"net"
"time"
"github.com/puzpuzpuz/xsync/v4"
"github.com/yusing/goutils/synk"
)
const (
dialTimeout = 10 * time.Second
readDeadline = 10 * time.Second
)
var sizedPool = synk.GetSizedBytesPool()
type CreateConnFunc[Conn net.Conn] func(host, port string) (Conn, error)
type ConnectionManager[Conn net.Conn] struct {
m *xsync.Map[string, Conn]
createConnection CreateConnFunc[Conn]
}
func NewConnectionManager[Conn net.Conn](createConnection CreateConnFunc[Conn]) *ConnectionManager[Conn] {
return &ConnectionManager[Conn]{
m: xsync.NewMap[string, Conn](),
createConnection: createConnection,
}
}
func (c *ConnectionManager[Conn]) GetOrCreateDestConnection(clientConn net.Conn, host, port string) (ret Conn, connErr error) {
clientKey := clientConn.RemoteAddr().String()
ret, _ = c.m.LoadOrCompute(clientKey, func() (conn Conn, cancel bool) {
conn, connErr = c.createConnection(host, port)
if connErr != nil {
cancel = true
}
return
})
return
}
func (c *ConnectionManager[Conn]) DeleteDestConnection(clientConn net.Conn) {
clientKey := clientConn.RemoteAddr().String()
conn, loaded := c.m.LoadAndDelete(clientKey)
if loaded {
conn.Close()
}
}
func (c *ConnectionManager[Conn]) CloseAllConnections() {
for _, conn := range c.m.Range {
conn.Close()
}
c.m.Clear()
}

View File

@@ -0,0 +1,109 @@
package stream
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"unsafe"
)
const (
versionSize = 8
hostSize = 255
portSize = 5
checksumSize = 4 // crc32 checksum
headerSize = versionSize + hostSize + portSize + checksumSize
)
var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0}
var ErrInvalidHeader = errors.New("invalid header")
type StreamRequestHeader struct {
Version [versionSize]byte
Host [hostSize]byte
Port [portSize]byte
Checksum [checksumSize]byte
}
type StreamRequestPayload struct {
StreamRequestHeader
Data []byte
}
func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) {
if len(host) > hostSize {
return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host))
}
if len(port) > portSize {
return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port))
}
header := &StreamRequestHeader{}
copy(header.Version[:], version[:])
copy(header.Host[:], host)
copy(header.Port[:], port)
header.updateChecksum()
return header, nil
}
func ToHeader(buf [headerSize]byte) *StreamRequestHeader {
return (*StreamRequestHeader)(unsafe.Pointer(&buf[0]))
}
// WriteTo implements the io.WriterTo interface.
func (p *StreamRequestPayload) WriteTo(w io.Writer) (n int64, err error) {
n1, err := w.Write(p.StreamRequestHeader.Bytes())
if err != nil {
return
}
if len(p.Data) == 0 {
return int64(n1), nil
}
n2, err := w.Write(p.Data)
if err != nil {
return
}
return int64(n1) + int64(n2), nil
}
func (h *StreamRequestHeader) GetHostPort() (string, string) {
hostEnd := bytes.IndexByte(h.Host[:], 0)
portEnd := bytes.IndexByte(h.Port[:], 0)
if hostEnd == -1 {
hostEnd = hostSize
}
if portEnd == -1 {
portEnd = portSize
}
return string(h.Host[:hostEnd]), string(h.Port[:portEnd])
}
func (h *StreamRequestHeader) Validate() bool {
if h.Version != version {
return false
}
return h.validateChecksum()
}
func (h *StreamRequestHeader) updateChecksum() {
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
binary.BigEndian.PutUint32(h.Checksum[:], checksum)
}
func (h *StreamRequestHeader) validateChecksum() bool {
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
return checksum == binary.BigEndian.Uint32(h.Checksum[:])
}
func (h *StreamRequestHeader) BytesWithoutChecksum() []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize-checksumSize)
}
func (h *StreamRequestHeader) Bytes() []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize)
}

View File

@@ -0,0 +1,53 @@
package stream
import (
"bytes"
"testing"
)
func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) {
h, err := NewStreamRequestHeader("example.com", "443")
if err != nil {
t.Fatalf("NewStreamRequestHeader: %v", err)
}
if !h.Validate() {
t.Fatalf("expected header to validate")
}
var buf [headerSize]byte
copy(buf[:], h.Bytes())
h2 := ToHeader(buf)
if !h2.Validate() {
t.Fatalf("expected round-tripped header to validate")
}
host, port := h2.GetHostPort()
if host != "example.com" || port != "443" {
t.Fatalf("unexpected host/port: %q:%q", host, port)
}
}
func TestStreamRequestPayload_WriteTo_WritesFullHeader(t *testing.T) {
h, err := NewStreamRequestHeader("127.0.0.1", "53")
if err != nil {
t.Fatalf("NewStreamRequestHeader: %v", err)
}
p := &StreamRequestPayload{StreamRequestHeader: *h, Data: []byte("hello")}
var out bytes.Buffer
n, err := p.WriteTo(&out)
if err != nil {
t.Fatalf("WriteTo: %v", err)
}
if int(n) != headerSize+len(p.Data) {
t.Fatalf("unexpected bytes written: got %d want %d", n, headerSize+len(p.Data))
}
written := out.Bytes()
if len(written) != headerSize+len(p.Data) {
t.Fatalf("unexpected output size: got %d", len(written))
}
if !bytes.Equal(written[:headerSize], h.Bytes()) {
t.Fatalf("expected full header (including checksum) to be written")
}
}

View File

@@ -0,0 +1,235 @@
package stream
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"io"
"math/big"
"net"
"testing"
"time"
)
func newSerial(t *testing.T) *big.Int {
t.Helper()
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
t.Fatalf("rand serial: %v", err)
}
return sn
}
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: newSerial(t),
Subject: pkix.Name{CommonName: "stream-test-ca"},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate(CA): %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate(CA): %v", err)
}
return cert, key
}
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: newSerial(t),
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{eku},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
if err != nil {
t.Fatalf("CreateCertificate(%s): %v", cn, err)
}
leaf, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate(%s): %v", cn, err)
}
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
}
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen tcp: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
for {
c, err := ln.Accept()
if err != nil {
return
}
go func(conn net.Conn) {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}(c)
}
}()
return ln.Addr().String(), func() {
_ = ln.Close()
<-done
}
}
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
t.Helper()
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen udp: %v", err)
}
uc := pc.(*net.UDPConn)
done := make(chan struct{})
go func() {
defer close(done)
buf := make([]byte, 65535)
for {
n, raddr, err := uc.ReadFromUDP(buf)
if err != nil {
return
}
_, _ = uc.WriteToUDP(buf[:n], raddr)
}
}()
return uc.LocalAddr().String(), func() {
_ = uc.Close()
<-done
}
}
func TestTCPServer_FullFlow(t *testing.T) {
ca, caKey := genCA(t)
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
dstAddr, closeDst := startTCPEcho(t)
defer closeDst()
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenTCP: %v", err)
}
defer tcpLn.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
_ = <-errCh
}()
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
if err != nil {
t.Fatalf("NewTCPClient: %v", err)
}
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over tcp")
if _, err := client.Write(msg); err != nil {
t.Fatalf("client.Write: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(client, buf); err != nil {
t.Fatalf("client.ReadFull: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
}
}
func TestUDPServer_FullFlow(t *testing.T) {
ca, caKey := genCA(t)
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
dstAddr, closeDst := startUDPEcho(t)
defer closeDst()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
err := <-errCh
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
t.Logf("udp server exit: %v", err)
}
}()
deadline := time.Now().Add(2 * time.Second)
for srv.listener == nil {
if time.Now().After(deadline) {
t.Fatalf("udp server listener did not start")
}
time.Sleep(10 * time.Millisecond)
}
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
if err != nil {
t.Fatalf("NewUDPClient: %v", err)
}
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over udp")
if _, err := client.Write(msg); err != nil {
t.Fatalf("client.Write: %v", err)
}
buf := make([]byte, 2048)
n, err := client.Read(buf)
if err != nil {
t.Fatalf("client.Read: %v", err)
}
if string(buf[:n]) != string(msg) {
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
}
}

View File

@@ -0,0 +1,91 @@
package stream
import (
"crypto/tls"
"crypto/x509"
"net"
"time"
)
type TCPClient struct {
conn net.Conn
}
// NewTCPClient creates a new TCP client for the agent.
//
// It will establish a TLS connection and send a stream request header to the server.
//
// It returns an error if
// - the target address is invalid
// - the stream request header is invalid
// - the TLS configuration is invalid
// - the TLS connection fails
// - the stream request header is not sent
func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
host, port, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
header, err := NewStreamRequestHeader(host, port)
if err != nil {
return nil, err
}
// Setup TLS configuration
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*clientCert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
}
// Establish TLS connection
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig)
if err != nil {
return nil, err
}
// Send the stream header once as a handshake.
if _, err := conn.Write(header.Bytes()); err != nil {
_ = conn.Close()
return nil, err
}
return &TCPClient{
conn: conn,
}, nil
}
func (c *TCPClient) Read(p []byte) (n int, err error) {
return c.conn.Read(p)
}
func (c *TCPClient) Write(p []byte) (n int, err error) {
return c.conn.Write(p)
}
func (c *TCPClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *TCPClient) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *TCPClient) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *TCPClient) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *TCPClient) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *TCPClient) Close() error {
return c.conn.Close()
}

View File

@@ -0,0 +1,99 @@
package stream
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
ioutils "github.com/yusing/goutils/io"
)
type TCPServer struct {
ctx context.Context
listener net.Listener
connMgr *ConnectionManager[net.Conn]
}
func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer {
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*serverCert},
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
}
tcpListener := tls.NewListener(listener, tlsConfig)
s := &TCPServer{
ctx: ctx,
listener: tcpListener,
}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
}
func (s *TCPServer) Start() error {
for {
select {
case <-s.ctx.Done():
return s.ctx.Err()
default:
conn, err := s.listener.Accept()
if err != nil {
return err
}
go s.handle(conn)
}
}
}
func (s *TCPServer) Addr() net.Addr {
return s.listener.Addr()
}
func (s *TCPServer) Close() error {
s.connMgr.CloseAllConnections()
return s.listener.Close()
}
func (s *TCPServer) handle(conn net.Conn) {
defer conn.Close()
dst, err := s.redirect(conn)
if err != nil {
// TODO: log error
return
}
defer s.connMgr.DeleteDestConnection(conn)
pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst)
pipe.Start()
}
func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) {
// Read the stream header once as a handshake.
var headerBuf [headerSize]byte
if _, err := io.ReadFull(conn, headerBuf[:]); err != nil {
return nil, err
}
header := ToHeader(headerBuf)
if !header.Validate() {
return nil, ErrInvalidHeader
}
// get destination connection
host, port := header.GetHostPort()
return s.connMgr.GetOrCreateDestConnection(conn, host, port)
}
func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) {
addr := host + ":" + port
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
return nil, err
}
return conn, nil
}

View File

@@ -0,0 +1,99 @@
package stream
import (
"crypto/tls"
"crypto/x509"
"net"
"time"
"github.com/pion/dtls/v3"
)
type UDPClient struct {
conn net.Conn
}
// NewUDPClient creates a new UDP client for the agent.
//
// It will establish a DTLS connection and send a stream request header to the server.
//
// It returns an error if
// - the target address is invalid
// - the stream request header is invalid
// - the DTLS configuration is invalid
// - the DTLS connection fails
// - the stream request header is not sent
func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
host, port, err := net.SplitHostPort(targetAddress)
if err != nil {
return nil, err
}
header, err := NewStreamRequestHeader(host, port)
if err != nil {
return nil, err
}
// Setup DTLS configuration
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
dtlsConfig := &dtls.Config{
Certificates: []tls.Certificate{*clientCert},
RootCAs: caCertPool,
InsecureSkipVerify: false,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
}
raddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
return nil, err
}
// Establish DTLS connection
conn, err := dtls.Dial("udp", raddr, dtlsConfig)
if err != nil {
return nil, err
}
// Send the stream header once as a handshake.
if _, err := conn.Write(header.Bytes()); err != nil {
_ = conn.Close()
return nil, err
}
return &UDPClient{
conn: conn,
}, nil
}
func (c *UDPClient) Read(p []byte) (n int, err error) {
return c.conn.Read(p)
}
func (c *UDPClient) Write(p []byte) (n int, err error) {
return c.conn.Write(p)
}
func (c *UDPClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *UDPClient) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
func (c *UDPClient) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *UDPClient) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *UDPClient) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func (c *UDPClient) Close() error {
return c.conn.Close()
}

View File

@@ -0,0 +1,164 @@
package stream
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"
"github.com/pion/dtls/v3"
)
type UDPServer struct {
ctx context.Context
laddr *net.UDPAddr
listener net.Listener
dtlsConfig *dtls.Config
connMgr *ConnectionManager[*net.UDPConn]
}
func NewUDPServer(ctx context.Context, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer {
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
dtlsConfig := &dtls.Config{
Certificates: []tls.Certificate{*serverCert},
ClientCAs: caCertPool,
ClientAuth: dtls.RequireAndVerifyClientCert,
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
}
s := &UDPServer{
ctx: ctx,
laddr: laddr,
dtlsConfig: dtlsConfig,
}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
}
func (s *UDPServer) Start() error {
listener, err := dtls.Listen("udp", s.laddr, s.dtlsConfig)
if err != nil {
return err
}
s.listener = listener
for {
select {
case <-s.ctx.Done():
return s.ctx.Err()
default:
conn, err := s.listener.Accept()
if err != nil {
return err
}
go s.handleDTLSConnection(conn)
}
}
}
func (s *UDPServer) Addr() net.Addr {
if s.listener != nil {
return s.listener.Addr()
}
return s.laddr
}
func (s *UDPServer) Close() error {
s.connMgr.CloseAllConnections()
if s.listener != nil {
return s.listener.Close()
}
return nil
}
func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) {
defer clientConn.Close()
// Read the stream header once as a handshake.
var headerBuf [headerSize]byte
if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil {
// TODO: log error
return
}
header := ToHeader(headerBuf)
if !header.Validate() {
// TODO: log error
return
}
host, port := header.GetHostPort()
dstConn, err := s.connMgr.GetOrCreateDestConnection(clientConn, host, port)
if err != nil {
// TODO: log error
return
}
defer s.connMgr.DeleteDestConnection(clientConn)
go s.forwardFromDestination(dstConn, clientConn)
buf := sizedPool.GetSized(65535)
defer sizedPool.Put(buf)
for {
select {
case <-s.ctx.Done():
return
default:
n, err := clientConn.Read(buf)
if err != nil {
// TODO: log error
return
}
if _, err := dstConn.Write(buf[:n]); err != nil {
// TODO: log error
return
}
}
}
}
func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) {
addr := host + ":" + port
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
dstConn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
return nil, err
}
return dstConn, nil
}
func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) {
buffer := sizedPool.GetSized(65535)
defer sizedPool.Put(buffer)
for {
select {
case <-s.ctx.Done():
return
default:
_ = dstConn.SetReadDeadline(time.Now().Add(readDeadline))
n, err := dstConn.Read(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return
}
// TODO: log error
return
}
if _, err := clientConn.Write(buffer[:n]); err != nil {
// TODO: log error
return
}
}
}
}

View File

@@ -1,44 +0,0 @@
services:
agent:
image: "{{.Image}}"
container_name: godoxy-agent
restart: always
network_mode: host # do not change this
environment:
AGENT_NAME: "{{.Name}}"
AGENT_PORT: "{{.Port}}"
AGENT_CA_CERT: "{{.CACert}}"
AGENT_SSL_CERT: "{{.SSLCert}}"
# use agent as a docker socket proxy: [host]:port
# set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375)
LISTEN_ADDR:
POST: false
ALLOW_RESTARTS: false
ALLOW_START: false
ALLOW_STOP: false
AUTH: false
BUILD: false
COMMIT: false
CONFIGS: false
CONTAINERS: false
DISTRIBUTION: false
EVENTS: true
EXEC: false
GRPC: false
IMAGES: false
INFO: false
NETWORKS: false
NODES: false
PING: true
PLUGINS: false
SECRETS: false
SERVICES: false
SESSION: false
SWARM: false
SYSTEM: false
TASKS: false
VERSION: true
VOLUMES: false
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ./data:/app/data

View File

@@ -5,7 +5,9 @@ services:
restart: always
{{ if eq .ContainerRuntime "podman" -}}
ports:
- "{{.Port}}:{{.Port}}"
- "{{.Port}}:{{.Port}}/tcp"
- "{{.StreamPort}}:{{.StreamPort}}/tcp"
- "{{.StreamPort}}:{{.StreamPort}}/udp"
{{ else -}}
network_mode: host # do not change this
{{ end -}}
@@ -22,6 +24,7 @@ services:
{{ end -}}
AGENT_NAME: "{{.Name}}"
AGENT_PORT: "{{.Port}}"
AGENT_STREAM_PORT: "{{.StreamPort}}"
AGENT_CA_CERT: "{{.CACert}}"
AGENT_SSL_CERT: "{{.SSLCert}}"
# use agent as a docker socket proxy: [host]:port

View File

@@ -20,6 +20,7 @@ func DefaultAgentName() string {
var (
AgentName string
AgentPort int
AgentStreamPort int
AgentSkipClientCertCheck bool
AgentCACert string
AgentSSLCert string
@@ -35,6 +36,7 @@ func Load() {
DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock")
AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName())
AgentPort = env.GetEnvInt("AGENT_PORT", 8890)
AgentStreamPort = env.GetEnvInt("AGENT_STREAM_PORT", AgentPort+1)
AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
AgentCACert = env.GetEnvString("AGENT_CA_CERT", "")

View File

@@ -1,9 +1,9 @@
package handler
import (
"fmt"
"net/http"
"github.com/bytedance/sonic"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/yusing/godoxy/agent/pkg/agent"
@@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler {
}
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, version.Get())
})
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, env.AgentName)
})
mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, env.Runtime)
mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) {
agentInfo := agent.AgentInfo{
Version: version.Get(),
Name: env.AgentName,
Runtime: env.Runtime,
StreamPort: env.AgentStreamPort,
}
sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo)
})
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP)

3
go.mod
View File

@@ -172,6 +172,9 @@ require (
github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect
github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/logging v0.2.4 // indirect
github.com/pion/transport/v3 v3.1.1 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/stretchr/objx v0.5.3 // indirect

6
go.sum
View File

@@ -247,6 +247,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=

View File

@@ -110,9 +110,9 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) {
switch rurl.Scheme {
case "tcp":
return stream.NewTCPTCPStream(laddr, rurl.Host)
return stream.NewTCPTCPStream(laddr, rurl.Host, r.GetAgent())
case "udp":
return stream.NewUDPUDPStream(laddr, rurl.Host)
return stream.NewUDPUDPStream(laddr, rurl.Host, r.GetAgent())
}
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/pires/go-proxyproto"
"github.com/rs/zerolog"
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/internal/acl"
"github.com/yusing/godoxy/internal/entrypoint"
nettypes "github.com/yusing/godoxy/internal/net/types"
@@ -17,6 +18,7 @@ type TCPTCPStream struct {
listener net.Listener
laddr *net.TCPAddr
dst *net.TCPAddr
agent *agent.AgentConfig
preDial nettypes.HookFunc
onRead nettypes.HookFunc
@@ -24,7 +26,7 @@ type TCPTCPStream struct {
closed atomic.Bool
}
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
func NewTCPTCPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) {
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
if err != nil {
return nil, err
@@ -33,7 +35,7 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
if err != nil {
return nil, err
}
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
return &TCPTCPStream{laddr: laddr, dst: dst, agent: agentCfg}, nil
}
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
@@ -126,7 +128,15 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
return
}
dstConn, err := net.DialTCP("tcp", nil, s.dst)
var (
dstConn net.Conn
err error
)
if s.agent != nil {
dstConn, err = s.agent.NewTCPClient(s.dst.String())
} else {
dstConn, err = net.DialTCP("tcp", nil, s.dst)
}
if err != nil {
if !s.closed.Load() {
logErr(s, err, "failed to dial destination")
@@ -140,7 +150,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
}
src := conn
dst := net.Conn(dstConn)
dst := dstConn
if s.onRead != nil {
src = &wrapperConn{
Conn: conn,

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/internal/acl"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/goutils/synk"
@@ -22,6 +23,7 @@ type UDPUDPStream struct {
laddr *net.UDPAddr
dst *net.UDPAddr
agent *agent.AgentConfig
preDial nettypes.HookFunc
onRead nettypes.HookFunc
@@ -35,7 +37,7 @@ type UDPUDPStream struct {
type udpUDPConn struct {
srcAddr *net.UDPAddr
dstConn *net.UDPConn
dstConn net.Conn
listener net.PacketConn
lastUsed atomic.Time
closed atomic.Bool
@@ -51,7 +53,7 @@ const (
var bufPool = synk.GetSizedBytesPool()
func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
func NewUDPUDPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) {
dst, err := net.ResolveUDPAddr("udp", dstAddr)
if err != nil {
return nil, err
@@ -63,6 +65,7 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
return &UDPUDPStream{
laddr: laddr,
dst: dst,
agent: agentCfg,
conns: make(map[string]*udpUDPConn),
}, nil
}
@@ -189,8 +192,16 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
}
}
// Create UDP connection to destination
dstConn, err := net.DialUDP("udp", nil, s.dst)
// Create connection to destination (direct UDP or via agent stream tunnel)
var (
dstConn net.Conn
err error
)
if s.agent != nil {
dstConn, err = s.agent.NewUDPClient(s.dst.String())
} else {
dstConn, err = net.DialUDP("udp", nil, s.dst)
}
if err != nil {
logErr(s, err, "failed to dial dst")
return nil, false
@@ -205,7 +216,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
// Send initial data before starting response handler
if !conn.forwardToDestination(initialData) {
dstConn.Close()
_ = dstConn.Close()
return nil, false
}
@@ -328,6 +339,6 @@ func (conn *udpUDPConn) Close() {
conn.closed.Store(true)
conn.dstConn.Close()
_ = conn.dstConn.Close()
conn.dstConn = nil
}