mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-22 16:58:54 +02:00
feat(agent): agent stream tunneling with TLS and dTLS (UDP); combined agent APIs
- Add `StreamPort` configuration to agent configuration and environment variables - Implement TCP and UDP stream client support in agent package - Update agent verification to test stream connectivity (TCP/UDP) - Add `/info` endpoint to agent HTTP handler for version, name, runtime, and stream port - Remove /version, /name, /runtime APIs, replaced by /info - Update agent compose template to expose stream port for TCP and UDP - Update agent creation API to optionally specify stream port (defaults to port + 1) - Modify `StreamRoute` to pass agent configuration to stream implementations - Update `TCPTCPStream` and `UDPUDPStream` to use agent stream tunneling when agent is configured - Add support for both direct connections and agent-tunneled connections in stream routes This enables agents to handle TCP and UDP route tunneling, expanding the proxy capabilities beyond HTTP-only connections.
This commit is contained in:
@@ -1,15 +1,18 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||||
|
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||||
"github.com/yusing/godoxy/agent/pkg/env"
|
"github.com/yusing/godoxy/agent/pkg/env"
|
||||||
"github.com/yusing/godoxy/agent/pkg/server"
|
"github.com/yusing/godoxy/agent/pkg/server"
|
||||||
"github.com/yusing/godoxy/internal/metrics/systeminfo"
|
"github.com/yusing/godoxy/internal/metrics/systeminfo"
|
||||||
socketproxy "github.com/yusing/godoxy/socketproxy/pkg"
|
socketproxy "github.com/yusing/godoxy/socketproxy/pkg"
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
httpServer "github.com/yusing/goutils/server"
|
httpServer "github.com/yusing/goutils/server"
|
||||||
strutils "github.com/yusing/goutils/strings"
|
strutils "github.com/yusing/goutils/strings"
|
||||||
"github.com/yusing/goutils/task"
|
"github.com/yusing/goutils/task"
|
||||||
@@ -63,6 +66,16 @@ Tips:
|
|||||||
|
|
||||||
server.StartAgentServer(t, opts)
|
server.StartAgentServer(t, opts)
|
||||||
|
|
||||||
|
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentStreamPort})
|
||||||
|
if err != nil {
|
||||||
|
gperr.LogFatal("failed to listen on port", err)
|
||||||
|
}
|
||||||
|
tcpServer := stream.NewTCPServer(t.Context(), tcpListener, caCert.Leaf, srvCert)
|
||||||
|
go tcpServer.Start()
|
||||||
|
|
||||||
|
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentStreamPort}, caCert.Leaf, srvCert)
|
||||||
|
go udpServer.Start()
|
||||||
|
|
||||||
if socketproxy.ListenAddr != "" {
|
if socketproxy.ListenAddr != "" {
|
||||||
runtime := strutils.Title(string(env.Runtime))
|
runtime := strutils.Title(string(env.Runtime))
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ require (
|
|||||||
github.com/bytedance/sonic v1.14.2
|
github.com/bytedance/sonic v1.14.2
|
||||||
github.com/gin-gonic/gin v1.11.0
|
github.com/gin-gonic/gin v1.11.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
github.com/pion/dtls/v3 v3.0.9
|
||||||
github.com/puzpuzpuz/xsync/v4 v4.2.0
|
github.com/puzpuzpuz/xsync/v4 v4.2.0
|
||||||
github.com/rs/zerolog v1.34.0
|
github.com/rs/zerolog v1.34.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
@@ -87,6 +88,8 @@ require (
|
|||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/oschwald/maxminddb-golang v1.13.1 // indirect
|
github.com/oschwald/maxminddb-golang v1.13.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/pion/logging v0.2.4 // indirect
|
||||||
|
github.com/pion/transport/v3 v3.1.1 // indirect
|
||||||
github.com/pires/go-proxyproto v0.8.1 // indirect
|
github.com/pires/go-proxyproto v0.8.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||||
|
|||||||
@@ -217,6 +217,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||||
|
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||||
|
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||||
|
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||||
|
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||||
|
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||||
|
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
var (
|
var (
|
||||||
installScript = `AGENT_NAME="{{.Name}}" \
|
installScript = `AGENT_NAME="{{.Name}}" \
|
||||||
AGENT_PORT="{{.Port}}" \
|
AGENT_PORT="{{.Port}}" \
|
||||||
|
AGENT_STREAM_PORT="{{.StreamPort}}" \
|
||||||
AGENT_CA_CERT="{{.CACert}}" \
|
AGENT_CA_CERT="{{.CACert}}" \
|
||||||
AGENT_SSL_CERT="{{.SSLCert}}" \
|
AGENT_SSL_CERT="{{.SSLCert}}" \
|
||||||
{{ if eq .ContainerRuntime "nerdctl" -}}
|
{{ if eq .ContainerRuntime "nerdctl" -}}
|
||||||
|
|||||||
@@ -4,38 +4,61 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||||
"github.com/yusing/godoxy/agent/pkg/certs"
|
"github.com/yusing/godoxy/agent/pkg/certs"
|
||||||
|
gperr "github.com/yusing/goutils/errs"
|
||||||
"github.com/yusing/goutils/version"
|
"github.com/yusing/goutils/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
Addr string `json:"addr"`
|
AgentInfo
|
||||||
Name string `json:"name"`
|
|
||||||
Version version.Version `json:"version" swaggertype:"string"`
|
Addr string `json:"addr"`
|
||||||
Runtime ContainerRuntime `json:"runtime"`
|
|
||||||
|
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
fasthttpClientHealthCheck *fasthttp.Client
|
fasthttpClientHealthCheck *fasthttp.Client
|
||||||
tlsConfig tls.Config
|
tlsConfig tls.Config
|
||||||
l zerolog.Logger
|
|
||||||
|
// for stream
|
||||||
|
caCert *x509.Certificate
|
||||||
|
clientCert *tls.Certificate
|
||||||
|
isTCPStreamSupported bool
|
||||||
|
isUDPStreamSupported bool
|
||||||
|
streamServerAddr string
|
||||||
|
|
||||||
|
l zerolog.Logger
|
||||||
} // @name Agent
|
} // @name Agent
|
||||||
|
|
||||||
|
type AgentInfo struct {
|
||||||
|
Version version.Version `json:"version" swaggertype:"string"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Runtime ContainerRuntime `json:"runtime"`
|
||||||
|
StreamPort int `json:"stream_port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated. Replaced by EndpointInfo
|
||||||
const (
|
const (
|
||||||
EndpointVersion = "/version"
|
EndpointVersion = "/version"
|
||||||
EndpointName = "/name"
|
EndpointName = "/name"
|
||||||
EndpointRuntime = "/runtime"
|
EndpointRuntime = "/runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointInfo = "/info"
|
||||||
EndpointProxyHTTP = "/proxy/http"
|
EndpointProxyHTTP = "/proxy/http"
|
||||||
EndpointHealth = "/health"
|
EndpointHealth = "/health"
|
||||||
EndpointLogs = "/logs"
|
EndpointLogs = "/logs"
|
||||||
@@ -90,6 +113,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
cfg.clientCert = &clientCert
|
||||||
|
|
||||||
// create tls config
|
// create tls config
|
||||||
caCertPool := x509.NewCertPool()
|
caCertPool := x509.NewCertPool()
|
||||||
@@ -97,6 +121,14 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
|||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("invalid ca certificate")
|
return errors.New("invalid ca certificate")
|
||||||
}
|
}
|
||||||
|
// Keep the CA leaf for stream client dialing.
|
||||||
|
if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" {
|
||||||
|
return errors.New("invalid ca certificate")
|
||||||
|
} else if cert, err := x509.ParseCertificate(block.Bytes); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
cfg.caCert = cert
|
||||||
|
}
|
||||||
|
|
||||||
cfg.tlsConfig = tls.Config{
|
cfg.tlsConfig = tls.Config{
|
||||||
Certificates: []tls.Certificate{clientCert},
|
Certificates: []tls.Certificate{clientCert},
|
||||||
@@ -113,48 +145,97 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
|||||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// get agent name
|
status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo)
|
||||||
name, _, err := cfg.fetchString(ctx, EndpointName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg.Name = name
|
var streamUnsupportedErrs gperr.Builder
|
||||||
|
|
||||||
|
if status == http.StatusOK {
|
||||||
|
if cfg.StreamPort <= 0 {
|
||||||
|
return fmt.Errorf("invalid agent stream port: %d", cfg.StreamPort)
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(cfg.Addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cfg.streamServerAddr = net.JoinHostPort(host, strconv.Itoa(cfg.StreamPort))
|
||||||
|
|
||||||
|
// test stream server connection
|
||||||
|
const fakeAddress = "localhost:8080" // it won't be used, just for testing
|
||||||
|
// test TCP stream support
|
||||||
|
conn, err := agentstream.NewTCPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
|
||||||
|
if err != nil {
|
||||||
|
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
|
||||||
|
} else {
|
||||||
|
conn.Close()
|
||||||
|
cfg.isTCPStreamSupported = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// test UDP stream support
|
||||||
|
conn, err = agentstream.NewUDPClient(cfg.streamServerAddr, fakeAddress, cfg.caCert, cfg.clientCert)
|
||||||
|
if err != nil {
|
||||||
|
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
|
||||||
|
} else {
|
||||||
|
conn.Close()
|
||||||
|
cfg.isUDPStreamSupported = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// old agent does not support EndpointInfo
|
||||||
|
// fallback with old logic
|
||||||
|
cfg.isTCPStreamSupported = false
|
||||||
|
cfg.isUDPStreamSupported = false
|
||||||
|
streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling")
|
||||||
|
|
||||||
|
// get agent name
|
||||||
|
name, _, err := cfg.fetchString(ctx, EndpointName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Name = name
|
||||||
|
|
||||||
|
// check agent version
|
||||||
|
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Version = version.Parse(agentVersion)
|
||||||
|
|
||||||
|
// check agent runtime
|
||||||
|
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case http.StatusOK:
|
||||||
|
switch runtime {
|
||||||
|
case "docker":
|
||||||
|
cfg.Runtime = ContainerRuntimeDocker
|
||||||
|
// case "nerdctl":
|
||||||
|
// cfg.Runtime = ContainerRuntimeNerdctl
|
||||||
|
case "podman":
|
||||||
|
cfg.Runtime = ContainerRuntimePodman
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid agent runtime: %s", runtime)
|
||||||
|
}
|
||||||
|
case http.StatusNotFound:
|
||||||
|
// backward compatibility, old agent does not have runtime endpoint
|
||||||
|
cfg.Runtime = ContainerRuntimeDocker
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cfg.l = log.With().Str("agent", cfg.Name).Logger()
|
cfg.l = log.With().Str("agent", cfg.Name).Logger()
|
||||||
|
|
||||||
// check agent version
|
if err := streamUnsupportedErrs.Error(); err != nil {
|
||||||
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
|
gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check agent runtime
|
|
||||||
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch status {
|
|
||||||
case http.StatusOK:
|
|
||||||
switch runtime {
|
|
||||||
case "docker":
|
|
||||||
cfg.Runtime = ContainerRuntimeDocker
|
|
||||||
// case "nerdctl":
|
|
||||||
// cfg.Runtime = ContainerRuntimeNerdctl
|
|
||||||
case "podman":
|
|
||||||
cfg.Runtime = ContainerRuntimePodman
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("invalid agent runtime: %s", runtime)
|
|
||||||
}
|
|
||||||
case http.StatusNotFound:
|
|
||||||
// backward compatibility, old agent does not have runtime endpoint
|
|
||||||
cfg.Runtime = ContainerRuntimeDocker
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg.Version = version.Parse(agentVersion)
|
|
||||||
|
|
||||||
if serverVersion.IsNewerThanMajor(cfg.Version) {
|
if serverVersion.IsNewerThanMajor(cfg.Version) {
|
||||||
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version)
|
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version)
|
||||||
}
|
}
|
||||||
@@ -163,6 +244,53 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cfg *AgentConfig) getStreamServerAddr() (string, error) {
|
||||||
|
if cfg.streamServerAddr == "" {
|
||||||
|
return "", errors.New("agent stream server address is not initialized")
|
||||||
|
}
|
||||||
|
return cfg.streamServerAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPClient creates a new TCP client for the agent.
|
||||||
|
//
|
||||||
|
// It returns an error if
|
||||||
|
// - the agent is not initialized
|
||||||
|
// - the agent does not support TCP stream tunneling
|
||||||
|
// - the agent stream server address is not initialized
|
||||||
|
func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) {
|
||||||
|
if cfg.caCert == nil || cfg.clientCert == nil {
|
||||||
|
return nil, errors.New("agent is not initialized")
|
||||||
|
}
|
||||||
|
if !cfg.isTCPStreamSupported {
|
||||||
|
return nil, errors.New("agent does not support TCP stream tunneling")
|
||||||
|
}
|
||||||
|
serverAddr, err := cfg.getStreamServerAddr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return agentstream.NewTCPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDPClient creates a new UDP client for the agent.
|
||||||
|
//
|
||||||
|
// It returns an error if
|
||||||
|
// - the agent is not initialized
|
||||||
|
// - the agent does not support UDP stream tunneling
|
||||||
|
// - the agent stream server address is not initialized
|
||||||
|
func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) {
|
||||||
|
if cfg.caCert == nil || cfg.clientCert == nil {
|
||||||
|
return nil, errors.New("agent is not initialized")
|
||||||
|
}
|
||||||
|
if !cfg.isUDPStreamSupported {
|
||||||
|
return nil, errors.New("agent does not support UDP stream tunneling")
|
||||||
|
}
|
||||||
|
serverAddr, err := cfg.getStreamServerAddr()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return agentstream.NewUDPClient(serverAddr, targetAddress, cfg.caCert, cfg.clientCert)
|
||||||
|
}
|
||||||
|
|
||||||
func (cfg *AgentConfig) Start(ctx context.Context) error {
|
func (cfg *AgentConfig) Start(ctx context.Context) error {
|
||||||
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
|
filepath, ok := certs.AgentCertsFilepath(cfg.Addr)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type (
|
|||||||
AgentEnvConfig struct {
|
AgentEnvConfig struct {
|
||||||
Name string
|
Name string
|
||||||
Port int
|
Port int
|
||||||
|
StreamPort int
|
||||||
CACert string
|
CACert string
|
||||||
SSLCert string
|
SSLCert string
|
||||||
ContainerRuntime ContainerRuntime
|
ContainerRuntime ContainerRuntime
|
||||||
|
|||||||
@@ -87,6 +87,34 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin
|
|||||||
return ret, resp.StatusCode, nil
|
return ret, resp.StatusCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct
|
||||||
|
//
|
||||||
|
// It will return the status code of the response, and error if any.
|
||||||
|
// If the status code is not http.StatusOK, out will be unchanged but error will still be nil.
|
||||||
|
func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) {
|
||||||
|
resp, err := cfg.Do(ctx, "GET", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
data, release, err := httputils.ReadAllBody(resp)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer release(data)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sonic.Unmarshal(data, out)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return resp.StatusCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) {
|
func (cfg *AgentConfig) Websocket(ctx context.Context, endpoint string) (*websocket.Conn, *http.Response, error) {
|
||||||
transport := cfg.Transport()
|
transport := cfg.Transport()
|
||||||
dialer := websocket.Dialer{
|
dialer := websocket.Dialer{
|
||||||
|
|||||||
40
agent/pkg/agent/stream/PROTOCOL.md
Normal file
40
agent/pkg/agent/stream/PROTOCOL.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Stream proxy protocol
|
||||||
|
|
||||||
|
This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination.
|
||||||
|
|
||||||
|
## Header
|
||||||
|
|
||||||
|
The on-wire header is a fixed-size binary blob:
|
||||||
|
|
||||||
|
- `Version` (8 bytes)
|
||||||
|
- `Host` (255 bytes, NUL padded)
|
||||||
|
- `Port` (5 bytes, NUL padded)
|
||||||
|
- `Checksum` (4 bytes, big-endian CRC32)
|
||||||
|
|
||||||
|
Total: `headerSize = 8 + 255 + 5 + 4 = 272` bytes.
|
||||||
|
|
||||||
|
Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`.
|
||||||
|
|
||||||
|
See [`StreamRequestHeader`](payload.go:26).
|
||||||
|
|
||||||
|
## TCP behavior
|
||||||
|
|
||||||
|
1. Client establishes a TLS connection to the stream server.
|
||||||
|
2. Client sends exactly one header as a handshake.
|
||||||
|
3. After the handshake, both sides proxy raw TCP bytes between client and destination.
|
||||||
|
|
||||||
|
Server reads the header using `io.ReadFull` to avoid dropping bytes.
|
||||||
|
|
||||||
|
See [`NewTCPClient()`](tcp_client.go:15) and [`(*TCPServer).redirect()`](tcp_server.go:77).
|
||||||
|
|
||||||
|
## UDP-over-DTLS behavior
|
||||||
|
|
||||||
|
1. Client establishes a DTLS connection to the stream server.
|
||||||
|
2. Client sends exactly one header as a handshake.
|
||||||
|
3. After the handshake, both sides proxy raw UDP datagrams:
|
||||||
|
- client → destination: DTLS payload is written to destination `UDPConn`
|
||||||
|
- destination → client: destination payload is written back to the DTLS connection
|
||||||
|
|
||||||
|
Responses do **not** include a header.
|
||||||
|
|
||||||
|
See [`NewUDPClient()`](udp_client.go:17) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:67).
|
||||||
57
agent/pkg/agent/stream/common.go
Normal file
57
agent/pkg/agent/stream/common.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
"github.com/yusing/goutils/synk"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dialTimeout = 10 * time.Second
|
||||||
|
readDeadline = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var sizedPool = synk.GetSizedBytesPool()
|
||||||
|
|
||||||
|
type CreateConnFunc[Conn net.Conn] func(host, port string) (Conn, error)
|
||||||
|
type ConnectionManager[Conn net.Conn] struct {
|
||||||
|
m *xsync.Map[string, Conn]
|
||||||
|
createConnection CreateConnFunc[Conn]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectionManager[Conn net.Conn](createConnection CreateConnFunc[Conn]) *ConnectionManager[Conn] {
|
||||||
|
return &ConnectionManager[Conn]{
|
||||||
|
m: xsync.NewMap[string, Conn](),
|
||||||
|
createConnection: createConnection,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnectionManager[Conn]) GetOrCreateDestConnection(clientConn net.Conn, host, port string) (ret Conn, connErr error) {
|
||||||
|
clientKey := clientConn.RemoteAddr().String()
|
||||||
|
ret, _ = c.m.LoadOrCompute(clientKey, func() (conn Conn, cancel bool) {
|
||||||
|
conn, connErr = c.createConnection(host, port)
|
||||||
|
if connErr != nil {
|
||||||
|
cancel = true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnectionManager[Conn]) DeleteDestConnection(clientConn net.Conn) {
|
||||||
|
clientKey := clientConn.RemoteAddr().String()
|
||||||
|
conn, loaded := c.m.LoadAndDelete(clientKey)
|
||||||
|
if loaded {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnectionManager[Conn]) CloseAllConnections() {
|
||||||
|
for _, conn := range c.m.Range {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
c.m.Clear()
|
||||||
|
}
|
||||||
109
agent/pkg/agent/stream/payload.go
Normal file
109
agent/pkg/agent/stream/payload.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/crc32"
|
||||||
|
"io"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
versionSize = 8
|
||||||
|
hostSize = 255
|
||||||
|
portSize = 5
|
||||||
|
checksumSize = 4 // crc32 checksum
|
||||||
|
|
||||||
|
headerSize = versionSize + hostSize + portSize + checksumSize
|
||||||
|
)
|
||||||
|
|
||||||
|
var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0}
|
||||||
|
|
||||||
|
var ErrInvalidHeader = errors.New("invalid header")
|
||||||
|
|
||||||
|
type StreamRequestHeader struct {
|
||||||
|
Version [versionSize]byte
|
||||||
|
Host [hostSize]byte
|
||||||
|
Port [portSize]byte
|
||||||
|
Checksum [checksumSize]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamRequestPayload struct {
|
||||||
|
StreamRequestHeader
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) {
|
||||||
|
if len(host) > hostSize {
|
||||||
|
return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host))
|
||||||
|
}
|
||||||
|
if len(port) > portSize {
|
||||||
|
return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port))
|
||||||
|
}
|
||||||
|
header := &StreamRequestHeader{}
|
||||||
|
copy(header.Version[:], version[:])
|
||||||
|
copy(header.Host[:], host)
|
||||||
|
copy(header.Port[:], port)
|
||||||
|
header.updateChecksum()
|
||||||
|
return header, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ToHeader(buf [headerSize]byte) *StreamRequestHeader {
|
||||||
|
return (*StreamRequestHeader)(unsafe.Pointer(&buf[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo implements the io.WriterTo interface.
|
||||||
|
func (p *StreamRequestPayload) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
n1, err := w.Write(p.StreamRequestHeader.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(p.Data) == 0 {
|
||||||
|
return int64(n1), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n2, err := w.Write(p.Data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return int64(n1) + int64(n2), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *StreamRequestHeader) GetHostPort() (string, string) {
|
||||||
|
hostEnd := bytes.IndexByte(h.Host[:], 0)
|
||||||
|
portEnd := bytes.IndexByte(h.Port[:], 0)
|
||||||
|
if hostEnd == -1 {
|
||||||
|
hostEnd = hostSize
|
||||||
|
}
|
||||||
|
if portEnd == -1 {
|
||||||
|
portEnd = portSize
|
||||||
|
}
|
||||||
|
return string(h.Host[:hostEnd]), string(h.Port[:portEnd])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *StreamRequestHeader) Validate() bool {
|
||||||
|
if h.Version != version {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return h.validateChecksum()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *StreamRequestHeader) updateChecksum() {
|
||||||
|
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
|
||||||
|
binary.BigEndian.PutUint32(h.Checksum[:], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *StreamRequestHeader) validateChecksum() bool {
|
||||||
|
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
|
||||||
|
return checksum == binary.BigEndian.Uint32(h.Checksum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *StreamRequestHeader) BytesWithoutChecksum() []byte {
|
||||||
|
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize-checksumSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *StreamRequestHeader) Bytes() []byte {
|
||||||
|
return unsafe.Slice((*byte)(unsafe.Pointer(h)), headerSize)
|
||||||
|
}
|
||||||
53
agent/pkg/agent/stream/payload_test.go
Normal file
53
agent/pkg/agent/stream/payload_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) {
|
||||||
|
h, err := NewStreamRequestHeader("example.com", "443")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamRequestHeader: %v", err)
|
||||||
|
}
|
||||||
|
if !h.Validate() {
|
||||||
|
t.Fatalf("expected header to validate")
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf [headerSize]byte
|
||||||
|
copy(buf[:], h.Bytes())
|
||||||
|
h2 := ToHeader(buf)
|
||||||
|
if !h2.Validate() {
|
||||||
|
t.Fatalf("expected round-tripped header to validate")
|
||||||
|
}
|
||||||
|
host, port := h2.GetHostPort()
|
||||||
|
if host != "example.com" || port != "443" {
|
||||||
|
t.Fatalf("unexpected host/port: %q:%q", host, port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamRequestPayload_WriteTo_WritesFullHeader(t *testing.T) {
|
||||||
|
h, err := NewStreamRequestHeader("127.0.0.1", "53")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamRequestHeader: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &StreamRequestPayload{StreamRequestHeader: *h, Data: []byte("hello")}
|
||||||
|
|
||||||
|
var out bytes.Buffer
|
||||||
|
n, err := p.WriteTo(&out)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WriteTo: %v", err)
|
||||||
|
}
|
||||||
|
if int(n) != headerSize+len(p.Data) {
|
||||||
|
t.Fatalf("unexpected bytes written: got %d want %d", n, headerSize+len(p.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
written := out.Bytes()
|
||||||
|
if len(written) != headerSize+len(p.Data) {
|
||||||
|
t.Fatalf("unexpected output size: got %d", len(written))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(written[:headerSize], h.Bytes()) {
|
||||||
|
t.Fatalf("expected full header (including checksum) to be written")
|
||||||
|
}
|
||||||
|
}
|
||||||
235
agent/pkg/agent/stream/server_flow_test.go
Normal file
235
agent/pkg/agent/stream/server_flow_test.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newSerial(t *testing.T) *big.Int {
|
||||||
|
t.Helper()
|
||||||
|
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rand serial: %v", err)
|
||||||
|
}
|
||||||
|
return sn
|
||||||
|
}
|
||||||
|
|
||||||
|
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: newSerial(t),
|
||||||
|
Subject: pkix.Name{CommonName: "stream-test-ca"},
|
||||||
|
NotBefore: time.Now().Add(-time.Minute),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateCertificate(CA): %v", err)
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(der)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseCertificate(CA): %v", err)
|
||||||
|
}
|
||||||
|
return cert, key
|
||||||
|
}
|
||||||
|
|
||||||
|
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: newSerial(t),
|
||||||
|
Subject: pkix.Name{CommonName: cn},
|
||||||
|
NotBefore: time.Now().Add(-time.Minute),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{eku},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateCertificate(%s): %v", cn, err)
|
||||||
|
}
|
||||||
|
leaf, err := x509.ParseCertificate(der)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseCertificate(%s): %v", cn, err)
|
||||||
|
}
|
||||||
|
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||||
|
t.Helper()
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Listen tcp: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
_, _ = io.Copy(conn, conn)
|
||||||
|
}(c)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ln.Addr().String(), func() {
|
||||||
|
_ = ln.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||||
|
t.Helper()
|
||||||
|
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Listen udp: %v", err)
|
||||||
|
}
|
||||||
|
uc := pc.(*net.UDPConn)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
buf := make([]byte, 65535)
|
||||||
|
for {
|
||||||
|
n, raddr, err := uc.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = uc.WriteToUDP(buf[:n], raddr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return uc.LocalAddr().String(), func() {
|
||||||
|
_ = uc.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPServer_FullFlow(t *testing.T) {
|
||||||
|
ca, caKey := genCA(t)
|
||||||
|
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
||||||
|
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
||||||
|
|
||||||
|
dstAddr, closeDst := startTCPEcho(t)
|
||||||
|
defer closeDst()
|
||||||
|
|
||||||
|
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenTCP: %v", err)
|
||||||
|
}
|
||||||
|
defer tcpLn.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() { errCh <- srv.Start() }()
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
_ = srv.Close()
|
||||||
|
_ = <-errCh
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewTCPClient: %v", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
msg := []byte("ping over tcp")
|
||||||
|
if _, err := client.Write(msg); err != nil {
|
||||||
|
t.Fatalf("client.Write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, len(msg))
|
||||||
|
if _, err := io.ReadFull(client, buf); err != nil {
|
||||||
|
t.Fatalf("client.ReadFull: %v", err)
|
||||||
|
}
|
||||||
|
if string(buf) != string(msg) {
|
||||||
|
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPServer_FullFlow(t *testing.T) {
|
||||||
|
ca, caKey := genCA(t)
|
||||||
|
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
||||||
|
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
||||||
|
|
||||||
|
dstAddr, closeDst := startUDPEcho(t)
|
||||||
|
defer closeDst()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() { errCh <- srv.Start() }()
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
_ = srv.Close()
|
||||||
|
err := <-errCh
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
|
||||||
|
t.Logf("udp server exit: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for srv.listener == nil {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
t.Fatalf("udp server listener did not start")
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewUDPClient: %v", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
msg := []byte("ping over udp")
|
||||||
|
if _, err := client.Write(msg); err != nil {
|
||||||
|
t.Fatalf("client.Write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
n, err := client.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client.Read: %v", err)
|
||||||
|
}
|
||||||
|
if string(buf[:n]) != string(msg) {
|
||||||
|
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
91
agent/pkg/agent/stream/tcp_client.go
Normal file
91
agent/pkg/agent/stream/tcp_client.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TCPClient struct {
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPClient creates a new TCP client for the agent.
|
||||||
|
//
|
||||||
|
// It will establish a TLS connection and send a stream request header to the server.
|
||||||
|
//
|
||||||
|
// It returns an error if
|
||||||
|
// - the target address is invalid
|
||||||
|
// - the stream request header is invalid
|
||||||
|
// - the TLS configuration is invalid
|
||||||
|
// - the TLS connection fails
|
||||||
|
// - the stream request header is not sent
|
||||||
|
func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||||
|
host, port, err := net.SplitHostPort(targetAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := NewStreamRequestHeader(host, port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup TLS configuration
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AddCert(caCert)
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*clientCert},
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Establish TLS connection
|
||||||
|
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Send the stream header once as a handshake.
|
||||||
|
if _, err := conn.Write(header.Bytes()); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TCPClient{
|
||||||
|
conn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) Read(p []byte) (n int, err error) {
|
||||||
|
return c.conn.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) Write(p []byte) (n int, err error) {
|
||||||
|
return c.conn.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) SetDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TCPClient) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
99
agent/pkg/agent/stream/tcp_server.go
Normal file
99
agent/pkg/agent/stream/tcp_server.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
ioutils "github.com/yusing/goutils/io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TCPServer struct {
|
||||||
|
ctx context.Context
|
||||||
|
listener net.Listener
|
||||||
|
connMgr *ConnectionManager[net.Conn]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AddCert(caCert)
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*serverCert},
|
||||||
|
ClientCAs: caCertPool,
|
||||||
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpListener := tls.NewListener(listener, tlsConfig)
|
||||||
|
s := &TCPServer{
|
||||||
|
ctx: ctx,
|
||||||
|
listener: tcpListener,
|
||||||
|
}
|
||||||
|
s.connMgr = NewConnectionManager(s.createDestConnection)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPServer) Start() error {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return s.ctx.Err()
|
||||||
|
default:
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go s.handle(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPServer) Addr() net.Addr {
|
||||||
|
return s.listener.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPServer) Close() error {
|
||||||
|
s.connMgr.CloseAllConnections()
|
||||||
|
return s.listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPServer) handle(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
dst, err := s.redirect(conn)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer s.connMgr.DeleteDestConnection(conn)
|
||||||
|
pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst)
|
||||||
|
pipe.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) {
|
||||||
|
// Read the stream header once as a handshake.
|
||||||
|
var headerBuf [headerSize]byte
|
||||||
|
if _, err := io.ReadFull(conn, headerBuf[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
header := ToHeader(headerBuf)
|
||||||
|
if !header.Validate() {
|
||||||
|
return nil, ErrInvalidHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// get destination connection
|
||||||
|
host, port := header.GetHostPort()
|
||||||
|
return s.connMgr.GetOrCreateDestConnection(conn, host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) {
|
||||||
|
addr := host + ":" + port
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
99
agent/pkg/agent/stream/udp_client.go
Normal file
99
agent/pkg/agent/stream/udp_client.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UDPClient struct {
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDPClient creates a new UDP client for the agent.
|
||||||
|
//
|
||||||
|
// It will establish a DTLS connection and send a stream request header to the server.
|
||||||
|
//
|
||||||
|
// It returns an error if
|
||||||
|
// - the target address is invalid
|
||||||
|
// - the stream request header is invalid
|
||||||
|
// - the DTLS configuration is invalid
|
||||||
|
// - the DTLS connection fails
|
||||||
|
// - the stream request header is not sent
|
||||||
|
func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||||
|
host, port, err := net.SplitHostPort(targetAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := NewStreamRequestHeader(host, port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup DTLS configuration
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AddCert(caCert)
|
||||||
|
|
||||||
|
dtlsConfig := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{*clientCert},
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
InsecureSkipVerify: false,
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
raddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Establish DTLS connection
|
||||||
|
conn, err := dtls.Dial("udp", raddr, dtlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Send the stream header once as a handshake.
|
||||||
|
if _, err := conn.Write(header.Bytes()); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UDPClient{
|
||||||
|
conn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) Read(p []byte) (n int, err error) {
|
||||||
|
return c.conn.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) Write(p []byte) (n int, err error) {
|
||||||
|
return c.conn.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) SetDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UDPClient) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
164
agent/pkg/agent/stream/udp_server.go
Normal file
164
agent/pkg/agent/stream/udp_server.go
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UDPServer struct {
|
||||||
|
ctx context.Context
|
||||||
|
laddr *net.UDPAddr
|
||||||
|
listener net.Listener
|
||||||
|
|
||||||
|
dtlsConfig *dtls.Config
|
||||||
|
connMgr *ConnectionManager[*net.UDPConn]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUDPServer(ctx context.Context, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
caCertPool.AddCert(caCert)
|
||||||
|
|
||||||
|
dtlsConfig := &dtls.Config{
|
||||||
|
Certificates: []tls.Certificate{*serverCert},
|
||||||
|
ClientCAs: caCertPool,
|
||||||
|
ClientAuth: dtls.RequireAndVerifyClientCert,
|
||||||
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &UDPServer{
|
||||||
|
ctx: ctx,
|
||||||
|
laddr: laddr,
|
||||||
|
dtlsConfig: dtlsConfig,
|
||||||
|
}
|
||||||
|
s.connMgr = NewConnectionManager(s.createDestConnection)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPServer) Start() error {
|
||||||
|
listener, err := dtls.Listen("udp", s.laddr, s.dtlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listener = listener
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return s.ctx.Err()
|
||||||
|
default:
|
||||||
|
conn, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go s.handleDTLSConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPServer) Addr() net.Addr {
|
||||||
|
if s.listener != nil {
|
||||||
|
return s.listener.Addr()
|
||||||
|
}
|
||||||
|
return s.laddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPServer) Close() error {
|
||||||
|
s.connMgr.CloseAllConnections()
|
||||||
|
if s.listener != nil {
|
||||||
|
return s.listener.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) {
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// Read the stream header once as a handshake.
|
||||||
|
var headerBuf [headerSize]byte
|
||||||
|
if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
header := ToHeader(headerBuf)
|
||||||
|
if !header.Validate() {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host, port := header.GetHostPort()
|
||||||
|
dstConn, err := s.connMgr.GetOrCreateDestConnection(clientConn, host, port)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer s.connMgr.DeleteDestConnection(clientConn)
|
||||||
|
|
||||||
|
go s.forwardFromDestination(dstConn, clientConn)
|
||||||
|
|
||||||
|
buf := sizedPool.GetSized(65535)
|
||||||
|
defer sizedPool.Put(buf)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
n, err := clientConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := dstConn.Write(buf[:n]); err != nil {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) {
|
||||||
|
addr := host + ":" + port
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dstConn, err := net.DialUDP("udp", nil, udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dstConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) {
|
||||||
|
buffer := sizedPool.GetSized(65535)
|
||||||
|
defer sizedPool.Put(buffer)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
_ = dstConn.SetReadDeadline(time.Now().Add(readDeadline))
|
||||||
|
n, err := dstConn.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := clientConn.Write(buffer[:n]); err != nil {
|
||||||
|
// TODO: log error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
services:
|
|
||||||
agent:
|
|
||||||
image: "{{.Image}}"
|
|
||||||
container_name: godoxy-agent
|
|
||||||
restart: always
|
|
||||||
network_mode: host # do not change this
|
|
||||||
environment:
|
|
||||||
AGENT_NAME: "{{.Name}}"
|
|
||||||
AGENT_PORT: "{{.Port}}"
|
|
||||||
AGENT_CA_CERT: "{{.CACert}}"
|
|
||||||
AGENT_SSL_CERT: "{{.SSLCert}}"
|
|
||||||
# use agent as a docker socket proxy: [host]:port
|
|
||||||
# set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375)
|
|
||||||
LISTEN_ADDR:
|
|
||||||
POST: false
|
|
||||||
ALLOW_RESTARTS: false
|
|
||||||
ALLOW_START: false
|
|
||||||
ALLOW_STOP: false
|
|
||||||
AUTH: false
|
|
||||||
BUILD: false
|
|
||||||
COMMIT: false
|
|
||||||
CONFIGS: false
|
|
||||||
CONTAINERS: false
|
|
||||||
DISTRIBUTION: false
|
|
||||||
EVENTS: true
|
|
||||||
EXEC: false
|
|
||||||
GRPC: false
|
|
||||||
IMAGES: false
|
|
||||||
INFO: false
|
|
||||||
NETWORKS: false
|
|
||||||
NODES: false
|
|
||||||
PING: true
|
|
||||||
PLUGINS: false
|
|
||||||
SECRETS: false
|
|
||||||
SERVICES: false
|
|
||||||
SESSION: false
|
|
||||||
SWARM: false
|
|
||||||
SYSTEM: false
|
|
||||||
TASKS: false
|
|
||||||
VERSION: true
|
|
||||||
VOLUMES: false
|
|
||||||
volumes:
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock
|
|
||||||
- ./data:/app/data
|
|
||||||
@@ -5,7 +5,9 @@ services:
|
|||||||
restart: always
|
restart: always
|
||||||
{{ if eq .ContainerRuntime "podman" -}}
|
{{ if eq .ContainerRuntime "podman" -}}
|
||||||
ports:
|
ports:
|
||||||
- "{{.Port}}:{{.Port}}"
|
- "{{.Port}}:{{.Port}}/tcp"
|
||||||
|
- "{{.StreamPort}}:{{.StreamPort}}/tcp"
|
||||||
|
- "{{.StreamPort}}:{{.StreamPort}}/udp"
|
||||||
{{ else -}}
|
{{ else -}}
|
||||||
network_mode: host # do not change this
|
network_mode: host # do not change this
|
||||||
{{ end -}}
|
{{ end -}}
|
||||||
@@ -22,6 +24,7 @@ services:
|
|||||||
{{ end -}}
|
{{ end -}}
|
||||||
AGENT_NAME: "{{.Name}}"
|
AGENT_NAME: "{{.Name}}"
|
||||||
AGENT_PORT: "{{.Port}}"
|
AGENT_PORT: "{{.Port}}"
|
||||||
|
AGENT_STREAM_PORT: "{{.StreamPort}}"
|
||||||
AGENT_CA_CERT: "{{.CACert}}"
|
AGENT_CA_CERT: "{{.CACert}}"
|
||||||
AGENT_SSL_CERT: "{{.SSLCert}}"
|
AGENT_SSL_CERT: "{{.SSLCert}}"
|
||||||
# use agent as a docker socket proxy: [host]:port
|
# use agent as a docker socket proxy: [host]:port
|
||||||
|
|||||||
2
agent/pkg/env/env.go
vendored
2
agent/pkg/env/env.go
vendored
@@ -20,6 +20,7 @@ func DefaultAgentName() string {
|
|||||||
var (
|
var (
|
||||||
AgentName string
|
AgentName string
|
||||||
AgentPort int
|
AgentPort int
|
||||||
|
AgentStreamPort int
|
||||||
AgentSkipClientCertCheck bool
|
AgentSkipClientCertCheck bool
|
||||||
AgentCACert string
|
AgentCACert string
|
||||||
AgentSSLCert string
|
AgentSSLCert string
|
||||||
@@ -35,6 +36,7 @@ func Load() {
|
|||||||
DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock")
|
DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock")
|
||||||
AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName())
|
AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName())
|
||||||
AgentPort = env.GetEnvInt("AGENT_PORT", 8890)
|
AgentPort = env.GetEnvInt("AGENT_PORT", 8890)
|
||||||
|
AgentStreamPort = env.GetEnvInt("AGENT_STREAM_PORT", AgentPort+1)
|
||||||
AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
|
AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
|
||||||
|
|
||||||
AgentCACert = env.GetEnvString("AGENT_CA_CERT", "")
|
AgentCACert = env.GetEnvString("AGENT_CA_CERT", "")
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bytedance/sonic"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||||
@@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
|
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
|
||||||
mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, version.Get())
|
agentInfo := agent.AgentInfo{
|
||||||
})
|
Version: version.Get(),
|
||||||
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
|
Name: env.AgentName,
|
||||||
fmt.Fprint(w, env.AgentName)
|
Runtime: env.Runtime,
|
||||||
})
|
StreamPort: env.AgentStreamPort,
|
||||||
mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) {
|
}
|
||||||
fmt.Fprint(w, env.Runtime)
|
sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo)
|
||||||
})
|
})
|
||||||
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
||||||
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP)
|
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -172,6 +172,9 @@ require (
|
|||||||
github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect
|
github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect
|
||||||
github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect
|
github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect
|
||||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||||
|
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||||
|
github.com/pion/logging v0.2.4 // indirect
|
||||||
|
github.com/pion/transport/v3 v3.1.1 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||||
github.com/pquerna/otp v1.5.0 // indirect
|
github.com/pquerna/otp v1.5.0 // indirect
|
||||||
github.com/stretchr/objx v0.5.3 // indirect
|
github.com/stretchr/objx v0.5.3 // indirect
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -247,6 +247,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||||
|
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||||
|
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||||
|
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||||
|
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||||
|
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||||
|
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
|||||||
@@ -110,9 +110,9 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
|||||||
|
|
||||||
switch rurl.Scheme {
|
switch rurl.Scheme {
|
||||||
case "tcp":
|
case "tcp":
|
||||||
return stream.NewTCPTCPStream(laddr, rurl.Host)
|
return stream.NewTCPTCPStream(laddr, rurl.Host, r.GetAgent())
|
||||||
case "udp":
|
case "udp":
|
||||||
return stream.NewUDPUDPStream(laddr, rurl.Host)
|
return stream.NewUDPUDPStream(laddr, rurl.Host, r.GetAgent())
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pires/go-proxyproto"
|
"github.com/pires/go-proxyproto"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||||
"github.com/yusing/godoxy/internal/acl"
|
"github.com/yusing/godoxy/internal/acl"
|
||||||
"github.com/yusing/godoxy/internal/entrypoint"
|
"github.com/yusing/godoxy/internal/entrypoint"
|
||||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||||
@@ -17,6 +18,7 @@ type TCPTCPStream struct {
|
|||||||
listener net.Listener
|
listener net.Listener
|
||||||
laddr *net.TCPAddr
|
laddr *net.TCPAddr
|
||||||
dst *net.TCPAddr
|
dst *net.TCPAddr
|
||||||
|
agent *agent.AgentConfig
|
||||||
|
|
||||||
preDial nettypes.HookFunc
|
preDial nettypes.HookFunc
|
||||||
onRead nettypes.HookFunc
|
onRead nettypes.HookFunc
|
||||||
@@ -24,7 +26,7 @@ type TCPTCPStream struct {
|
|||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
func NewTCPTCPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) {
|
||||||
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
|
dst, err := net.ResolveTCPAddr("tcp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -33,7 +35,7 @@ func NewTCPTCPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &TCPTCPStream{laddr: laddr, dst: dst}, nil
|
return &TCPTCPStream{laddr: laddr, dst: dst, agent: agentCfg}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||||
@@ -126,7 +128,15 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dstConn, err := net.DialTCP("tcp", nil, s.dst)
|
var (
|
||||||
|
dstConn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if s.agent != nil {
|
||||||
|
dstConn, err = s.agent.NewTCPClient(s.dst.String())
|
||||||
|
} else {
|
||||||
|
dstConn, err = net.DialTCP("tcp", nil, s.dst)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !s.closed.Load() {
|
if !s.closed.Load() {
|
||||||
logErr(s, err, "failed to dial destination")
|
logErr(s, err, "failed to dial destination")
|
||||||
@@ -140,7 +150,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
src := conn
|
src := conn
|
||||||
dst := net.Conn(dstConn)
|
dst := dstConn
|
||||||
if s.onRead != nil {
|
if s.onRead != nil {
|
||||||
src = &wrapperConn{
|
src = &wrapperConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||||
"github.com/yusing/godoxy/internal/acl"
|
"github.com/yusing/godoxy/internal/acl"
|
||||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||||
"github.com/yusing/goutils/synk"
|
"github.com/yusing/goutils/synk"
|
||||||
@@ -22,6 +23,7 @@ type UDPUDPStream struct {
|
|||||||
|
|
||||||
laddr *net.UDPAddr
|
laddr *net.UDPAddr
|
||||||
dst *net.UDPAddr
|
dst *net.UDPAddr
|
||||||
|
agent *agent.AgentConfig
|
||||||
|
|
||||||
preDial nettypes.HookFunc
|
preDial nettypes.HookFunc
|
||||||
onRead nettypes.HookFunc
|
onRead nettypes.HookFunc
|
||||||
@@ -35,7 +37,7 @@ type UDPUDPStream struct {
|
|||||||
|
|
||||||
type udpUDPConn struct {
|
type udpUDPConn struct {
|
||||||
srcAddr *net.UDPAddr
|
srcAddr *net.UDPAddr
|
||||||
dstConn *net.UDPConn
|
dstConn net.Conn
|
||||||
listener net.PacketConn
|
listener net.PacketConn
|
||||||
lastUsed atomic.Time
|
lastUsed atomic.Time
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
@@ -51,7 +53,7 @@ const (
|
|||||||
|
|
||||||
var bufPool = synk.GetSizedBytesPool()
|
var bufPool = synk.GetSizedBytesPool()
|
||||||
|
|
||||||
func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
func NewUDPUDPStream(listenAddr, dstAddr string, agentCfg *agent.AgentConfig) (nettypes.Stream, error) {
|
||||||
dst, err := net.ResolveUDPAddr("udp", dstAddr)
|
dst, err := net.ResolveUDPAddr("udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -63,6 +65,7 @@ func NewUDPUDPStream(listenAddr, dstAddr string) (nettypes.Stream, error) {
|
|||||||
return &UDPUDPStream{
|
return &UDPUDPStream{
|
||||||
laddr: laddr,
|
laddr: laddr,
|
||||||
dst: dst,
|
dst: dst,
|
||||||
|
agent: agentCfg,
|
||||||
conns: make(map[string]*udpUDPConn),
|
conns: make(map[string]*udpUDPConn),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -189,8 +192,16 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create UDP connection to destination
|
// Create connection to destination (direct UDP or via agent stream tunnel)
|
||||||
dstConn, err := net.DialUDP("udp", nil, s.dst)
|
var (
|
||||||
|
dstConn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if s.agent != nil {
|
||||||
|
dstConn, err = s.agent.NewUDPClient(s.dst.String())
|
||||||
|
} else {
|
||||||
|
dstConn, err = net.DialUDP("udp", nil, s.dst)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logErr(s, err, "failed to dial dst")
|
logErr(s, err, "failed to dial dst")
|
||||||
return nil, false
|
return nil, false
|
||||||
@@ -205,7 +216,7 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
|||||||
|
|
||||||
// Send initial data before starting response handler
|
// Send initial data before starting response handler
|
||||||
if !conn.forwardToDestination(initialData) {
|
if !conn.forwardToDestination(initialData) {
|
||||||
dstConn.Close()
|
_ = dstConn.Close()
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,6 +339,6 @@ func (conn *udpUDPConn) Close() {
|
|||||||
|
|
||||||
conn.closed.Store(true)
|
conn.closed.Store(true)
|
||||||
|
|
||||||
conn.dstConn.Close()
|
_ = conn.dstConn.Close()
|
||||||
conn.dstConn = nil
|
conn.dstConn = nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user