mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
feat(agent): agent stream tunneling with TLS and dTLS (UDP)
This commit is contained in:
@@ -1,21 +1,31 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
"github.com/yusing/godoxy/agent/pkg/env"
|
||||
"github.com/yusing/godoxy/agent/pkg/server"
|
||||
"github.com/yusing/godoxy/agent/pkg/handler"
|
||||
"github.com/yusing/godoxy/internal/metrics/systeminfo"
|
||||
socketproxy "github.com/yusing/godoxy/socketproxy/pkg"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httpServer "github.com/yusing/goutils/server"
|
||||
strutils "github.com/yusing/goutils/strings"
|
||||
"github.com/yusing/goutils/task"
|
||||
"github.com/yusing/goutils/version"
|
||||
)
|
||||
|
||||
// TODO: support IPv6
|
||||
|
||||
func main() {
|
||||
writer := zerolog.ConsoleWriter{
|
||||
Out: os.Stderr,
|
||||
@@ -52,16 +62,84 @@ func main() {
|
||||
Tips:
|
||||
1. To change the agent name, you can set the AGENT_NAME environment variable.
|
||||
2. To change the agent port, you can set the AGENT_PORT environment variable.
|
||||
`)
|
||||
`)
|
||||
|
||||
t := task.RootTask("agent", false)
|
||||
opts := server.Options{
|
||||
CACert: caCert,
|
||||
ServerCert: srvCert,
|
||||
Port: env.AgentPort,
|
||||
|
||||
// One TCP listener on AGENT_PORT, then multiplex by TLS ALPN:
|
||||
// - Stream ALPN: route to TCP stream tunnel handler (via http.Server.TLSNextProto)
|
||||
// - Otherwise: route to HTTPS API handler
|
||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort})
|
||||
if err != nil {
|
||||
gperr.LogFatal("failed to listen on port", err)
|
||||
}
|
||||
|
||||
server.StartAgentServer(t, opts)
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert.Leaf)
|
||||
|
||||
muxTLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*srvCert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
// Keep HTTP limited to HTTP/1.1 (matching current agent server behavior)
|
||||
// and add the stream tunnel ALPN for multiplexing.
|
||||
NextProtos: []string{"http/1.1", stream.StreamALPN},
|
||||
}
|
||||
if env.AgentSkipClientCertCheck {
|
||||
muxTLSConfig.ClientAuth = tls.NoClientCert
|
||||
}
|
||||
|
||||
// TLS listener feeds the HTTP server. ALPN stream connections are intercepted
|
||||
// using http.Server.TLSNextProto.
|
||||
tlsLn := tls.NewListener(tcpListener, muxTLSConfig)
|
||||
|
||||
streamSrv := stream.NewTCPServerHandler(t.Context())
|
||||
|
||||
httpSrv := &http.Server{
|
||||
Handler: handler.NewAgentHandler(),
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return t.Context()
|
||||
},
|
||||
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
|
||||
// When a client negotiates StreamALPN, net/http will call this hook instead
|
||||
// of treating the connection as HTTP.
|
||||
stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
|
||||
// ServeConn blocks until the tunnel finishes.
|
||||
streamSrv.ServeConn(conn)
|
||||
},
|
||||
},
|
||||
}
|
||||
{
|
||||
subtask := t.Subtask("agent-http", true)
|
||||
t.OnCancel("stop_http", func() {
|
||||
_ = streamSrv.Close()
|
||||
_ = httpSrv.Close()
|
||||
_ = tlsLn.Close()
|
||||
})
|
||||
go func() {
|
||||
err := httpSrv.Serve(tlsLn)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Error().Err(err).Msg("agent HTTP server stopped with error")
|
||||
}
|
||||
subtask.Finish(err)
|
||||
}()
|
||||
log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started (ALPN mux enabled)")
|
||||
}
|
||||
log.Info().Int("port", env.AgentPort).Msg("TCP stream handler started (via TLSNextProto)")
|
||||
|
||||
{
|
||||
udpServer := stream.NewUDPServer(t.Context(), "udp", &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert)
|
||||
subtask := t.Subtask("agent-stream-udp", true)
|
||||
t.OnCancel("stop_stream_udp", func() {
|
||||
_ = udpServer.Close()
|
||||
})
|
||||
go func() {
|
||||
err := udpServer.Start()
|
||||
subtask.Finish(err)
|
||||
}()
|
||||
log.Info().Int("port", env.AgentPort).Msg("UDP stream server started")
|
||||
}
|
||||
|
||||
if socketproxy.ListenAddr != "" {
|
||||
runtime := strutils.Title(string(env.Runtime))
|
||||
|
||||
@@ -18,6 +18,8 @@ require (
|
||||
github.com/bytedance/sonic v1.14.2
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/pion/dtls/v3 v3.0.9
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/yusing/godoxy v0.0.0-00010101000000-000000000000
|
||||
@@ -72,6 +74,7 @@ require (
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pion/logging v0.2.4 // indirect
|
||||
github.com/pires/go-proxyproto v0.8.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
|
||||
@@ -146,6 +146,12 @@ github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5
|
||||
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
3
agent/pkg/agent/common/common.go
Normal file
3
agent/pkg/agent/common/common.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package common
|
||||
|
||||
const CertsDNSName = "godoxy.agent"
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -16,31 +18,51 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/common"
|
||||
agentstream "github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
"github.com/yusing/godoxy/agent/pkg/certs"
|
||||
gperr "github.com/yusing/goutils/errs"
|
||||
httputils "github.com/yusing/goutils/http"
|
||||
"github.com/yusing/goutils/version"
|
||||
)
|
||||
|
||||
type AgentConfig struct {
|
||||
Addr string `json:"addr"`
|
||||
Name string `json:"name"`
|
||||
Version version.Version `json:"version" swaggertype:"string"`
|
||||
Runtime ContainerRuntime `json:"runtime"`
|
||||
AgentInfo
|
||||
|
||||
Addr string `json:"addr"`
|
||||
IsTCPStreamSupported bool `json:"supports_tcp_stream"`
|
||||
IsUDPStreamSupported bool `json:"supports_udp_stream"`
|
||||
|
||||
// for stream
|
||||
caCert *x509.Certificate
|
||||
clientCert *tls.Certificate
|
||||
|
||||
tlsConfig tls.Config
|
||||
l zerolog.Logger
|
||||
|
||||
l zerolog.Logger
|
||||
} // @name Agent
|
||||
|
||||
type AgentInfo struct {
|
||||
Version version.Version `json:"version" swaggertype:"string"`
|
||||
Name string `json:"name"`
|
||||
Runtime ContainerRuntime `json:"runtime"`
|
||||
}
|
||||
|
||||
// Deprecated. Replaced by EndpointInfo
|
||||
const (
|
||||
EndpointVersion = "/version"
|
||||
EndpointName = "/name"
|
||||
EndpointRuntime = "/runtime"
|
||||
EndpointVersion = "/version"
|
||||
EndpointName = "/name"
|
||||
EndpointRuntime = "/runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
EndpointInfo = "/info"
|
||||
EndpointProxyHTTP = "/proxy/http"
|
||||
EndpointHealth = "/health"
|
||||
EndpointLogs = "/logs"
|
||||
EndpointSystemInfo = "/system_info"
|
||||
|
||||
AgentHost = CertsDNSName
|
||||
AgentHost = common.CertsDNSName
|
||||
|
||||
APIEndpointBase = "/godoxy/agent"
|
||||
APIBaseURL = "https://" + AgentHost + APIEndpointBase
|
||||
@@ -90,6 +112,7 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.clientCert = &clientCert
|
||||
|
||||
// create tls config
|
||||
caCertPool := x509.NewCertPool()
|
||||
@@ -97,58 +120,105 @@ func (cfg *AgentConfig) InitWithCerts(ctx context.Context, ca, crt, key []byte)
|
||||
if !ok {
|
||||
return errors.New("invalid ca certificate")
|
||||
}
|
||||
// Keep the CA leaf for stream client dialing.
|
||||
if block, _ := pem.Decode(ca); block == nil || block.Type != "CERTIFICATE" {
|
||||
return errors.New("invalid ca certificate")
|
||||
} else if cert, err := x509.ParseCertificate(block.Bytes); err != nil {
|
||||
return err
|
||||
} else {
|
||||
cfg.caCert = cert
|
||||
}
|
||||
|
||||
cfg.tlsConfig = tls.Config{
|
||||
Certificates: []tls.Certificate{clientCert},
|
||||
RootCAs: caCertPool,
|
||||
ServerName: CertsDNSName,
|
||||
ServerName: common.CertsDNSName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// get agent name
|
||||
name, _, err := cfg.fetchString(ctx, EndpointName)
|
||||
status, err := cfg.fetchJSON(ctx, EndpointInfo, &cfg.AgentInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Name = name
|
||||
var streamUnsupportedErrs gperr.Builder
|
||||
|
||||
if status == http.StatusOK {
|
||||
// test stream server connection
|
||||
const fakeAddress = "localhost:8080" // it won't be used, just for testing
|
||||
// test TCP stream support
|
||||
err := agentstream.TCPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert)
|
||||
if err != nil {
|
||||
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
|
||||
} else {
|
||||
cfg.IsTCPStreamSupported = true
|
||||
}
|
||||
|
||||
// test UDP stream support
|
||||
err = agentstream.UDPHealthCheck(cfg.Addr, cfg.caCert, cfg.clientCert)
|
||||
if err != nil {
|
||||
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
|
||||
} else {
|
||||
cfg.IsUDPStreamSupported = true
|
||||
}
|
||||
} else {
|
||||
// old agent does not support EndpointInfo
|
||||
// fallback with old logic
|
||||
cfg.IsTCPStreamSupported = false
|
||||
cfg.IsUDPStreamSupported = false
|
||||
streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling")
|
||||
|
||||
// get agent name
|
||||
name, _, err := cfg.fetchString(ctx, EndpointName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Name = name
|
||||
|
||||
// check agent version
|
||||
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Version = version.Parse(agentVersion)
|
||||
|
||||
// check agent runtime
|
||||
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch status {
|
||||
case http.StatusOK:
|
||||
switch runtime {
|
||||
case "docker":
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
// case "nerdctl":
|
||||
// cfg.Runtime = ContainerRuntimeNerdctl
|
||||
case "podman":
|
||||
cfg.Runtime = ContainerRuntimePodman
|
||||
default:
|
||||
return fmt.Errorf("invalid agent runtime: %s", runtime)
|
||||
}
|
||||
case http.StatusNotFound:
|
||||
// backward compatibility, old agent does not have runtime endpoint
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
default:
|
||||
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
|
||||
}
|
||||
}
|
||||
|
||||
cfg.l = log.With().Str("agent", cfg.Name).Logger()
|
||||
|
||||
// check agent version
|
||||
agentVersion, _, err := cfg.fetchString(ctx, EndpointVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := streamUnsupportedErrs.Error(); err != nil {
|
||||
gperr.LogWarn("agent has limited/no stream tunneling support, TCP and UDP routes via agent will not work", err, &cfg.l)
|
||||
}
|
||||
|
||||
// check agent runtime
|
||||
runtime, status, err := cfg.fetchString(ctx, EndpointRuntime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch status {
|
||||
case http.StatusOK:
|
||||
switch runtime {
|
||||
case "docker":
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
// case "nerdctl":
|
||||
// cfg.Runtime = ContainerRuntimeNerdctl
|
||||
case "podman":
|
||||
cfg.Runtime = ContainerRuntimePodman
|
||||
default:
|
||||
return fmt.Errorf("invalid agent runtime: %s", runtime)
|
||||
}
|
||||
case http.StatusNotFound:
|
||||
// backward compatibility, old agent does not have runtime endpoint
|
||||
cfg.Runtime = ContainerRuntimeDocker
|
||||
default:
|
||||
return fmt.Errorf("failed to get agent runtime: HTTP %d %s", status, runtime)
|
||||
}
|
||||
|
||||
cfg.Version = version.Parse(agentVersion)
|
||||
|
||||
if serverVersion.IsNewerThanMajor(cfg.Version) {
|
||||
log.Warn().Msgf("agent %s major version mismatch: server: %s, agent: %s", cfg.Name, serverVersion, cfg.Version)
|
||||
}
|
||||
@@ -177,6 +247,38 @@ func (cfg *AgentConfig) Init(ctx context.Context) error {
|
||||
return cfg.InitWithCerts(ctx, ca, crt, key)
|
||||
}
|
||||
|
||||
// NewTCPClient creates a new TCP client for the agent.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the agent is not initialized
|
||||
// - the agent does not support TCP stream tunneling
|
||||
// - the agent stream server address is not initialized
|
||||
func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) {
|
||||
if cfg.caCert == nil || cfg.clientCert == nil {
|
||||
return nil, errors.New("agent is not initialized")
|
||||
}
|
||||
if !cfg.IsTCPStreamSupported {
|
||||
return nil, errors.New("agent does not support TCP stream tunneling")
|
||||
}
|
||||
return agentstream.NewTCPClient(cfg.Addr, targetAddress, cfg.caCert, cfg.clientCert)
|
||||
}
|
||||
|
||||
// NewUDPClient creates a new UDP client for the agent.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the agent is not initialized
|
||||
// - the agent does not support UDP stream tunneling
|
||||
// - the agent stream server address is not initialized
|
||||
func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) {
|
||||
if cfg.caCert == nil || cfg.clientCert == nil {
|
||||
return nil, errors.New("agent is not initialized")
|
||||
}
|
||||
if !cfg.IsUDPStreamSupported {
|
||||
return nil, errors.New("agent does not support UDP stream tunneling")
|
||||
}
|
||||
return agentstream.NewUDPClient(cfg.Addr, targetAddress, cfg.caCert, cfg.clientCert)
|
||||
}
|
||||
|
||||
func (cfg *AgentConfig) Transport() *http.Transport {
|
||||
return &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
@@ -232,3 +334,31 @@ func (cfg *AgentConfig) fetchString(ctx context.Context, endpoint string) (strin
|
||||
release(data)
|
||||
return ret, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// fetchJSON fetches a JSON response from the agent and unmarshals it into the provided struct
|
||||
//
|
||||
// It will return the status code of the response, and error if any.
|
||||
// If the status code is not http.StatusOK, out will be unchanged but error will still be nil.
|
||||
func (cfg *AgentConfig) fetchJSON(ctx context.Context, endpoint string, out any) (int, error) {
|
||||
resp, err := cfg.do(ctx, "GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, release, err := httputils.ReadAllBody(resp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer release(data)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, out)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
@@ -17,10 +17,8 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CertsDNSName = "godoxy.agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/common"
|
||||
)
|
||||
|
||||
func toPEMPair(certDER []byte, key *ecdsa.PrivateKey) *PEMPair {
|
||||
@@ -156,7 +154,7 @@ func NewAgent() (ca, srv, client *PEMPair, err error) {
|
||||
SerialNumber: caSerialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"GoDoxy"},
|
||||
CommonName: CertsDNSName,
|
||||
CommonName: common.CertsDNSName,
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years
|
||||
@@ -196,9 +194,9 @@ func NewAgent() (ca, srv, client *PEMPair, err error) {
|
||||
Subject: pkix.Name{
|
||||
Organization: caTemplate.Subject.Organization,
|
||||
OrganizationalUnit: []string{"Server"},
|
||||
CommonName: CertsDNSName,
|
||||
CommonName: common.CertsDNSName,
|
||||
},
|
||||
DNSNames: []string{CertsDNSName},
|
||||
DNSNames: []string{common.CertsDNSName},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1000, 0, 0), // Add validity period
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
@@ -228,9 +226,9 @@ func NewAgent() (ca, srv, client *PEMPair, err error) {
|
||||
Subject: pkix.Name{
|
||||
Organization: caTemplate.Subject.Organization,
|
||||
OrganizationalUnit: []string{"Client"},
|
||||
CommonName: CertsDNSName,
|
||||
CommonName: common.CertsDNSName,
|
||||
},
|
||||
DNSNames: []string{CertsDNSName},
|
||||
DNSNames: []string{common.CertsDNSName},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1000, 0, 0),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/common"
|
||||
)
|
||||
|
||||
func TestNewAgent(t *testing.T) {
|
||||
@@ -72,7 +73,7 @@ func TestServerClient(t *testing.T) {
|
||||
clientTLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*clientTLS},
|
||||
RootCAs: caPool,
|
||||
ServerName: CertsDNSName,
|
||||
ServerName: common.CertsDNSName,
|
||||
}
|
||||
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
197
agent/pkg/agent/stream/README.md
Normal file
197
agent/pkg/agent/stream/README.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# Stream proxy protocol
|
||||
|
||||
This package implements a small header-based handshake that allows an authenticated client to request forwarding to a `(host, port)` destination. It supports both TCP-over-TLS and UDP-over-DTLS transports.
|
||||
|
||||
## Overview
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph Client
|
||||
TC[TCPClient] -->|TLS| TSS[TCPServer]
|
||||
UC[UDPClient] -->|DTLS| USS[UDPServer]
|
||||
end
|
||||
|
||||
subgraph Stream Protocol
|
||||
H[StreamRequestHeader]
|
||||
end
|
||||
|
||||
TSS -->|Redirect| DST1[Destination TCP]
|
||||
USS -->|Forward UDP| DST2[Destination UDP]
|
||||
```
|
||||
|
||||
## Header
|
||||
|
||||
The on-wire header is a fixed-size binary blob:
|
||||
|
||||
- `Version` (8 bytes)
|
||||
- `HostLength` (1 byte)
|
||||
- `Host` (255 bytes, NUL padded)
|
||||
- `PortLength` (1 byte)
|
||||
- `Port` (5 bytes, NUL padded)
|
||||
- `Flag` (1 byte, protocol flags)
|
||||
- `Checksum` (4 bytes, big-endian CRC32)
|
||||
|
||||
Total: `headerSize = 8 + 1 + 255 + 1 + 5 + 1 + 4 = 275` bytes.
|
||||
|
||||
Checksum is `crc32.ChecksumIEEE(header[0:headerSize-4])`.
|
||||
|
||||
### Flags
|
||||
|
||||
The `Flag` field is a bitmask of protocol flags defined by `FlagType`:
|
||||
|
||||
| Flag | Value | Purpose |
|
||||
| ---------------------- | ----- | ---------------------------------------------------------------------- |
|
||||
| `FlagCloseImmediately` | `1` | Health check probe - server closes immediately after validating header |
|
||||
|
||||
See [`FlagType`](header.go:26) and [`FlagCloseImmediately`](header.go:28).
|
||||
|
||||
See [`StreamRequestHeader`](header.go:30).
|
||||
|
||||
## File Structure
|
||||
|
||||
| File | Purpose |
|
||||
| ----------------------------------- | ------------------------------------------------------------ |
|
||||
| [`header.go`](header.go) | Stream request header structure and validation. |
|
||||
| [`tcp_client.go`](tcp_client.go:12) | TCP client implementation with TLS transport. |
|
||||
| [`tcp_server.go`](tcp_server.go:13) | TCP server implementation for handling stream requests. |
|
||||
| [`udp_client.go`](udp_client.go:13) | UDP client implementation with DTLS transport. |
|
||||
| [`udp_server.go`](udp_server.go:17) | UDP server implementation for handling DTLS stream requests. |
|
||||
| [`common.go`](common.go:11) | Connection manager and shared constants. |
|
||||
|
||||
## Constants
|
||||
|
||||
| Constant | Value | Purpose |
|
||||
| ---------------------- | ------------------------- | ------------------------------------------------------- |
|
||||
| `StreamALPN` | `"godoxy-agent-stream/1"` | TLS ALPN protocol for stream multiplexing. |
|
||||
| `headerSize` | `275` bytes | Total size of the stream request header. |
|
||||
| `dialTimeout` | `10s` | Timeout for establishing destination connections. |
|
||||
| `readDeadline` | `10s` | Read timeout for UDP destination sockets. |
|
||||
| `FlagCloseImmediately` | `1` | Flag for health check probe - server closes immediately |
|
||||
|
||||
See [`common.go`](common.go:11).
|
||||
|
||||
## Public API
|
||||
|
||||
### Types
|
||||
|
||||
#### `StreamRequestHeader`
|
||||
|
||||
Represents the on-wire protocol header used to negotiate a stream tunnel.
|
||||
|
||||
```go
|
||||
type StreamRequestHeader struct {
|
||||
Version [8]byte // Fixed to "0.1.0" with NUL padding
|
||||
HostLength byte // Actual host name length (0-255)
|
||||
Host [255]byte // NUL-padded host name
|
||||
PortLength byte // Actual port string length (0-5)
|
||||
Port [5]byte // NUL-padded port string
|
||||
Flag FlagType // Protocol flags (e.g., FlagCloseImmediately)
|
||||
Checksum [4]byte // CRC32 checksum of header without checksum
|
||||
}
|
||||
```
|
||||
|
||||
**Methods:**
|
||||
|
||||
- `NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error)` - Creates a header for the given host and port. Returns error if host exceeds 255 bytes or port exceeds 5 bytes.
|
||||
- `NewStreamHealthCheckHeader() *StreamRequestHeader` - Creates a header with `FlagCloseImmediately` set for health check probes.
|
||||
- `Validate() bool` - Validates the version and checksum.
|
||||
- `GetHostPort() (string, string)` - Extracts the host and port from the header.
|
||||
- `ShouldCloseImmediately() bool` - Returns true if `FlagCloseImmediately` is set.
|
||||
|
||||
### TCP Functions
|
||||
|
||||
- [`NewTCPClient()`](tcp_client.go:26) - Creates a TLS client connection and sends the stream header.
|
||||
- [`NewTCPServerHandler()`](tcp_server.go:24) - Creates a handler for ALPN-multiplexed connections (no listener).
|
||||
- [`NewTCPServerFromListener()`](tcp_server.go:36) - Wraps an existing TLS listener.
|
||||
- [`NewTCPServer()`](tcp_server.go:45) - Creates a fully-configured TCP server with TLS listener.
|
||||
|
||||
### UDP Functions
|
||||
|
||||
- [`NewUDPClient()`](udp_client.go:27) - Creates a DTLS client connection and sends the stream header.
|
||||
- [`NewUDPServer()`](udp_server.go:26) - Creates a DTLS server listening on the given UDP address.
|
||||
|
||||
## Health Check Probes
|
||||
|
||||
The protocol supports health check probes using the `FlagCloseImmediately` flag. When a client sends a header with this flag set, the server validates the header and immediately closes the connection without establishing a destination tunnel.
|
||||
|
||||
This is useful for:
|
||||
|
||||
- Connectivity testing between agent and server
|
||||
- Verifying TLS/DTLS handshake and mTLS authentication
|
||||
- Monitoring stream protocol availability
|
||||
|
||||
**Usage:**
|
||||
|
||||
```go
|
||||
header := stream.NewStreamHealthCheckHeader()
|
||||
// Send header over TLS/DTLS connection
|
||||
// Server will validate and close immediately
|
||||
```
|
||||
|
||||
Both TCP and UDP servers silently handle health check probes without logging errors.
|
||||
|
||||
See [`NewStreamHealthCheckHeader()`](header.go:66) and [`FlagCloseImmediately`](header.go:28).
|
||||
|
||||
## TCP behavior
|
||||
|
||||
1. Client establishes a TLS connection to the stream server.
|
||||
2. Client sends exactly one header as a handshake.
|
||||
3. After the handshake, both sides proxy raw TCP bytes between client and destination.
|
||||
|
||||
Server reads the header using `io.ReadFull` to avoid dropping bytes.
|
||||
|
||||
See [`NewTCPClient()`](tcp_client.go:26) and [`(*TCPServer).redirect()`](tcp_server.go:116).
|
||||
|
||||
## UDP-over-DTLS behavior
|
||||
|
||||
1. Client establishes a DTLS connection to the stream server.
|
||||
2. Client sends exactly one header as a handshake.
|
||||
3. After the handshake, both sides proxy raw UDP datagrams:
|
||||
- client -> destination: DTLS payload is written to destination `UDPConn`
|
||||
- destination -> client: destination payload is written back to the DTLS connection
|
||||
|
||||
Responses do **not** include a header.
|
||||
|
||||
The UDP server uses a bidirectional forwarding model:
|
||||
|
||||
- One goroutine forwards from client to destination
|
||||
- Another goroutine forwards from destination to client
|
||||
|
||||
The destination reader uses `readDeadline` to periodically wake up and check for context cancellation. Timeouts do not terminate the session.
|
||||
|
||||
See [`NewUDPClient()`](udp_client.go:27) and [`(*UDPServer).handleDTLSConnection()`](udp_server.go:89).
|
||||
|
||||
## Connection Management
|
||||
|
||||
Both `TCPServer` and `UDPServer` create a dedicated destination connection per incoming stream session and close it when the session ends (no destination connection reuse).
|
||||
|
||||
## Error Handling
|
||||
|
||||
| Error | Description |
|
||||
| --------------------- | ----------------------------------------------- |
|
||||
| `ErrInvalidHeader` | Header validation failed (version or checksum). |
|
||||
| `ErrCloseImmediately` | Health check probe - server closed immediately. |
|
||||
|
||||
Errors from connection creation are propagated to the caller.
|
||||
|
||||
See [`header.go`](header.go:23).
|
||||
|
||||
## Integration
|
||||
|
||||
This package is used by the agent to provide stream tunneling capabilities. See the parent [`agent`](../README.md) package for integration details with the GoDoxy server.
|
||||
|
||||
### Certificate Requirements
|
||||
|
||||
Both TCP and UDP servers require:
|
||||
|
||||
- CA certificate for client verification
|
||||
- Server certificate for TLS/DTLS termination
|
||||
|
||||
Both clients require:
|
||||
|
||||
- CA certificate for server verification
|
||||
- Client certificate for mTLS authentication
|
||||
|
||||
### ALPN Protocol
|
||||
|
||||
The `StreamALPN` constant (`"godoxy-agent-stream/1"`) is used to multiplex stream tunnel traffic and HTTPS API traffic on the same port. Connections negotiating this ALPN are routed to the stream handler.
|
||||
24
agent/pkg/agent/stream/common.go
Normal file
24
agent/pkg/agent/stream/common.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
"github.com/yusing/goutils/synk"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 10 * time.Second
|
||||
readDeadline = 10 * time.Second
|
||||
)
|
||||
|
||||
// StreamALPN is the TLS ALPN protocol id used to multiplex the TCP stream tunnel
|
||||
// and the HTTPS API on the same TCP port.
|
||||
//
|
||||
// When a client negotiates this ALPN, the agent will route the connection to the
|
||||
// stream tunnel handler instead of the HTTP handler.
|
||||
const StreamALPN = "godoxy-agent-stream/1"
|
||||
|
||||
var dTLSCipherSuites = []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}
|
||||
|
||||
var sizedPool = synk.GetSizedBytesPool()
|
||||
116
agent/pkg/agent/stream/header.go
Normal file
116
agent/pkg/agent/stream/header.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
versionSize = 8
|
||||
hostSize = 255
|
||||
portSize = 5
|
||||
flagSize = 1
|
||||
checksumSize = 4 // crc32 checksum
|
||||
|
||||
headerSize = versionSize + 1 + hostSize + 1 + portSize + flagSize + checksumSize
|
||||
)
|
||||
|
||||
var version = [versionSize]byte{'0', '.', '1', '.', '0', 0, 0, 0}
|
||||
|
||||
var ErrInvalidHeader = errors.New("invalid header")
|
||||
var ErrCloseImmediately = errors.New("close immediately")
|
||||
|
||||
type FlagType uint8
|
||||
|
||||
const FlagCloseImmediately FlagType = 1 << iota
|
||||
|
||||
type StreamRequestHeader struct {
|
||||
Version [versionSize]byte
|
||||
|
||||
HostLength byte
|
||||
Host [hostSize]byte
|
||||
|
||||
PortLength byte
|
||||
Port [portSize]byte
|
||||
|
||||
Flag FlagType
|
||||
Checksum [checksumSize]byte
|
||||
}
|
||||
|
||||
func init() {
|
||||
if headerSize != reflect.TypeFor[StreamRequestHeader]().Size() {
|
||||
panic("headerSize does not match the size of StreamRequestHeader")
|
||||
}
|
||||
}
|
||||
|
||||
func NewStreamRequestHeader(host, port string) (*StreamRequestHeader, error) {
|
||||
if len(host) > hostSize {
|
||||
return nil, fmt.Errorf("host is too long: max %d characters, got %d", hostSize, len(host))
|
||||
}
|
||||
if len(port) > portSize {
|
||||
return nil, fmt.Errorf("port is too long: max %d characters, got %d", portSize, len(port))
|
||||
}
|
||||
header := &StreamRequestHeader{}
|
||||
copy(header.Version[:], version[:])
|
||||
header.HostLength = byte(len(host))
|
||||
copy(header.Host[:], host)
|
||||
header.PortLength = byte(len(port))
|
||||
copy(header.Port[:], port)
|
||||
header.updateChecksum()
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func NewStreamHealthCheckHeader() *StreamRequestHeader {
|
||||
header := &StreamRequestHeader{}
|
||||
copy(header.Version[:], version[:])
|
||||
header.Flag |= FlagCloseImmediately
|
||||
header.updateChecksum()
|
||||
return header
|
||||
}
|
||||
|
||||
func ToHeader(buf [headerSize]byte) *StreamRequestHeader {
|
||||
return (*StreamRequestHeader)(unsafe.Pointer(&buf[0]))
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) GetHostPort() (string, string) {
|
||||
return string(h.Host[:h.HostLength]), string(h.Port[:h.PortLength])
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) Validate() bool {
|
||||
if h.Version != version {
|
||||
return false
|
||||
}
|
||||
if h.HostLength > hostSize {
|
||||
return false
|
||||
}
|
||||
if h.PortLength > portSize {
|
||||
return false
|
||||
}
|
||||
return h.validateChecksum()
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) ShouldCloseImmediately() bool {
|
||||
return h.Flag&FlagCloseImmediately != 0
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) updateChecksum() {
|
||||
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
|
||||
binary.BigEndian.PutUint32(h.Checksum[:], checksum)
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) validateChecksum() bool {
|
||||
checksum := crc32.ChecksumIEEE(h.BytesWithoutChecksum())
|
||||
return checksum == binary.BigEndian.Uint32(h.Checksum[:])
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) BytesWithoutChecksum() []byte {
|
||||
return (*[headerSize - checksumSize]byte)(unsafe.Pointer(h))[:]
|
||||
}
|
||||
|
||||
func (h *StreamRequestHeader) Bytes() []byte {
|
||||
return (*[headerSize]byte)(unsafe.Pointer(h))[:]
|
||||
}
|
||||
26
agent/pkg/agent/stream/payload_test.go
Normal file
26
agent/pkg/agent/stream/payload_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStreamRequestHeader_RoundTripAndChecksum(t *testing.T) {
|
||||
h, err := NewStreamRequestHeader("example.com", "443")
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamRequestHeader: %v", err)
|
||||
}
|
||||
if !h.Validate() {
|
||||
t.Fatalf("expected header to validate")
|
||||
}
|
||||
|
||||
var buf [headerSize]byte
|
||||
copy(buf[:], h.Bytes())
|
||||
h2 := ToHeader(buf)
|
||||
if !h2.Validate() {
|
||||
t.Fatalf("expected round-tripped header to validate")
|
||||
}
|
||||
host, port := h2.GetHostPort()
|
||||
if host != "example.com" || port != "443" {
|
||||
t.Fatalf("unexpected host/port: %q:%q", host, port)
|
||||
}
|
||||
}
|
||||
122
agent/pkg/agent/stream/tcp_client.go
Normal file
122
agent/pkg/agent/stream/tcp_client.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/common"
|
||||
)
|
||||
|
||||
type TCPClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// NewTCPClient creates a new TCP client for the agent.
|
||||
//
|
||||
// It will establish a TLS connection and send a stream request header to the server.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the target address is invalid
|
||||
// - the stream request header is invalid
|
||||
// - the TLS configuration is invalid
|
||||
// - the TLS connection fails
|
||||
// - the stream request header is not sent
|
||||
func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(targetAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := NewStreamRequestHeader(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newTCPClientWIthHeader(serverAddr, header, caCert, clientCert)
|
||||
}
|
||||
|
||||
func TCPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error {
|
||||
header := NewStreamHealthCheckHeader()
|
||||
|
||||
conn, err := newTCPClientWIthHeader(serverAddr, header, caCert, clientCert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTCPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||
// Setup TLS configuration
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*clientCert},
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
NextProtos: []string{StreamALPN},
|
||||
ServerName: common.CertsDNSName,
|
||||
}
|
||||
|
||||
// Establish TLS connection
|
||||
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: dialTimeout}, "tcp", serverAddr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send the stream header once as a handshake.
|
||||
if _, err := conn.Write(header.Bytes()); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TCPClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *TCPClient) Read(p []byte) (n int, err error) {
|
||||
return c.conn.Read(p)
|
||||
}
|
||||
|
||||
func (c *TCPClient) Write(p []byte) (n int, err error) {
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *TCPClient) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *TCPClient) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *TCPClient) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *TCPClient) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *TCPClient) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *TCPClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// ConnectionState exposes the underlying TLS connection state when the client is
|
||||
// backed by *tls.Conn.
|
||||
//
|
||||
// This is primarily used by tests and diagnostics.
|
||||
func (c *TCPClient) ConnectionState() tls.ConnectionState {
|
||||
if tc, ok := c.conn.(*tls.Conn); ok {
|
||||
return tc.ConnectionState()
|
||||
}
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
176
agent/pkg/agent/stream/tcp_server.go
Normal file
176
agent/pkg/agent/stream/tcp_server.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
)
|
||||
|
||||
type TCPServer struct {
|
||||
ctx context.Context
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewTCPServerHandler creates a TCP stream server that can serve already-accepted
|
||||
// connections (e.g. handed off by an ALPN multiplexer).
|
||||
//
|
||||
// This variant does not require a listener. Use TCPServer.ServeConn to handle
|
||||
// each incoming stream connection.
|
||||
func NewTCPServerHandler(ctx context.Context) *TCPServer {
|
||||
s := &TCPServer{ctx: ctx}
|
||||
return s
|
||||
}
|
||||
|
||||
// NewTCPServerFromListener creates a TCP stream server from an already-prepared
|
||||
// listener.
|
||||
//
|
||||
// The listener is expected to yield connections that are already secured (e.g.
|
||||
// a TLS/mTLS listener, or pre-handshaked *tls.Conn). This is used when the agent
|
||||
// multiplexes HTTPS and stream-tunnel traffic on the same port.
|
||||
func NewTCPServerFromListener(ctx context.Context, listener net.Listener) *TCPServer {
|
||||
s := &TCPServer{
|
||||
ctx: ctx,
|
||||
listener: listener,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{*serverCert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
NextProtos: []string{StreamALPN},
|
||||
}
|
||||
|
||||
tcpListener := tls.NewListener(listener, tlsConfig)
|
||||
return NewTCPServerFromListener(ctx, tcpListener)
|
||||
}
|
||||
|
||||
func (s *TCPServer) Start() error {
|
||||
if s.listener == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
context.AfterFunc(s.ctx, func() {
|
||||
_ = s.listener.Close()
|
||||
})
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) && s.ctx.Err() != nil {
|
||||
return s.ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
go s.handle(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeConn serves a single stream connection.
|
||||
//
|
||||
// The provided connection is expected to be already secured (TLS/mTLS) and to
|
||||
// speak the stream protocol (i.e. the client will send the stream header first).
|
||||
//
|
||||
// This method blocks until the stream finishes.
|
||||
func (s *TCPServer) ServeConn(conn net.Conn) {
|
||||
s.handle(conn)
|
||||
}
|
||||
|
||||
func (s *TCPServer) Addr() net.Addr {
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
func (s *TCPServer) Close() error {
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *TCPServer) logger(clientConn net.Conn) *zerolog.Logger {
|
||||
ev := log.With().Str("protocol", "tcp").
|
||||
Str("remote", clientConn.RemoteAddr().String())
|
||||
if s.listener != nil {
|
||||
ev = ev.Str("addr", s.listener.Addr().String())
|
||||
}
|
||||
l := ev.Logger()
|
||||
return &l
|
||||
}
|
||||
|
||||
func (s *TCPServer) loggerWithDst(dstConn net.Conn, clientConn net.Conn) *zerolog.Logger {
|
||||
ev := log.With().Str("protocol", "tcp").
|
||||
Str("remote", clientConn.RemoteAddr().String()).
|
||||
Str("dst", dstConn.RemoteAddr().String())
|
||||
if s.listener != nil {
|
||||
ev = ev.Str("addr", s.listener.Addr().String())
|
||||
}
|
||||
l := ev.Logger()
|
||||
return &l
|
||||
}
|
||||
|
||||
func (s *TCPServer) handle(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
dst, err := s.redirect(conn)
|
||||
if err != nil {
|
||||
// Health check probe: close connection
|
||||
if errors.Is(err, ErrCloseImmediately) {
|
||||
s.logger(conn).Info().Msg("Health check received")
|
||||
return
|
||||
}
|
||||
s.logger(conn).Err(err).Msg("failed to redirect connection")
|
||||
return
|
||||
}
|
||||
|
||||
defer dst.Close()
|
||||
pipe := ioutils.NewBidirectionalPipe(s.ctx, conn, dst)
|
||||
err = pipe.Start()
|
||||
if err != nil {
|
||||
s.loggerWithDst(dst, conn).Err(err).Msg("failed to start bidirectional pipe")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TCPServer) redirect(conn net.Conn) (net.Conn, error) {
|
||||
// Read the stream header once as a handshake.
|
||||
var headerBuf [headerSize]byte
|
||||
if _, err := io.ReadFull(conn, headerBuf[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header := ToHeader(headerBuf)
|
||||
if !header.Validate() {
|
||||
return nil, ErrInvalidHeader
|
||||
}
|
||||
|
||||
// Health check: close immediately if FlagCloseImmediately is set
|
||||
if header.ShouldCloseImmediately() {
|
||||
return nil, ErrCloseImmediately
|
||||
}
|
||||
|
||||
// get destination connection
|
||||
host, port := header.GetHostPort()
|
||||
return s.createDestConnection(host, port)
|
||||
}
|
||||
|
||||
func (s *TCPServer) createDestConnection(host, port string) (net.Conn, error) {
|
||||
addr := net.JoinHostPort(host, port)
|
||||
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
26
agent/pkg/agent/stream/tests/healthcheck_test.go
Normal file
26
agent/pkg/agent/stream/tests/healthcheck_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package stream_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
)
|
||||
|
||||
func TestTCPHealthCheck(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
srv := startTCPServer(t, certs)
|
||||
|
||||
err := stream.TCPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert)
|
||||
require.NoError(t, err, "health check")
|
||||
}
|
||||
|
||||
func TestUDPHealthCheck(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
srv := startUDPServer(t, certs)
|
||||
|
||||
err := stream.UDPHealthCheck(srv.Addr.String(), certs.CaCert, certs.ClientCert)
|
||||
require.NoError(t, err, "health check")
|
||||
}
|
||||
94
agent/pkg/agent/stream/tests/mux_test.go
Normal file
94
agent/pkg/agent/stream/tests/mux_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package stream_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/common"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
)
|
||||
|
||||
func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
baseLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
require.NoError(t, err, "listen tcp")
|
||||
defer baseLn.Close()
|
||||
baseAddr := baseLn.Addr().String()
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(certs.CaCert)
|
||||
|
||||
serverTLS := &tls.Config{
|
||||
Certificates: []tls.Certificate{*certs.SrvCert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
NextProtos: []string{"http/1.1", stream.StreamALPN},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
streamSrv := stream.NewTCPServerHandler(ctx)
|
||||
defer func() { _ = streamSrv.Close() }()
|
||||
|
||||
tlsLn := tls.NewListener(baseLn, serverTLS)
|
||||
defer func() { _ = tlsLn.Close() }()
|
||||
|
||||
// HTTP server
|
||||
httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}),
|
||||
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
|
||||
stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
|
||||
streamSrv.ServeConn(conn)
|
||||
},
|
||||
},
|
||||
}
|
||||
go func() { _ = httpSrv.Serve(tlsLn) }()
|
||||
defer func() { _ = httpSrv.Close() }()
|
||||
|
||||
// Stream destination
|
||||
dstAddr, closeDst := startTCPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
// HTTP client over the same port
|
||||
clientTLS := &tls.Config{
|
||||
Certificates: []tls.Certificate{*certs.ClientCert},
|
||||
RootCAs: caCertPool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
ServerName: common.CertsDNSName,
|
||||
}
|
||||
hc, err := tls.Dial("tcp", baseAddr, clientTLS)
|
||||
require.NoError(t, err, "dial https")
|
||||
defer hc.Close()
|
||||
_ = hc.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
_, err = hc.Write([]byte("GET / HTTP/1.1\r\nHost: godoxy-agent\r\n\r\n"))
|
||||
require.NoError(t, err, "write http request")
|
||||
r := bufio.NewReader(hc)
|
||||
statusLine, err := r.ReadString('\n')
|
||||
require.NoError(t, err, "read status line")
|
||||
require.Contains(t, statusLine, "200", "expected 200")
|
||||
|
||||
// Stream client over the same port
|
||||
client := NewTCPClient(t, baseAddr, dstAddr, certs)
|
||||
defer client.Close()
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over mux")
|
||||
_, err = client.Write(msg)
|
||||
require.NoError(t, err, "write stream payload")
|
||||
buf := make([]byte, len(msg))
|
||||
_, err = io.ReadFull(client, buf)
|
||||
require.NoError(t, err, "read stream payload")
|
||||
require.Equal(t, msg, buf)
|
||||
}
|
||||
201
agent/pkg/agent/stream/tests/server_flow_test.go
Normal file
201
agent/pkg/agent/stream/tests/server_flow_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package stream_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
)
|
||||
|
||||
func TestTCPServer_FullFlow(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
dstAddr, closeDst := startTCPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
srv := startTCPServer(t, certs)
|
||||
|
||||
client := NewTCPClient(t, srv.Addr.String(), dstAddr, certs)
|
||||
defer client.Close()
|
||||
|
||||
// Ensure ALPN is negotiated as expected (required for multiplexing).
|
||||
withState, ok := client.(interface{ ConnectionState() tls.ConnectionState })
|
||||
require.True(t, ok, "tcp client should expose TLS connection state")
|
||||
require.Equal(t, stream.StreamALPN, withState.ConnectionState().NegotiatedProtocol)
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over tcp")
|
||||
_, err := client.Write(msg)
|
||||
require.NoError(t, err, "write to client")
|
||||
|
||||
buf := make([]byte, len(msg))
|
||||
_, err = io.ReadFull(client, buf)
|
||||
require.NoError(t, err, "read from client")
|
||||
require.Equal(t, string(msg), string(buf), "unexpected echo")
|
||||
}
|
||||
|
||||
func TestTCPServer_ConcurrentConnections(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
dstAddr, closeDst := startTCPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
srv := startTCPServer(t, certs)
|
||||
|
||||
const nClients = 25
|
||||
|
||||
errs := make(chan error, nClients)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(nClients)
|
||||
|
||||
for i := range nClients {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
client := NewTCPClient(t, srv.Addr.String(), dstAddr, certs)
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := fmt.Appendf(nil, "ping over tcp %d", i)
|
||||
if _, err := client.Write(msg); err != nil {
|
||||
errs <- fmt.Errorf("write to client: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, len(msg))
|
||||
if _, err := io.ReadFull(client, buf); err != nil {
|
||||
errs <- fmt.Errorf("read from client: %w", err)
|
||||
return
|
||||
}
|
||||
if string(msg) != string(buf) {
|
||||
errs <- fmt.Errorf("unexpected echo: got=%q want=%q", string(buf), string(msg))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPServer_RejectInvalidClient(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
// Generate a self-signed client cert that is NOT signed by the CA
|
||||
_, _, invalidClientPEM, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate invalid client certs")
|
||||
invalidClientCert, err := invalidClientPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse invalid client cert")
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
srv := startUDPServer(t, certs)
|
||||
|
||||
|
||||
// Try to connect with a client cert from a different CA
|
||||
_, err = stream.NewUDPClient(srv.Addr.String(), dstAddr, certs.CaCert, invalidClientCert)
|
||||
require.Error(t, err, "expected error when connecting with client cert from different CA")
|
||||
|
||||
var handshakeErr *dtls.HandshakeError
|
||||
require.ErrorAs(t, err, &handshakeErr, "expected handshake error")
|
||||
}
|
||||
|
||||
func TestUDPServer_RejectClientWithoutCert(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
srv := startUDPServer(t, certs)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Try to connect without any client certificate
|
||||
// Create a TLS cert without a private key to simulate no client cert
|
||||
emptyCert := &tls.Certificate{}
|
||||
_, err := stream.NewUDPClient(srv.Addr.String(), dstAddr, certs.CaCert, emptyCert)
|
||||
require.Error(t, err, "expected error when connecting without client cert")
|
||||
|
||||
require.ErrorContains(t, err, "no certificate provided", "expected no cert error")
|
||||
}
|
||||
|
||||
func TestUDPServer_FullFlow(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
srv := startUDPServer(t, certs)
|
||||
|
||||
client := NewUDPClient(t, srv.Addr.String(), dstAddr, certs)
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over udp")
|
||||
_, err := client.Write(msg)
|
||||
require.NoError(t, err, "write to client")
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := client.Read(buf)
|
||||
require.NoError(t, err, "read from client")
|
||||
require.Equal(t, string(msg), string(buf[:n]), "unexpected echo")
|
||||
}
|
||||
|
||||
func TestUDPServer_ConcurrentConnections(t *testing.T) {
|
||||
certs := genTestCerts(t)
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
srv := startUDPServer(t, certs)
|
||||
|
||||
const nClients = 25
|
||||
|
||||
errs := make(chan error, nClients)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(nClients)
|
||||
|
||||
for i := range nClients {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
client := NewUDPClient(t, srv.Addr.String(), dstAddr, certs)
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
msg := fmt.Appendf(nil, "ping over udp %d", i)
|
||||
if _, err := client.Write(msg); err != nil {
|
||||
errs <- fmt.Errorf("write to client: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := client.Read(buf)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("read from client: %w", err)
|
||||
return
|
||||
}
|
||||
if string(msg) != string(buf[:n]) {
|
||||
errs <- fmt.Errorf("unexpected echo: got=%q want=%q", string(buf[:n]), string(msg))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
for err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
177
agent/pkg/agent/stream/tests/testutils_test.go
Normal file
177
agent/pkg/agent/stream/tests/testutils_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package stream_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/transport/v3/udp"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
)
|
||||
|
||||
// CertBundle holds all certificates needed for testing.
|
||||
type CertBundle struct {
|
||||
CaCert *x509.Certificate
|
||||
SrvCert *tls.Certificate
|
||||
ClientCert *tls.Certificate
|
||||
}
|
||||
|
||||
// genTestCerts generates certificates for testing and returns them as a CertBundle.
|
||||
func genTestCerts(t *testing.T) CertBundle {
|
||||
t.Helper()
|
||||
|
||||
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate agent certs")
|
||||
|
||||
caCert, err := caPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse CA cert")
|
||||
srvCert, err := srvPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse server cert")
|
||||
clientCert, err := clientPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse client cert")
|
||||
|
||||
return CertBundle{
|
||||
CaCert: caCert.Leaf,
|
||||
SrvCert: srvCert,
|
||||
ClientCert: clientCert,
|
||||
}
|
||||
}
|
||||
|
||||
// startTCPEcho starts a TCP echo server and returns its address and close function.
|
||||
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "listen tcp")
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
_, _ = io.Copy(conn, conn)
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
|
||||
return ln.Addr().String(), func() {
|
||||
_ = ln.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// startUDPEcho starts a UDP echo server and returns its address and close function.
|
||||
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
t.Helper()
|
||||
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err, "listen udp")
|
||||
uc := pc.(*net.UDPConn)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
buf := make([]byte, 65535)
|
||||
for {
|
||||
n, raddr, err := uc.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = uc.WriteToUDP(buf[:n], raddr)
|
||||
}
|
||||
}()
|
||||
|
||||
return uc.LocalAddr().String(), func() {
|
||||
_ = uc.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer wraps a server with its startup goroutine for cleanup.
|
||||
type TestServer struct {
|
||||
Server interface{ Close() error }
|
||||
Addr net.Addr
|
||||
}
|
||||
|
||||
// startTCPServer starts a TCP server and returns a TestServer for cleanup.
|
||||
func startTCPServer(t *testing.T, certs CertBundle) TestServer {
|
||||
t.Helper()
|
||||
|
||||
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
require.NoError(t, err, "listen tcp")
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
srv := stream.NewTCPServer(ctx, tcpLn, certs.CaCert, certs.SrvCert)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
|
||||
t.Logf("tcp server exit: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
return TestServer{
|
||||
Server: srv,
|
||||
Addr: srv.Addr(),
|
||||
}
|
||||
}
|
||||
|
||||
// startUDPServer starts a UDP server and returns a TestServer for cleanup.
|
||||
func startUDPServer(t *testing.T, certs CertBundle) TestServer {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
|
||||
srv := stream.NewUDPServer(ctx, "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, certs.CaCert, certs.SrvCert)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) && !errors.Is(err, udp.ErrClosedListener) {
|
||||
t.Logf("udp server exit: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
return TestServer{
|
||||
Server: srv,
|
||||
Addr: srv.Addr(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTCPClient creates a TCP client connected to the server with test certificates.
|
||||
func NewTCPClient(t *testing.T, serverAddr, targetAddress string, certs CertBundle) net.Conn {
|
||||
t.Helper()
|
||||
client, err := stream.NewTCPClient(serverAddr, targetAddress, certs.CaCert, certs.ClientCert)
|
||||
require.NoError(t, err, "create tcp client")
|
||||
return client
|
||||
}
|
||||
|
||||
// NewUDPClient creates a UDP client connected to the server with test certificates.
|
||||
func NewUDPClient(t *testing.T, serverAddr, targetAddress string, certs CertBundle) net.Conn {
|
||||
t.Helper()
|
||||
client, err := stream.NewUDPClient(serverAddr, targetAddress, certs.CaCert, certs.ClientCert)
|
||||
require.NoError(t, err, "create udp client")
|
||||
return client
|
||||
}
|
||||
118
agent/pkg/agent/stream/udp_client.go
Normal file
118
agent/pkg/agent/stream/udp_client.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/common"
|
||||
)
|
||||
|
||||
type UDPClient struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// NewUDPClient creates a new UDP client for the agent.
|
||||
//
|
||||
// It will establish a DTLS connection and send a stream request header to the server.
|
||||
//
|
||||
// It returns an error if
|
||||
// - the target address is invalid
|
||||
// - the stream request header is invalid
|
||||
// - the DTLS configuration is invalid
|
||||
// - the DTLS connection fails
|
||||
// - the stream request header is not sent
|
||||
func NewUDPClient(serverAddr, targetAddress string, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(targetAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := NewStreamRequestHeader(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newUDPClientWIthHeader(serverAddr, header, caCert, clientCert)
|
||||
}
|
||||
|
||||
func newUDPClientWIthHeader(serverAddr string, header *StreamRequestHeader, caCert *x509.Certificate, clientCert *tls.Certificate) (net.Conn, error) {
|
||||
// Setup DTLS configuration
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
dtlsConfig := &dtls.Config{
|
||||
Certificates: []tls.Certificate{*clientCert},
|
||||
RootCAs: caCertPool,
|
||||
InsecureSkipVerify: false,
|
||||
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||
ServerName: common.CertsDNSName,
|
||||
CipherSuites: dTLSCipherSuites,
|
||||
}
|
||||
|
||||
raddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Establish DTLS connection
|
||||
conn, err := dtls.Dial("udp", raddr, dtlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send the stream header once as a handshake.
|
||||
if _, err := conn.Write(header.Bytes()); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &UDPClient{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func UDPHealthCheck(serverAddr string, caCert *x509.Certificate, clientCert *tls.Certificate) error {
|
||||
header := NewStreamHealthCheckHeader()
|
||||
|
||||
conn, err := newUDPClientWIthHeader(serverAddr, header, caCert, clientCert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UDPClient) Read(p []byte) (n int, err error) {
|
||||
return c.conn.Read(p)
|
||||
}
|
||||
|
||||
func (c *UDPClient) Write(p []byte) (n int, err error) {
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *UDPClient) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *UDPClient) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *UDPClient) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *UDPClient) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *UDPClient) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *UDPClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
205
agent/pkg/agent/stream/udp_server.go
Normal file
205
agent/pkg/agent/stream/udp_server.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type UDPServer struct {
|
||||
ctx context.Context
|
||||
network string
|
||||
laddr *net.UDPAddr
|
||||
listener net.Listener
|
||||
|
||||
dtlsConfig *dtls.Config
|
||||
}
|
||||
|
||||
func NewUDPServer(ctx context.Context, network string, laddr *net.UDPAddr, caCert *x509.Certificate, serverCert *tls.Certificate) *UDPServer {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AddCert(caCert)
|
||||
|
||||
dtlsConfig := &dtls.Config{
|
||||
Certificates: []tls.Certificate{*serverCert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: dtls.RequireAndVerifyClientCert,
|
||||
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
||||
CipherSuites: dTLSCipherSuites,
|
||||
}
|
||||
|
||||
s := &UDPServer{
|
||||
ctx: ctx,
|
||||
network: network,
|
||||
laddr: laddr,
|
||||
dtlsConfig: dtlsConfig,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *UDPServer) Start() error {
|
||||
listener, err := dtls.Listen(s.network, s.laddr, s.dtlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
context.AfterFunc(s.ctx, func() {
|
||||
_ = s.listener.Close()
|
||||
})
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
// Expected error when context cancelled
|
||||
if errors.Is(err, net.ErrClosed) && s.ctx.Err() != nil {
|
||||
return s.ctx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
go s.handleDTLSConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPServer) Addr() net.Addr {
|
||||
if s.listener != nil {
|
||||
return s.listener.Addr()
|
||||
}
|
||||
return s.laddr
|
||||
}
|
||||
|
||||
func (s *UDPServer) Close() error {
|
||||
if s.listener != nil {
|
||||
return s.listener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UDPServer) logger(clientConn net.Conn) *zerolog.Logger {
|
||||
l := log.With().Str("protocol", "udp").
|
||||
Str("addr", s.Addr().String()).
|
||||
Str("remote", clientConn.RemoteAddr().String()).Logger()
|
||||
return &l
|
||||
}
|
||||
|
||||
func (s *UDPServer) loggerWithDst(clientConn net.Conn, dstConn *net.UDPConn) *zerolog.Logger {
|
||||
l := log.With().Str("protocol", "udp").
|
||||
Str("addr", s.Addr().String()).
|
||||
Str("remote", clientConn.RemoteAddr().String()).
|
||||
Str("dst", dstConn.RemoteAddr().String()).Logger()
|
||||
return &l
|
||||
}
|
||||
|
||||
func (s *UDPServer) handleDTLSConnection(clientConn net.Conn) {
|
||||
defer clientConn.Close()
|
||||
|
||||
// Read the stream header once as a handshake.
|
||||
var headerBuf [headerSize]byte
|
||||
if _, err := io.ReadFull(clientConn, headerBuf[:]); err != nil {
|
||||
s.logger(clientConn).Err(err).Msg("failed to read stream header")
|
||||
return
|
||||
}
|
||||
header := ToHeader(headerBuf)
|
||||
if !header.Validate() {
|
||||
s.logger(clientConn).Error().Bytes("header", headerBuf[:]).Msg("invalid stream header received")
|
||||
return
|
||||
}
|
||||
|
||||
// Health check probe: close connection
|
||||
if header.ShouldCloseImmediately() {
|
||||
s.logger(clientConn).Info().Msg("Health check received")
|
||||
return
|
||||
}
|
||||
|
||||
host, port := header.GetHostPort()
|
||||
dstConn, err := s.createDestConnection(host, port)
|
||||
if err != nil {
|
||||
s.logger(clientConn).Err(err).Msg("failed to get or create destination connection")
|
||||
return
|
||||
}
|
||||
defer dstConn.Close()
|
||||
|
||||
go s.forwardFromDestination(dstConn, clientConn)
|
||||
|
||||
buf := sizedPool.GetSized(65535)
|
||||
defer sizedPool.Put(buf)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := clientConn.Read(buf)
|
||||
// Per net.Conn contract, Read may return (n > 0, err == io.EOF).
|
||||
// Always forward any bytes we got before acting on the error.
|
||||
if n > 0 {
|
||||
if _, werr := dstConn.Write(buf[:n]); werr != nil {
|
||||
s.logger(clientConn).Err(werr).Msgf("failed to write %d bytes to destination", n)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Expected shutdown paths.
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
s.logger(clientConn).Err(err).Msg("failed to read from client")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UDPServer) createDestConnection(host, port string) (*net.UDPConn, error) {
|
||||
addr := net.JoinHostPort(host, port)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dstConn, err := net.DialUDP("udp", nil, udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dstConn, nil
|
||||
}
|
||||
|
||||
func (s *UDPServer) forwardFromDestination(dstConn *net.UDPConn, clientConn net.Conn) {
|
||||
buffer := sizedPool.GetSized(65535)
|
||||
defer sizedPool.Put(buffer)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
_ = dstConn.SetReadDeadline(time.Now().Add(readDeadline))
|
||||
n, err := dstConn.Read(buffer)
|
||||
if err != nil {
|
||||
// The destination socket can be closed when the client disconnects (e.g. during
|
||||
// the stream support probe in AgentConfig.StartWithCerts). Treat that as a
|
||||
// normal exit and avoid noisy logs.
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
s.loggerWithDst(clientConn, dstConn).Err(err).Msg("failed to read from destination")
|
||||
return
|
||||
}
|
||||
if _, err := clientConn.Write(buffer[:n]); err != nil {
|
||||
s.loggerWithDst(clientConn, dstConn).Err(err).Msgf("failed to write %d bytes to client", n)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
services:
|
||||
agent:
|
||||
image: "{{.Image}}"
|
||||
container_name: godoxy-agent
|
||||
restart: always
|
||||
network_mode: host # do not change this
|
||||
environment:
|
||||
AGENT_NAME: "{{.Name}}"
|
||||
AGENT_PORT: "{{.Port}}"
|
||||
AGENT_CA_CERT: "{{.CACert}}"
|
||||
AGENT_SSL_CERT: "{{.SSLCert}}"
|
||||
# use agent as a docker socket proxy: [host]:port
|
||||
# set LISTEN_ADDR to enable (e.g. 127.0.0.1:2375)
|
||||
LISTEN_ADDR:
|
||||
POST: false
|
||||
ALLOW_RESTARTS: false
|
||||
ALLOW_START: false
|
||||
ALLOW_STOP: false
|
||||
AUTH: false
|
||||
BUILD: false
|
||||
COMMIT: false
|
||||
CONFIGS: false
|
||||
CONTAINERS: false
|
||||
DISTRIBUTION: false
|
||||
EVENTS: true
|
||||
EXEC: false
|
||||
GRPC: false
|
||||
IMAGES: false
|
||||
INFO: false
|
||||
NETWORKS: false
|
||||
NODES: false
|
||||
PING: true
|
||||
PLUGINS: false
|
||||
SECRETS: false
|
||||
SERVICES: false
|
||||
SESSION: false
|
||||
SWARM: false
|
||||
SYSTEM: false
|
||||
TASKS: false
|
||||
VERSION: true
|
||||
VOLUMES: false
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./data:/app/data
|
||||
@@ -5,7 +5,8 @@ services:
|
||||
restart: always
|
||||
{{ if eq .ContainerRuntime "podman" -}}
|
||||
ports:
|
||||
- "{{.Port}}:{{.Port}}"
|
||||
- "{{.Port}}:{{.Port}}/tcp"
|
||||
- "{{.Port}}:{{.Port}}/udp"
|
||||
{{ else -}}
|
||||
network_mode: host # do not change this
|
||||
{{ end -}}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
@@ -44,14 +44,14 @@ func NewAgentHandler() http.Handler {
|
||||
}
|
||||
|
||||
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
|
||||
mux.HandleEndpoint("GET", agent.EndpointVersion, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, version.Get())
|
||||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointName, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, env.AgentName)
|
||||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointRuntime, func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, env.Runtime)
|
||||
mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) {
|
||||
agentInfo := agent.AgentInfo{
|
||||
Version: version.Get(),
|
||||
Name: env.AgentName,
|
||||
Runtime: env.Runtime,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo)
|
||||
})
|
||||
mux.HandleEndpoint("GET", agent.EndpointHealth, CheckHealth)
|
||||
mux.HandleEndpoint("GET", agent.EndpointSystemInfo, metricsHandler.ServeHTTP)
|
||||
|
||||
3
go.mod
3
go.mod
@@ -172,6 +172,9 @@ require (
|
||||
github.com/nrdcg/oci-go-sdk/common/v1065 v1065.105.2 // indirect
|
||||
github.com/nrdcg/oci-go-sdk/dns/v1065 v1065.105.2 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/logging v0.2.4 // indirect
|
||||
github.com/pion/transport/v3 v3.1.1 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/stretchr/objx v0.5.3 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -247,6 +247,12 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
|
||||
@@ -2356,6 +2356,16 @@
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"supports_tcp_stream": {
|
||||
"type": "boolean",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"supports_udp_stream": {
|
||||
"type": "boolean",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
@@ -2439,7 +2449,7 @@
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent": {
|
||||
"$ref": "#/definitions/Agent",
|
||||
"$ref": "#/definitions/agentpool.Agent",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
@@ -4909,6 +4919,43 @@
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"agentpool.Agent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"addr": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"runtime": {
|
||||
"$ref": "#/definitions/agent.ContainerRuntime",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"supports_tcp_stream": {
|
||||
"type": "boolean",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"supports_udp_stream": {
|
||||
"type": "boolean",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
}
|
||||
},
|
||||
"x-nullable": false,
|
||||
"x-omitempty": false
|
||||
},
|
||||
"auth.UserPassAuthCallbackRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -8,6 +8,10 @@ definitions:
|
||||
type: string
|
||||
runtime:
|
||||
$ref: '#/definitions/agent.ContainerRuntime'
|
||||
supports_tcp_stream:
|
||||
type: boolean
|
||||
supports_udp_stream:
|
||||
type: boolean
|
||||
version:
|
||||
type: string
|
||||
type: object
|
||||
@@ -48,7 +52,7 @@ definitions:
|
||||
Container:
|
||||
properties:
|
||||
agent:
|
||||
$ref: '#/definitions/Agent'
|
||||
$ref: '#/definitions/agentpool.Agent'
|
||||
aliases:
|
||||
items:
|
||||
type: string
|
||||
@@ -1236,6 +1240,21 @@ definitions:
|
||||
x-enum-varnames:
|
||||
- ContainerRuntimeDocker
|
||||
- ContainerRuntimePodman
|
||||
agentpool.Agent:
|
||||
properties:
|
||||
addr:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
runtime:
|
||||
$ref: '#/definitions/agent.ContainerRuntime'
|
||||
supports_tcp_stream:
|
||||
type: boolean
|
||||
supports_udp_stream:
|
||||
type: boolean
|
||||
version:
|
||||
type: string
|
||||
type: object
|
||||
auth.UserPassAuthCallbackRequest:
|
||||
properties:
|
||||
password:
|
||||
|
||||
@@ -116,9 +116,9 @@ func (r *StreamRoute) initStream() (nettypes.Stream, error) {
|
||||
|
||||
switch rScheme {
|
||||
case "tcp":
|
||||
return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host)
|
||||
return stream.NewTCPTCPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host, r.GetAgent())
|
||||
case "udp":
|
||||
return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host)
|
||||
return stream.NewUDPUDPStream(lurl.Scheme, rurl.Scheme, laddr, rurl.Host, r.GetAgent())
|
||||
}
|
||||
return nil, fmt.Errorf("unknown scheme: %s", rurl.Scheme)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/godoxy/internal/acl"
|
||||
"github.com/yusing/godoxy/internal/agentpool"
|
||||
"github.com/yusing/godoxy/internal/entrypoint"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
ioutils "github.com/yusing/goutils/io"
|
||||
@@ -21,6 +22,7 @@ type TCPTCPStream struct {
|
||||
|
||||
laddr *net.TCPAddr
|
||||
dst *net.TCPAddr
|
||||
agent *agentpool.Agent
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
@@ -28,7 +30,7 @@ type TCPTCPStream struct {
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveTCPAddr(dstNetwork, dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -37,7 +39,7 @@ func NewTCPTCPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst}, nil
|
||||
return &TCPTCPStream{network: network, dstNetwork: dstNetwork, laddr: laddr, dst: dst, agent: agent}, nil
|
||||
}
|
||||
|
||||
func (s *TCPTCPStream) ListenAndServe(ctx context.Context, preDial, onRead nettypes.HookFunc) {
|
||||
@@ -130,7 +132,15 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
dstConn, err := net.DialTCP(s.dstNetwork, nil, s.dst)
|
||||
var (
|
||||
dstConn net.Conn
|
||||
err error
|
||||
)
|
||||
if s.agent != nil {
|
||||
dstConn, err = s.agent.NewTCPClient(s.dst.String())
|
||||
} else {
|
||||
dstConn, err = net.DialTCP(s.dstNetwork, nil, s.dst)
|
||||
}
|
||||
if err != nil {
|
||||
if !s.closed.Load() {
|
||||
logErr(s, err, "failed to dial destination")
|
||||
@@ -144,7 +154,7 @@ func (s *TCPTCPStream) handle(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
src := conn
|
||||
dst := net.Conn(dstConn)
|
||||
dst := dstConn
|
||||
if s.onRead != nil {
|
||||
src = &wrapperConn{
|
||||
Conn: conn,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/yusing/godoxy/internal/acl"
|
||||
"github.com/yusing/godoxy/internal/agentpool"
|
||||
nettypes "github.com/yusing/godoxy/internal/net/types"
|
||||
"github.com/yusing/goutils/synk"
|
||||
"go.uber.org/atomic"
|
||||
@@ -24,6 +25,7 @@ type UDPUDPStream struct {
|
||||
|
||||
laddr *net.UDPAddr
|
||||
dst *net.UDPAddr
|
||||
agent *agentpool.Agent
|
||||
|
||||
preDial nettypes.HookFunc
|
||||
onRead nettypes.HookFunc
|
||||
@@ -53,7 +55,7 @@ const (
|
||||
|
||||
var bufPool = synk.GetSizedBytesPool()
|
||||
|
||||
func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.Stream, error) {
|
||||
func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string, agent *agentpool.Agent) (nettypes.Stream, error) {
|
||||
dst, err := net.ResolveUDPAddr(dstNetwork, dstAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -67,6 +69,7 @@ func NewUDPUDPStream(network, dstNetwork, listenAddr, dstAddr string) (nettypes.
|
||||
dstNetwork: dstNetwork,
|
||||
laddr: laddr,
|
||||
dst: dst,
|
||||
agent: agent,
|
||||
conns: make(map[string]*udpUDPConn),
|
||||
}, nil
|
||||
}
|
||||
@@ -195,7 +198,11 @@ func (s *UDPUDPStream) createConnection(ctx context.Context, srcAddr *net.UDPAdd
|
||||
dstConn net.Conn
|
||||
err error
|
||||
)
|
||||
dstConn, err = net.DialUDP(s.dstNetwork, nil, s.dst)
|
||||
if s.agent != nil {
|
||||
dstConn, err = s.agent.NewUDPClient(s.dst.String())
|
||||
} else {
|
||||
dstConn, err = net.DialUDP(s.dst.Network(), nil, s.dst)
|
||||
}
|
||||
if err != nil {
|
||||
logErr(s, err, "failed to dial dst")
|
||||
return nil, false
|
||||
|
||||
Reference in New Issue
Block a user