mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-18 07:13:50 +01:00
feat(agent): agent stream tunneling with TLS and dTLS (UDP); combined agent APIs
- Add `StreamPort` configuration to agent configuration and environment variables - Implement TCP and UDP stream client support in agent package - Update agent verification to test stream connectivity (TCP/UDP) - Add `/info` endpoint to agent HTTP handler for version, name, runtime, and stream port - Remove /version, /name, /runtime APIs, replaced by /info - Update agent compose template to expose stream port for TCP and UDP - Update agent creation API to optionally specify stream port (defaults to port + 1) - Modify `StreamRoute` to pass agent configuration to stream implementations - Update `TCPTCPStream` and `UDPUDPStream` to use agent stream tunneling when agent is configured - Add support for both direct connections and agent-tunneled connections in stream routes This enables agents to handle TCP and UDP route tunneling, expanding the proxy capabilities beyond HTTP-only connections.
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
"github.com/yusing/godoxy/agent/pkg/env"
|
||||
"github.com/yusing/godoxy/agent/pkg/server"
|
||||
"github.com/yusing/godoxy/internal/metrics/systeminfo"
|
||||
socketproxy "github.com/yusing/godoxy/socketproxy/pkg"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httpServer "github.com/yusing/goutils/server"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
"github.com/yusing/goutils/task"
|
||||
@@ -63,6 +66,16 @@ Tips:
|
||||
|
||||
server.StartAgentServer(t, opts)
|
||||
|
||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentStreamPort})
|
||||
if err != nil {
|
||||
gperr.LogFatal("failed to listen on port", err)
|
||||
}
|
||||
tcpServer := stream.NewTCPServer(t.Context(), tcpListener, caCert.Leaf, srvCert)
|
||||
go tcpServer.Start()
|
||||
|
||||
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentStreamPort}, caCert.Leaf, srvCert)
|
||||
go udpServer.Start()
|
||||
|
||||
if socketproxy.ListenAddr != "" {
|
||||
runtime := strutils.Title(string(env.Runtime))
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ require (
|
||||
github.com/bytedance/sonic v1.14.2
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/pion/dtls/v3 v3.0.9
|
||||
github.com/puzpuzpuz/xsync/v4 v4.2.0
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
@@ -87,6 +88,8 @@ require (
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/oschwald/maxminddb-golang v1.13.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pion/logging v0.2.4 // indirect
|
||||
github.com/pion/transport/v3 v3.1.1 // indirect
|
||||
github.com/pires/go-proxyproto v0.8.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
|
||||
@@ -217,6 +217,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
var (
|
||||
installScript = `AGENT_NAME="{{.Name}}" \
|
||||
AGENT_PORT="{{.Port}}" \
|
||||
AGENT_STREAM_PORT="{{.StreamPort}}" \
|
||||
AGENT_CA_CERT="{{.CACert}}" \
|
||||
AGENT_SSL_CERT="{{.SSLCert}}" \
|
||||
{{ if eq .ContainerRuntime "nerdctl" -}}
|
||||
|
||||
@@ -4,38 +4,61 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
"github.com/yusing/godoxy/agent/pkg/certs"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
"github.com/yusing/goutils/version"
|
||||
)
|
||||
|
||||
type AgentConfig struct {
|
||||
Addr string `json:"addr"`
|
||||
Name string `json:"name"`
|
||||
Version version.Version `json:"version" swaggertype:"string"`
|
||||
Runtime ContainerRuntime `json:"runtime"`
|
||||
AgentInfo
|
||||
|
||||
Addr string `json:"addr"`
|
||||
|
||||
httpClient *http.Client
|
||||
fasthttpClientHealthCheck *fasthttp.Client
|
||||
tlsConfig tls.Config
|
||||
l zerolog.Logger
|
||||
|
||||
// for stream
|
||||
caCert *x509.Certificate
|
||||
clientCert *tls.Certificate
|
||||
isTCPStreamSupported bool
|
||||
isUDPStreamSupported bool
|
||||
streamServerAddr string
|
||||
|
||||
l zerolog.Logger
|
||||
} // @name Agent
|
||||
|
||||
type AgentInfo struct {
|
||||
Version version.Version `json:"version" swaggertype:"string"`
|
||||
Name string `json:"name"`
|
||||
Runtime ContainerRuntime `json:"runtime"`
|
||||
StreamPort int `json:"stream_port"`
|
||||
}
|
||||
|
||||
// Deprecated. Replaced by EndpointInfo
|
||||
const (
|
||||
EndpointVersion = "/version"
|
||||
EndpointName = "/name"
|
||||
EndpointRuntime = "/runtime"
|
||||
EndpointVersion = "/version"
|
||||
EndpointName = "/name"
|
||||
EndpointRuntime = "/runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
EndpointInfo = "/info"
|
||||
EndpointProxyHTTP = "/proxy/http"
|
||||
EndpointHealth = "/health"
|
||||
EndpointLogs = "/logs"
|
||||
@@ -90,6 +113,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.clientCert = &clientCert
|
||||
|
||||
// create tls config
|
||||
caCertPool := x509.NewCertPool()
|
||||
@@ -97,6 +121,14 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||
if !ok {
|
||||
return errors.New("invalid ca certificate")
|
||||
}
|
||||
// Keep the CA leaf for stream client dialing.
|
||||
if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" {
|
||||
return errors.New("invalid ca certificate")
|
||||
} else if cert, err := x509.ParseCertificate(block.Bytes); err != nil {
|
||||
return err
|
||||
} else {
|
||||
cfg.caCert = cert
|
||||
}
|
||||
|
||||
cfg.tlsConfig = tls.Config{
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
@@ -113,48 +145,97 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// get agent name
|
||||
name, _, err := cfg.fetchString(ctx, EndpointName)
|
||||
status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Name = name
|
||||
var streamUnsupportedErrs gperr.Builder
|
||||
|
||||
if status == http.StatusOK {
|
||||
if cfg.StreamPort <= 0 {
|
||||
return fmt.Errorf("invalid agent stream port: %d", cfg.StreamPort)
|
||||
}
|
||||
host, _, err := net.SplitHostPort(cfg.Addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.streamServerAddr = net.JoinHostPort(host, strconv.Itoa(cfg.StreamPort))
|
||||
|
||||
// test stream server connection
|
||||
const fakeAddress = "localhost:8080" // it won't be used, just for testing
|
||||
// test TCP stream support
|
||||
conn, err := agentstream.NewTCPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
|
||||
if err != nil {
|
||||
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
|
||||
} else {
|
||||
conn.Close()
|
||||
cfg.isTCPStreamSupported = true
|
||||
}
|
||||
|
||||
// test UDP stream support
|
||||
conn, err = agentstream.NewUDPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
|
||||
if err != nil {
|
||||
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
|
||||
} else {
|
||||
conn.Close()
|
||||
cfg.isUDPStreamSupported = true
|
||||
}
|
||||
} else {
|
||||
// old agent does not support EndpointInfo
|
||||
// fallback with old logic
|
||||
cfg.isTCPStreamSupported = false
|
||||
cfg.isUDPStreamSupported = false
|
||||
streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling")
|
||||
|
||||
// get agent name
|
||||
name, _, err := cfg.fetchString(ctx, EndpointName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Name = name
|
||||
|
||||
// check agent version
|
||||
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Version = version.Parse(agentVersion)
|
||||
|
||||
// check agent runtime
|
||||
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch status {
|
||||
case http.StatusOK:
|
||||
switch runtime {
|
||||
case "docker":
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
// case "nerdctl":
|
||||
// cfg.Runtime = ContainerRuntimeNerdctl
|
||||
case "podman":
|
||||
cfg.Runtime = ContainerRuntimePodman
|
||||
default:
|
||||
return fmt.Errorf("invalid agent runtime: %s", runtime)
|
||||
}
|
||||
case http.StatusNotFound:
|
||||
// backward compatibility, old agent does not have runtime endpoint
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
default:
|
||||
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
|
||||
}
|
||||
}
|
||||
|
||||
cfg.l = log.With().Str("agent", cfg.Name).Logger()
|
||||
|
||||
// check agent version
|
||||
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := streamUnsupportedErrs.Error(); err != nil {
|
||||
gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l)
|
||||
}
|
||||
|
||||
// check agent runtime
|
||||
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch status {
|
||||
case http.StatusOK:
|
||||
switch runtime {
|
||||
case "docker":
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
// case "nerdctl":
|
||||
// cfg.Runtime = ContainerRuntimeNerdctl
|
||||
case "podman":
|
||||
cfg.Runtime = ContainerRuntimePodman
|
||||
default:
|
||||
return fmt.Errorf("invalid agent runtime: %s", runtime)
|
||||
}
|
||||
case http.StatusNotFound:
|
||||
// backward compatibility, old agent does not have runtime endpoint
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
default:
|
||||
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
|
||||
}
|
||||
|
||||
cfg.Version = version.Parse(agentVersion)
|
||||
|
||||
if serverVersion.IsNewerThanMajor(cfg.Version) {
|
||||
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version)
|
||||
}
|
||||
@@ -163,6 +244,53 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) getStreamServerAddr() (string, error) {
|
||||
if cfg.streamServerAddr == "" {
|
||||
return "", errors.New("agent stream server address is not initialized")
|
||||
}
|
||||
return cfg.streamServerAddr, nil
|
||||
}
|
||||
|
||||
// NewTCPClient creates a new TCP client for the agent.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the agent is not initialized
|
||||
// - the agent does not support TCP stream tunneling
|
||||
// - the agent stream server address is not initialized
|
||||
func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) {
|
||||
if cfg.caCert == nil || cfg.clientCert == nil {
|
||||
return nil, errors.New("agent is not initialized")
|
||||
}
|
||||
if !cfg.isTCPStreamSupported {
|
||||
return nil, errors.New("agent does not support TCP stream tunneling")
|
||||
}
|
||||
serverAddr, err := cfg.getStreamServerAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agentstream.NewTCPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
|
||||
}
|
||||
|
||||
// NewUDPClient creates a new UDP client for the agent.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the agent is not initialized
|
||||
// - the agent does not support UDP stream tunneling
|
||||
// - the agent stream server address is not initialized
|
||||
func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) {
|
||||
if cfg.caCert == nil || cfg.clientCert == nil {
|
||||
return nil, errors.New("agent is not initialized")
|
||||
}
|
||||
if !cfg.isUDPStreamSupported {
|
||||
return nil, errors.New("agent does not support UDP stream tunneling")
|
||||
}
|
||||
serverAddr, err := cfg.getStreamServerAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agentstream.NewUDPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) Start(ctx context.Context) error {
|
||||
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
|
||||
if !ok {
|
||||
|
||||
@@ -5,6 +5,7 @@ type (
|
||||
AgentEnvConfig struct {
|
||||
Name string
|
||||
Port int
|
||||
StreamPort int
|
||||
CACert string
|
||||
SSLCert string
|
||||
ContainerRuntime ContainerRuntime
|
||||
|
||||
@@ -87,6 +87,34 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin
|
||||
return ret, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct
|
||||
//
|
||||
// It will return the status code of the response, and error if any.
|
||||
// If the status code is not http.StatusOK, out will be unchanged but error will still be nil.
|
||||
func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) {
|
||||
resp, err := cfg.Do(ctx, "GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, release, err := httputils.ReadAllBody(resp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer release(data)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
err = sonic.Unmarshal(data, out)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) {
|
||||
transport := cfg.Transport()
|
||||
dialer := websocket.Dialer{
|
||||
|
||||
40
agent/pkg/agent/stream/PROTOCOL.md
Normal file
40
agent/pkg/agent/stream/PROTOCOL.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Stream proxy protocol
|
||||
|
||||
This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination.
|
||||
|
||||
## Header
|
||||
|
||||
The on-wire header is a fixed-size binary blob:
|
||||
|
||||
- `Version` (8 bytes)
|
||||
- `Host` (255 bytes, NUL padded)
|
||||
- `Port` (5 bytes, NUL padded)
|
||||
- `Checksum` (4 bytes, big-endian CRC32)
|
||||
|
||||
Total: `headerSize = 8 + 255 + 5 + 4 = 272` bytes.
|
||||
|
||||
Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`.
|
||||
|
||||
See [`StreamRequestHeader`](payload.go:26).
|
||||
|
||||
## TCP behavior
|
||||
|
||||
1. Client establishes a TLS connection to the stream server.
|
||||
2. Client sends exactly one header as a handshake.
|
||||
3. After the handshake, both sides proxy raw TCP bytes between client and destination.
|
||||
|
||||
Server reads the header using `io.ReadFull` to avoid dropping bytes.
|
||||
|
||||
See [`NewTCPClient()`](tcp_client.go:15) and [`(*TCPServer).redirect()`](tcp_server.go:77).
|
||||
|
||||
## UDP-over-DTLS behavior
|
||||
|
||||
1. Client establishes a DTLS connection to the stream server.
|
||||
2. Client sends exactly one header as a handshake.
|
||||
3. After the handshake, both sides proxy raw UDP datagrams:
|
||||
- client → destination: DTLS payload is written to destination `UDPConn`
|
||||
- destination → client: destination payload is written back to the DTLS connection
|
||||
|
||||
Responses do **not** include a header.
|
||||
|
||||
See [`NewUDPClient()`](udp_client.go:17) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:67).
|
||||
57
agent/pkg/agent/stream/common.go
Normal file
57
agent/pkg/agent/stream/common.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/yusing/goutils/synk"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 10 * time.Second
|
||||
readDeadline = 10 * time.Second
|
||||
)
|
||||
|
||||
var sizedPool = synk.GetSizedBytesPool()
|
||||
|
||||
type CreateConnFunc[Conn net.Conn] func(host, port string) (Conn, error)
|
||||
type ConnectionManager[Conn net.Conn] struct {
|
||||
m *xsync.Map[string, Conn]
|
||||
createConnection CreateConnFunc[Conn]
|
||||
}
|
||||
|
||||
func NewConnectionManager[Conn net.Conn](createConnection CreateConnFunc[Conn]) *ConnectionManager[Conn] {
|
||||
return &ConnectionManager[Conn]{
|
||||
m: xsync.NewMap[string, Conn](),
|
||||
createConnection: createConnection,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectionManager[Conn]) GetOrCreateDestConnection(clientConn net.Conn, host, port string) (ret Conn, connErr error) {
|
||||
clientKey := clientConn.RemoteAddr().String()
|
||||
ret, _ = c.m.LoadOrCompute(clientKey, func() (conn Conn, cancel bool) {
|
||||
conn, connErr = c.createConnection(host, port)
|
||||
if connErr != nil {
|
||||
cancel = true
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ConnectionManager[Conn]) DeleteDestConnection(clientConn net.Conn) {
|
||||
clientKey := clientConn.RemoteAddr().String()
|
||||
conn, loaded := c.m.LoadAndDelete(clientKey)
|
||||
if loaded {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectionManager[Conn]) CloseAllConnections() {
|
||||
for _, conn := range c.m.Range {
|
||||
conn.Close()
|
||||
}
|
||||
c.m.Clear()
|
||||
}
|
||||
109
agent/pkg/agent/stream/payload.go
Normal file
109
agent/pkg/agent/stream/payload.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
versionSize = 8
|
||||
hostSize = 255
|
||||
portSize = 5
|
||||
checksumSize = 4 // crc32 checksum
|
||||
|
||||
headerSize = versionSize + hostSize + portSize + checksumSize
|
||||
)
|
||||
|
||||
var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0}
|
||||
|
||||
var ErrInvalidHeader = errors.New("invalid header")
|
||||
|
||||
type StreamRequestHeader struct {
|
||||
Version [versionSize]byte
|
||||
Host [hostSize]byte
|
||||
Port [portSize]byte
|
||||
Checksum [checksumSize]byte
|
||||
}
|
||||
|
||||
type StreamRequestPayload struct {
|
||||
StreamRequestHeader
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) {
|
||||
if len(host) > hostSize {
|
||||
return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host))
|
||||
}
|
||||
if len(port) > portSize {
|
||||
return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port))
|
||||
}
|
||||
header := &StreamRequestHeader{}
|
||||
copy(header.Version[:], version[:])
|
||||
copy(header.Host[:], host)
|
||||
copy(header.Port[:], port)
|
||||
header.updateChecksum()
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func ToHeader(buf [headerSize]byte) *StreamRequestHeader {
|
||||
return (*StreamRequestHeader)(unsafe.Pointer(&buf[0]))
|
||||
}
|
||||
|
||||
// WriteTo implements the io.WriterTo interface.
|
||||
func (p *StreamRequestPayload) WriteTo(w io.Writer) (n int64, err error) {
|
||||
n1, err := w.Write(p.StreamRequestHeader.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(p.Data) == 0 {
|
||||
return int64(n1), nil
|
||||
}
|
||||
|
||||
n2, err := w.Write(p.Data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return int64(n1) + int64(n2), nil
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) GetHostPort() (string, string) {
|
||||
hostEnd := bytes.IndexByte(h.Host[:], 0)
|
||||
portEnd := bytes.IndexByte(h.Port[:], 0)
|
||||
if hostEnd == -1 {
|
||||
hostEnd = hostSize
|
||||
}
|
||||
if portEnd == -1 {
|
||||
portEnd = portSize
|
||||
}
|
||||
return string(h.Host[:hostEnd]), string(h.Port[:portEnd])
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) Validate() bool {
|
||||
if h.Version != version {
|
||||
return false
|
||||
}
|
||||
return h.validateChecksum()
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) updateChecksum() {
|
||||
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
|
||||
binary.BigEndian.PutUint32(h.Checksum[:], checksum)
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) validateChecksum() bool {
|
||||
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
|
||||
return checksum == binary.BigEndian.Uint32(h.Checksum[:])
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) BytesWithoutChecksum() []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize-checksumSize)
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) Bytes() []byte {
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize)
|
||||
}
|
||||
53
agent/pkg/agent/stream/payload_test.go
Normal file
53
agent/pkg/agent/stream/payload_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) {
|
||||
h, err := NewStreamRequestHeader("example.com", "443")
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamRequestHeader: %v", err)
|
||||
}
|
||||
if !h.Validate() {
|
||||
t.Fatalf("expected header to validate")
|
||||
}
|
||||
|
||||
var buf [headerSize]byte
|
||||
copy(buf[:], h.Bytes())
|
||||
h2 := ToHeader(buf)
|
||||
if !h2.Validate() {
|
||||
t.Fatalf("expected round-tripped header to validate")
|
||||
}
|
||||
host, port := h2.GetHostPort()
|
||||
if host != "example.com" || port != "443" {
|
||||
t.Fatalf("unexpected host/port: %q:%q", host, port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamRequestPayload_WriteTo_WritesFullHeader(t *testing.T) {
|
||||
h, err := NewStreamRequestHeader("127.0.0.1", "53")
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamRequestHeader: %v", err)
|
||||
}
|
||||
|
||||
p := &StreamRequestPayload{StreamRequestHeader: *h, Data: []byte("hello")}
|
||||
|
||||
var out bytes.Buffer
|
||||
n, err := p.WriteTo(&out)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteTo: %v", err)
|
||||
}
|
||||
if int(n) != headerSize+len(p.Data) {
|
||||
t.Fatalf("unexpected bytes written: got %d want %d", n, headerSize+len(p.Data))
|
||||
}
|
||||
|
||||
written := out.Bytes()
|
||||
if len(written) != headerSize+len(p.Data) {
|
||||
t.Fatalf("unexpected output size: got %d", len(written))
|
||||
}
|
||||
if !bytes.Equal(written[:headerSize], h.Bytes()) {
|
||||
t.Fatalf("expected full header (including checksum) to be written")
|
||||
}
|
||||
}
|
||||
235
agent/pkg/agent/stream/server_flow_test.go
Normal file
235
agent/pkg/agent/stream/server_flow_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newSerial(t *testing.T) *big.Int {
|
||||
t.Helper()
|
||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
t.Fatalf("rand serial: %v", err)
|
||||
}
|
||||
return sn
|
||||
}
|
||||
|
||||
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: newSerial(t),
|
||||
Subject: pkix.Name{CommonName: "stream-test-ca"},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate(CA): %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificate(CA): %v", err)
|
||||
}
|
||||
return cert, key
|
||||
}
|
||||
|
||||
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: newSerial(t),
|
||||
Subject: pkix.Name{CommonName: cn},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{eku},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate(%s): %v", cn, err)
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificate(%s): %v", cn, err)
|
||||
}
|
||||
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
|
||||
}
|
||||
|
||||
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen tcp: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
t.Helper()
|
||||
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen udp: %v", err)
|
||||
}
|
||||
uc := pc.(*net.UDPConn)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, raddr, err := uc.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = uc.WriteToUDP(buf[:n], raddr)
|
||||
}
|
||||
}()
|
||||
|
||||
return uc.LocalAddr().String(), func() {
|
||||
_ = uc.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPServer_FullFlow(t *testing.T) {
|
||||
ca, caKey := genCA(t)
|
||||
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
||||
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
||||
|
||||
dstAddr, closeDst := startTCPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTCP: %v", err)
|
||||
}
|
||||
defer tcpLn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
defer func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
_ = <-errCh
|
||||
}()
|
||||
|
||||
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPClient: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over tcp")
|
||||
if _, err := client.Write(msg); err != nil {
|
||||
t.Fatalf("client.Write: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, len(msg))
|
||||
if _, err := io.ReadFull(client, buf); err != nil {
|
||||
t.Fatalf("client.ReadFull: %v", err)
|
||||
}
|
||||
if string(buf) != string(msg) {
|
||||
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPServer_FullFlow(t *testing.T) {
|
||||
ca, caKey := genCA(t)
|
||||
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
||||
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
defer func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
|
||||
t.Logf("udp server exit: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for srv.listener == nil {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("udp server listener did not start")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
||||
if err != nil {
|
||||
t.Fatalf("NewUDPClient: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over udp")
|
||||
if _, err := client.Write(msg); err != nil {
|
||||
t.Fatalf("client.Write: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := client.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("client.Read: %v", err)
|
||||
}
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
|
||||
}
|
||||
}
|
||||
91
agent/pkg/agent/stream/tcp_client.go
Normal file
91
agent/pkg/agent/stream/tcp_client.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TCPClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// NewTCPClient creates a new TCP client for the agent.
|
||||
//
|
||||
// It will establish a TLS connection and send a stream request header to the server.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the target address is invalid
|
||||
// - the stream request header is invalid
|
||||
// - the TLS configuration is invalid
|
||||
// - the TLS connection fails
|
||||
// - the stream request header is not sent
|
||||
func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(targetAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := NewStreamRequestHeader(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup TLS configuration
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*clientCert},
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// Establish TLS connection
|
||||
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send the stream header once as a handshake.
|
||||
if _, err := conn.Write(header.Bytes()); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TCPClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *TCPClient) Read(p []byte) (n int, err error) {
|
||||
return c.conn.Read(p)
|
||||
}
|
||||
|
||||
func (c *TCPClient) Write(p []byte) (n int, err error) {
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *TCPClient) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *TCPClient) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *TCPClient) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *TCPClient) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *TCPClient) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *TCPClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
99
agent/pkg/agent/stream/tcp_server.go
Normal file
99
agent/pkg/agent/stream/tcp_server.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
|
||||
type TCPServer struct {
|
||||
ctx context.Context
|
||||
listener net.Listener
|
||||
connMgr *ConnectionManager[net.Conn]
|
||||
}
|
||||
|
||||
func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*serverCert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
tcpListener := tls.NewListener(listener, tlsConfig)
|
||||
s := &TCPServer{
|
||||
ctx: ctx,
|
||||
listener: tcpListener,
|
||||
}
|
||||
s.connMgr = NewConnectionManager(s.createDestConnection)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *TCPServer) Start() error {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handle(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TCPServer) Addr() net.Addr {
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
func (s *TCPServer) Close() error {
|
||||
s.connMgr.CloseAllConnections()
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *TCPServer) handle(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
dst, err := s.redirect(conn)
|
||||
if err != nil {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
defer s.connMgr.DeleteDestConnection(conn)
|
||||
pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst)
|
||||
pipe.Start()
|
||||
}
|
||||
|
||||
func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) {
|
||||
// Read the stream header once as a handshake.
|
||||
var headerBuf [headerSize]byte
|
||||
if _, err := io.ReadFull(conn, headerBuf[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header := ToHeader(headerBuf)
|
||||
if !header.Validate() {
|
||||
return nil, ErrInvalidHeader
|
||||
}
|
||||
|
||||
// get destination connection
|
||||
host, port := header.GetHostPort()
|
||||
return s.connMgr.GetOrCreateDestConnection(conn, host, port)
|
||||
}
|
||||
|
||||
func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) {
|
||||
addr := host + ":" + port
|
||||
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
99
agent/pkg/agent/stream/udp_client.go
Normal file
99
agent/pkg/agent/stream/udp_client.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
)
|
||||
|
||||
type UDPClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// NewUDPClient creates a new UDP client for the agent.
|
||||
//
|
||||
// It will establish a DTLS connection and send a stream request header to the server.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the target address is invalid
|
||||
// - the stream request header is invalid
|
||||
// - the DTLS configuration is invalid
|
||||
// - the DTLS connection fails
|
||||
// - the stream request header is not sent
|
||||
func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(targetAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := NewStreamRequestHeader(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Setup DTLS configuration
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
dtlsConfig := &dtls.Config{
|
||||
Certificates: []tls.Certificate{*clientCert},
|
||||
RootCAs: caCertPool,
|
||||
InsecureSkipVerify: false,
|
||||
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||
}
|
||||
|
||||
raddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Establish DTLS connection
|
||||
conn, err := dtls.Dial("udp", raddr, dtlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send the stream header once as a handshake.
|
||||
if _, err := conn.Write(header.Bytes()); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UDPClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *UDPClient) Read(p []byte) (n int, err error) {
|
||||
return c.conn.Read(p)
|
||||
}
|
||||
|
||||
func (c *UDPClient) Write(p []byte) (n int, err error) {
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *UDPClient) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *UDPClient) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *UDPClient) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *UDPClient) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *UDPClient) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *UDPClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
164
agent/pkg/agent/stream/udp_server.go
Normal file
164
agent/pkg/agent/stream/udp_server.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
)
|
||||
|
||||
type UDPServer struct {
|
||||
ctx context.Context
|
||||
laddr *net.UDPAddr
|
||||
listener net.Listener
|
||||
|
||||
dtlsConfig *dtls.Config
|
||||
connMgr *ConnectionManager[*net.UDPConn]
|
||||
}
|
||||
|
||||
func NewUDPServer(ctx context.Context, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
dtlsConfig := &dtls.Config{
|
||||
Certificates: []tls.Certificate{*serverCert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: dtls.RequireAndVerifyClientCert,
|
||||
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||
}
|
||||
|
||||
s := &UDPServer{
|
||||
ctx: ctx,
|
||||
laddr: laddr,
|
||||
dtlsConfig: dtlsConfig,
|
||||
}
|
||||
s.connMgr = NewConnectionManager(s.createDestConnection)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *UDPServer) Start() error {
|
||||
listener, err := dtls.Listen("udp", s.laddr, s.dtlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.handleDTLSConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPServer) Addr() net.Addr {
|
||||
if s.listener != nil {
|
||||
return s.listener.Addr()
|
||||
}
|
||||
return s.laddr
|
||||
}
|
||||
|
||||
func (s *UDPServer) Close() error {
|
||||
s.connMgr.CloseAllConnections()
|
||||
if s.listener != nil {
|
||||
return s.listener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) {
|
||||
defer clientConn.Close()
|
||||
|
||||
// Read the stream header once as a handshake.
|
||||
var headerBuf [headerSize]byte
|
||||
if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
header := ToHeader(headerBuf)
|
||||
if !header.Validate() {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
|
||||
host, port := header.GetHostPort()
|
||||
dstConn, err := s.connMgr.GetOrCreateDestConnection(clientConn, host, port)
|
||||
if err != nil {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
defer s.connMgr.DeleteDestConnection(clientConn)
|
||||
|
||||
go s.forwardFromDestination(dstConn, clientConn)
|
||||
|
||||
buf := sizedPool.GetSized(65535)
|
||||
defer sizedPool.Put(buf)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
if _, err := dstConn.Write(buf[:n]); err != nil {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) {
|
||||
addr := host + ":" + port
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dstConn, err := net.DialUDP("udp", nil, udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dstConn, nil
|
||||
}
|
||||
|
||||
func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) {
|
||||
buffer := sizedPool.GetSized(65535)
|
||||
defer sizedPool.Put(buffer)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_ = dstConn.SetReadDeadline(time.Now().Add(readDeadline))
|
||||
n, err := dstConn.Read(buffer)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return
|
||||
}
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
if _, err := clientConn.Write(buffer[:n]); err != nil {
|
||||
// TODO: log error
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
services:
|
||||
agent:
|
||||
image: "{{.Image}}"
|
||||
container_name: godoxy-agent
|
||||
restart: always
|
||||
network_mode: host # do not change this
|
||||
environment:
|
||||
AGENT_NAME: "{{.Name}}"
|
||||
AGENT_PORT: "{{.Port}}"
|
||||
AGENT_CA_CERT: "{{.CACert}}"
|
||||
AGENT_SSL_CERT: "{{.SSLCert}}"
|
||||
# use agent as a docker socket proxy: [host]:port
|
||||
# set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375)
|
||||
LISTEN_ADDR:
|
||||
POST: false
|
||||
ALLOW_RESTARTS: false
|
||||
ALLOW_START: false
|
||||
ALLOW_STOP: false
|
||||
AUTH: false
|
||||
BUILD: false
|
||||
COMMIT: false
|
||||
CONFIGS: false
|
||||
CONTAINERS: false
|
||||
DISTRIBUTION: false
|
||||
EVENTS: true
|
||||
EXEC: false
|
||||
GRPC: false
|
||||
IMAGES: false
|
||||
INFO: false
|
||||
NETWORKS: false
|
||||
NODES: false
|
||||
PING: true
|
||||
PLUGINS: false
|
||||
SECRETS: false
|
||||
SERVICES: false
|
||||
SESSION: false
|
||||
SWARM: false
|
||||
SYSTEM: false
|
||||
TASKS: false
|
||||
VERSION: true
|
||||
VOLUMES: false
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./data:/app/data
|
||||
@@ -5,7 +5,9 @@ services:
|
||||
restart: always
|
||||
{{ if eq .ContainerRuntime "podman" -}}
|
||||
ports:
|
||||
- "{{.Port}}:{{.Port}}"
|
||||
- "{{.Port}}:{{.Port}}/tcp"
|
||||
- "{{.StreamPort}}:{{.StreamPort}}/tcp"
|
||||
- "{{.StreamPort}}:{{.StreamPort}}/udp"
|
||||
{{ else -}}
|
||||
network_mode: host # do not change this
|
||||
{{ end -}}
|
||||
@@ -22,6 +24,7 @@ services:
|
||||
{{ end -}}
|
||||
AGENT_NAME: "{{.Name}}"
|
||||
AGENT_PORT: "{{.Port}}"
|
||||
AGENT_STREAM_PORT: "{{.StreamPort}}"
|
||||
AGENT_CA_CERT: "{{.CACert}}"
|
||||
AGENT_SSL_CERT: "{{.SSLCert}}"
|
||||
# use agent as a docker socket proxy: [host]:port
|
||||
|
||||
2
agent/pkg/env/env.go
vendored
2
agent/pkg/env/env.go
vendored
@@ -20,6 +20,7 @@ func DefaultAgentName() string {
|
||||
var (
|
||||
AgentName string
|
||||
AgentPort int
|
||||
AgentStreamPort int
|
||||
AgentSkipClientCertCheck bool
|
||||
AgentCACert string
|
||||
AgentSSLCert string
|
||||
@@ -35,6 +36,7 @@ func Load() {
|
||||
DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock")
|
||||
AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName())
|
||||
AgentPort = env.GetEnvInt("AGENT_PORT", 8890)
|
||||
AgentStreamPort = env.GetEnvInt("AGENT_STREAM_PORT", AgentPort+1)
|
||||
AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
|
||||
|
||||
AgentCACert = env.GetEnvString("AGENT_CA_CERT", "")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
@@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler {
|
||||
}
|
||||
|
||||
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
|
||||
mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, version.Get())
|
||||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, env.AgentName)
|
||||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, env.Runtime)
|
||||
mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) {
|
||||
agentInfo := agent.AgentInfo{
|
||||
Version: version.Get(),
|
||||
Name: env.AgentName,
|
||||
Runtime: env.Runtime,
|
||||
StreamPort: env.AgentStreamPort,
|
||||
}
|
||||
sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo)
|
||||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
||||
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP)
|
||||
|
||||
3
go.mod
3
go.mod
@@ -172,6 +172,9 @@ require (
|
||||
github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect
|
||||
github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/logging v0.2.4 // indirect
|
||||
github.com/pion/transport/v3 v3.1.1 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/stretchr/objx v0.5.3 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -247,6 +247,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
|
||||
@@ -110,9 +110,9 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
||||
|
||||
switch rurl.Scheme {
|
||||
case "tcp":
|
||||
return stream.NewTCPTCPStream(laddr, rurl.Host)
|
||||
return stream.NewTCPTCPStream(laddr, rurl.Host, r.GetAgent())
|
||||
case "udp":
|
||||
return stream.NewUDPUDPStream(laddr, rurl.Host)
|
||||
return stream.NewUDPUDPStream(laddr, rurl.Host, r.GetAgent())
|
||||
}
|
||||
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/internal/acl"
|
||||
"github.com/yusing/godoxy/internal/entrypoint"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
@@ -17,6 +18,7 @@ type TCPTCPStream struct {
|
||||
listener net.Listener
|
||||
laddr *net.TCPAddr
|
||||
dst *net.TCPAddr
|
||||
agent *agent.AgentConfig
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
@@ -24,7 +26,7 @@ type TCPTCPStream struct {
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
func NewTCPTCPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -33,7 +35,7 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
|
||||
return &TCPTCPStream{laddr: laddr, dst: dst, agent: agentCfg}, nil
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
@@ -126,7 +128,15 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
dstConn, err := net.DialTCP("tcp", nil, s.dst)
|
||||
var (
|
||||
dstConn net.Conn
|
||||
err error
|
||||
)
|
||||
if s.agent != nil {
|
||||
dstConn, err = s.agent.NewTCPClient(s.dst.String())
|
||||
} else {
|
||||
dstConn, err = net.DialTCP("tcp", nil, s.dst)
|
||||
}
|
||||
if err != nil {
|
||||
if !s.closed.Load() {
|
||||
logErr(s, err, "failed to dial destination")
|
||||
@@ -140,7 +150,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
src := conn
|
||||
dst := net.Conn(dstConn)
|
||||
dst := dstConn
|
||||
if s.onRead != nil {
|
||||
src = &wrapperConn{
|
||||
Conn: conn,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/internal/acl"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/goutils/synk"
|
||||
@@ -22,6 +23,7 @@ type UDPUDPStream struct {
|
||||
|
||||
laddr *net.UDPAddr
|
||||
dst *net.UDPAddr
|
||||
agent *agent.AgentConfig
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
@@ -35,7 +37,7 @@ type UDPUDPStream struct {
|
||||
|
||||
type udpUDPConn struct {
|
||||
srcAddr *net.UDPAddr
|
||||
dstConn *net.UDPConn
|
||||
dstConn net.Conn
|
||||
listener net.PacketConn
|
||||
lastUsed atomic.Time
|
||||
closed atomic.Bool
|
||||
@@ -51,7 +53,7 @@ const (
|
||||
|
||||
var bufPool = synk.GetSizedBytesPool()
|
||||
|
||||
func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
func NewUDPUDPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveUDPAddr("udp", dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -63,6 +65,7 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
return &UDPUDPStream{
|
||||
laddr: laddr,
|
||||
dst: dst,
|
||||
agent: agentCfg,
|
||||
conns: make(map[string]*udpUDPConn),
|
||||
}, nil
|
||||
}
|
||||
@@ -189,8 +192,16 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
||||
}
|
||||
}
|
||||
|
||||
// Create UDP connection to destination
|
||||
dstConn, err := net.DialUDP("udp", nil, s.dst)
|
||||
// Create connection to destination (direct UDP or via agent stream tunnel)
|
||||
var (
|
||||
dstConn net.Conn
|
||||
err error
|
||||
)
|
||||
if s.agent != nil {
|
||||
dstConn, err = s.agent.NewUDPClient(s.dst.String())
|
||||
} else {
|
||||
dstConn, err = net.DialUDP("udp", nil, s.dst)
|
||||
}
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to dial dst")
|
||||
return nil, false
|
||||
@@ -205,7 +216,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
||||
|
||||
// Send initial data before starting response handler
|
||||
if !conn.forwardToDestination(initialData) {
|
||||
dstConn.Close()
|
||||
_ = dstConn.Close()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -328,6 +339,6 @@ func (conn *udpUDPConn) Close() {
|
||||
|
||||
conn.closed.Store(true)
|
||||
|
||||
conn.dstConn.Close()
|
||||
_ = conn.dstConn.Close()
|
||||
conn.dstConn = nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user