feat(agent/stream): remove STREAM_PORT and use tcp multiplexing on the same port

This commit is contained in:
yusing
2026-01-07 18:30:31 +08:00
parent cc406921cb
commit a605d56a4c
15 changed files with 410 additions and 81 deletions

View File

@@ -1,15 +1,21 @@
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"net/http"
"os"
"sync"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/agent/pkg/agent/stream"
"github.com/yusing/godoxy/agent/pkg/env"
"github.com/yusing/godoxy/agent/pkg/server"
"github.com/yusing/godoxy/agent/pkg/handler"
"github.com/yusing/godoxy/internal/metrics/systeminfo"
socketproxy "github.com/yusing/godoxy/socketproxy/pkg"
gperr "github.com/yusing/goutils/errs"
@@ -19,6 +25,53 @@ import (
"github.com/yusing/goutils/version"
)
var errListenerClosed = errors.New("listener closed")
type connQueueListener struct {
addr net.Addr
conns chan net.Conn
closed chan struct{}
closeOnce sync.Once
}
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
return &connQueueListener{
addr: addr,
conns: make(chan net.Conn, buffer),
closed: make(chan struct{}),
}
}
func (l *connQueueListener) push(conn net.Conn) error {
select {
case <-l.closed:
_ = conn.Close()
return errListenerClosed
case l.conns <- conn:
return nil
}
}
func (l *connQueueListener) Accept() (net.Conn, error) {
conn, ok := <-l.conns
if !ok {
return nil, errListenerClosed
}
return conn, nil
}
func (l *connQueueListener) Close() error {
l.closeOnce.Do(func() {
close(l.closed)
close(l.conns)
})
return nil
}
func (l *connQueueListener) Addr() net.Addr {
return l.addr
}
func main() {
writer := zerolog.ConsoleWriter{
Out: os.Stderr,
@@ -55,28 +108,129 @@ func main() {
Tips:
1. To change the agent name, you can set the AGENT_NAME environment variable.
2. To change the agent port, you can set the AGENT_PORT environment variable.
`)
`)
t := task.RootTask("agent", false)
opts := server.Options{
CACert: caCert,
ServerCert: srvCert,
Port: env.AgentPort,
}
server.StartAgentServer(t, opts)
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentStreamPort})
// One TCP listener on AGENT_PORT, then multiplex by TLS ALPN:
// - Stream ALPN: route to TCP stream tunnel handler
// - Otherwise: route to HTTPS API handler
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort})
if err != nil {
gperr.LogFatal("failed to listen on port", err)
}
tcpServer := stream.NewTCPServer(t.Context(), tcpListener, caCert.Leaf, srvCert)
go tcpServer.Start()
log.Info().Int("port", env.AgentStreamPort).Msg("TCP stream server started")
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentStreamPort}, caCert.Leaf, srvCert)
go udpServer.Start()
log.Info().Int("port", env.AgentStreamPort).Msg("UDP stream server started")
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert.Leaf)
muxTLSConfig := &tls.Config{
Certificates: []tls.Certificate{*srvCert},
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
// Keep HTTP limited to HTTP/1.1 (matching current agent server behavior)
// and add the stream tunnel ALPN for multiplexing.
NextProtos: []string{"http/1.1", stream.StreamALPN},
}
if env.AgentSkipClientCertCheck {
muxTLSConfig.ClientAuth = tls.NoClientCert
}
httpLn := newConnQueueListener(tcpListener.Addr(), 128)
streamLn := newConnQueueListener(tcpListener.Addr(), 128)
httpSrv := &http.Server{
Handler: handler.NewAgentHandler(),
BaseContext: func(net.Listener) context.Context {
return t.Context()
},
}
{
subtask := t.Subtask("agent-http", true)
t.OnCancel("stop_http", func() {
_ = httpSrv.Shutdown(context.Background())
_ = httpLn.Close()
})
go func() {
err := httpSrv.Serve(httpLn)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error().Err(err).Msg("agent HTTP server stopped with error")
}
subtask.Finish(err)
}()
log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started")
}
{
tcpServer := stream.NewTCPServerFromListener(t.Context(), streamLn)
subtask := t.Subtask("agent-stream-tcp", true)
t.OnCancel("stop_stream_tcp", func() {
_ = tcpServer.Close()
_ = streamLn.Close()
})
go func() {
err := tcpServer.Start()
subtask.Finish(err)
}()
log.Info().Int("port", env.AgentPort).Msg("TCP stream server started")
}
{
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert)
subtask := t.Subtask("agent-stream-udp", true)
t.OnCancel("stop_stream_udp", func() {
_ = udpServer.Close()
})
go func() {
err := udpServer.Start()
subtask.Finish(err)
}()
log.Info().Int("port", env.AgentPort).Msg("UDP stream server started")
}
// Accept raw TCP connections, terminate TLS once, and dispatch by ALPN.
{
subtask := t.Subtask("agent-tls-mux", true)
t.OnCancel("stop_mux", func() {
_ = tcpListener.Close()
_ = httpLn.Close()
_ = streamLn.Close()
})
go func() {
defer subtask.Finish(subtask.FinishCause())
for {
select {
case <-t.Context().Done():
return
default:
}
conn, err := tcpListener.Accept()
if err != nil {
if t.Context().Err() != nil {
return
}
log.Error().Err(err).Msg("failed to accept connection")
continue
}
tlsConn := tls.Server(conn, muxTLSConfig)
if err := tlsConn.HandshakeContext(t.Context()); err != nil {
_ = tlsConn.Close()
log.Debug().Err(err).Msg("TLS handshake failed")
continue
}
alpn := tlsConn.ConnectionState().NegotiatedProtocol
switch alpn {
case stream.StreamALPN:
_ = streamLn.push(tlsConn)
default:
_ = httpLn.push(tlsConn)
}
}
}()
}
if socketproxy.ListenAddr != "" {
runtime := strutils.Title(string(env.Runtime))

View File

@@ -8,7 +8,6 @@ import (
var (
installScript = `AGENT_NAME="{{.Name}}" \
AGENT_PORT="{{.Port}}" \
AGENT_STREAM_PORT="{{.StreamPort}}" \
AGENT_CA_CERT="{{.CACert}}" \
AGENT_SSL_CERT="{{.SSLCert}}" \
{{ if eq .ContainerRuntime "nerdctl" -}}

View File

@@ -11,7 +11,6 @@ import (
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
@@ -28,27 +27,26 @@ import (
type AgentConfig struct {
AgentInfo
Addr string `json:"addr"`
Addr string `json:"addr"`
IsTCPStreamSupported bool `json:"supports_tcp_stream"`
IsUDPStreamSupported bool `json:"supports_udp_stream"`
httpClient *http.Client
fasthttpClientHealthCheck *fasthttp.Client
tlsConfig tls.Config
// for stream
caCert *x509.Certificate
clientCert *tls.Certificate
isTCPStreamSupported bool
isUDPStreamSupported bool
streamServerAddr string
caCert *x509.Certificate
clientCert *tls.Certificate
streamServerAddr string
l zerolog.Logger
} // @name Agent
type AgentInfo struct {
Version version.Version `json:"version" swaggertype:"string"`
Name string `json:"name"`
Runtime ContainerRuntime `json:"runtime"`
StreamPort int `json:"stream_port,omitempty"`
Version version.Version `json:"version" swaggertype:"string"`
Name string `json:"name"`
Runtime ContainerRuntime `json:"runtime"`
}
// Deprecated. Replaced by EndpointInfo
@@ -154,14 +152,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
var streamUnsupportedErrs gperr.Builder
if status == http.StatusOK {
if cfg.StreamPort <= 0 {
return fmt.Errorf("invalid agent stream port: %d", cfg.StreamPort)
}
host, _, err := net.SplitHostPort(cfg.Addr)
if err != nil {
return err
}
cfg.streamServerAddr = net.JoinHostPort(host, strconv.Itoa(cfg.StreamPort))
cfg.streamServerAddr = cfg.Addr
// test stream server connection
const fakeAddress = "localhost:8080" // it won't be used, just for testing
@@ -171,7 +162,7 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
streamUnsupportedErrs.Addf("failed to connect to stream server via TCP: %w", err)
} else {
conn.Close()
cfg.isTCPStreamSupported = true
cfg.IsTCPStreamSupported = true
}
// test UDP stream support
@@ -180,13 +171,13 @@ func (cfg *AgentConfig) StartWithCerts(ctx context.Context, ca, crt, key []byte)
streamUnsupportedErrs.Addf("failed to connect to stream server via UDP: %w", err)
} else {
conn.Close()
cfg.isUDPStreamSupported = true
cfg.IsUDPStreamSupported = true
}
} else {
// old agent does not support EndpointInfo
// fallback with old logic
cfg.isTCPStreamSupported = false
cfg.isUDPStreamSupported = false
cfg.IsTCPStreamSupported = false
cfg.IsUDPStreamSupported = false
streamUnsupportedErrs.Adds("agent version is too old, does not support stream tunneling")
// get agent name
@@ -262,7 +253,7 @@ func (cfg *AgentConfig) NewTCPClient(targetAddress string) (net.Conn, error) {
if cfg.caCert == nil || cfg.clientCert == nil {
return nil, errors.New("agent is not initialized")
}
if !cfg.isTCPStreamSupported {
if !cfg.IsTCPStreamSupported {
return nil, errors.New("agent does not support TCP stream tunneling")
}
serverAddr, err := cfg.getStreamServerAddr()
@@ -282,7 +273,7 @@ func (cfg *AgentConfig) NewUDPClient(targetAddress string) (net.Conn, error) {
if cfg.caCert == nil || cfg.clientCert == nil {
return nil, errors.New("agent is not initialized")
}
if !cfg.isUDPStreamSupported {
if !cfg.IsUDPStreamSupported {
return nil, errors.New("agent does not support UDP stream tunneling")
}
serverAddr, err := cfg.getStreamServerAddr()

View File

@@ -5,7 +5,6 @@ type (
AgentEnvConfig struct {
Name string
Port int
StreamPort int
CACert string
SSLCert string
ContainerRuntime ContainerRuntime

View File

@@ -13,6 +13,13 @@ const (
readDeadline = 10 * time.Second
)
// StreamALPN is the TLS ALPN protocol id used to multiplex the TCP stream tunnel
// and the HTTPS API on the same TCP port.
//
// When a client negotiates this ALPN, the agent will route the connection to the
// stream tunnel handler instead of the HTTP handler.
const StreamALPN = "godoxy-agent-stream/1"
var sizedPool = synk.GetSizedBytesPool()
type CreateConnFunc[Conn net.Conn] func(host, port string) (Conn, error)

View File

@@ -0,0 +1,168 @@
package stream_test
import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"net/http"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/agent/pkg/agent/common"
"github.com/yusing/godoxy/agent/pkg/agent/stream"
)
var errListenerClosed = errors.New("listener closed")
type connQueueListener struct {
addr net.Addr
conns chan net.Conn
closed chan struct{}
closeOnce sync.Once
}
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
return &connQueueListener{
addr: addr,
conns: make(chan net.Conn, buffer),
closed: make(chan struct{}),
}
}
func (l *connQueueListener) push(conn net.Conn) error {
select {
case <-l.closed:
_ = conn.Close()
return errListenerClosed
case l.conns <- conn:
return nil
}
}
func (l *connQueueListener) Accept() (net.Conn, error) {
conn, ok := <-l.conns
if !ok {
return nil, errListenerClosed
}
return conn, nil
}
func (l *connQueueListener) Close() error {
l.closeOnce.Do(func() {
close(l.closed)
close(l.conns)
})
return nil
}
func (l *connQueueListener) Addr() net.Addr { return l.addr }
func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
require.NoError(t, err, "generate agent certs")
caCert, err := caPEM.ToTLSCert()
require.NoError(t, err, "parse CA cert")
srvCert, err := srvPEM.ToTLSCert()
require.NoError(t, err, "parse server cert")
clientCert, err := clientPEM.ToTLSCert()
require.NoError(t, err, "parse client cert")
baseLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err, "listen tcp")
defer baseLn.Close()
baseAddr := baseLn.Addr().String()
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert.Leaf)
serverTLS := &tls.Config{
Certificates: []tls.Certificate{*srvCert},
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
NextProtos: []string{"http/1.1", stream.StreamALPN},
}
httpLn := newConnQueueListener(baseLn.Addr(), 16)
streamLn := newConnQueueListener(baseLn.Addr(), 16)
defer httpLn.Close()
defer streamLn.Close()
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// HTTP server
httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("ok"))
})}
go func() { _ = httpSrv.Serve(httpLn) }()
defer func() { _ = httpSrv.Shutdown(context.Background()) }()
// Stream server
dstAddr, closeDst := startTCPEcho(t)
defer closeDst()
tcpStreamSrv := stream.NewTCPServerFromListener(ctx, streamLn)
go func() { _ = tcpStreamSrv.Start() }()
defer func() { _ = tcpStreamSrv.Close() }()
// Mux loop
go func() {
for {
conn, err := baseLn.Accept()
if err != nil {
return
}
tlsConn := tls.Server(conn, serverTLS)
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = tlsConn.Close()
continue
}
if tlsConn.ConnectionState().NegotiatedProtocol == stream.StreamALPN {
_ = streamLn.push(tlsConn)
} else {
_ = httpLn.push(tlsConn)
}
}
}()
// HTTP client over the same port
clientTLS := &tls.Config{
Certificates: []tls.Certificate{*clientCert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
NextProtos: []string{"http/1.1"},
ServerName: common.CertsDNSName,
}
hc, err := tls.Dial("tcp", baseAddr, clientTLS)
require.NoError(t, err, "dial https")
defer hc.Close()
_ = hc.SetDeadline(time.Now().Add(2 * time.Second))
_, err = hc.Write([]byte("GET / HTTP/1.1\r\nHost: godoxy-agent\r\n\r\n"))
require.NoError(t, err, "write http request")
r := bufio.NewReader(hc)
statusLine, err := r.ReadString('\n')
require.NoError(t, err, "read status line")
require.Contains(t, statusLine, "200", "expected 200")
// Stream client over the same port
client, err := stream.NewTCPClient(baseAddr, dstAddr, caCert.Leaf, clientCert)
require.NoError(t, err, "create stream tcp client")
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over mux")
_, err = client.Write(msg)
require.NoError(t, err, "write stream payload")
buf := make([]byte, len(msg))
_, err = io.ReadFull(client, buf)
require.NoError(t, err, "read stream payload")
require.Equal(t, msg, buf)
}

View File

@@ -101,6 +101,11 @@ func TestTCPServer_FullFlow(t *testing.T) {
require.NoError(t, err, "create tcp client")
defer client.Close()
// Ensure ALPN is negotiated as expected (required for multiplexing).
withState, ok := client.(interface{ ConnectionState() tls.ConnectionState })
require.True(t, ok, "tcp client should expose TLS connection state")
require.Equal(t, stream.StreamALPN, withState.ConnectionState().NegotiatedProtocol)
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over tcp")
_, err = client.Write(msg)

View File

@@ -42,6 +42,7 @@ func NewTCPClient(serverAddr, targetAddress string, caCert *x509.Certificate, cl
Certificates: []tls.Certificate{*clientCert},
RootCAs: caCertPool,
MinVersion: tls.VersionTLS12,
NextProtos: []string{StreamALPN},
ServerName: common.CertsDNSName,
}
@@ -92,3 +93,14 @@ func (c *TCPClient) SetWriteDeadline(t time.Time) error {
func (c *TCPClient) Close() error {
return c.conn.Close()
}
// ConnectionState exposes the underlying TLS connection state when the client is
// backed by *tls.Conn.
//
// This is primarily used by tests and diagnostics.
func (c *TCPClient) ConnectionState() tls.ConnectionState {
if tc, ok := c.conn.(*tls.Conn); ok {
return tc.ConnectionState()
}
return tls.ConnectionState{}
}

View File

@@ -16,6 +16,21 @@ type TCPServer struct {
connMgr *ConnectionManager[net.Conn]
}
// NewTCPServerFromListener creates a TCP stream server from an already-prepared
// listener.
//
// The listener is expected to yield connections that are already secured (e.g.
// a TLS/mTLS listener, or pre-handshaked *tls.Conn). This is used when the agent
// multiplexes HTTPS and stream-tunnel traffic on the same port.
func NewTCPServerFromListener(ctx context.Context, listener net.Listener) *TCPServer {
s := &TCPServer{
ctx: ctx,
listener: listener,
}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
}
func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.Certificate, serverCert *tls.Certificate) *TCPServer {
caCertPool := x509.NewCertPool()
caCertPool.AddCert(caCert)
@@ -25,15 +40,11 @@ func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.C
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
NextProtos: []string{StreamALPN},
}
tcpListener := tls.NewListener(listener, tlsConfig)
s := &TCPServer{
ctx: ctx,
listener: tcpListener,
}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
return NewTCPServerFromListener(ctx, tcpListener)
}
func (s *TCPServer) Start() error {

View File

@@ -6,8 +6,7 @@ services:
{{ if eq .ContainerRuntime "podman" -}}
ports:
- "{{.Port}}:{{.Port}}/tcp"
- "{{.StreamPort}}:{{.StreamPort}}/tcp"
- "{{.StreamPort}}:{{.StreamPort}}/udp"
- "{{.Port}}:{{.Port}}/udp"
{{ else -}}
network_mode: host # do not change this
{{ end -}}
@@ -24,7 +23,6 @@ services:
{{ end -}}
AGENT_NAME: "{{.Name}}"
AGENT_PORT: "{{.Port}}"
AGENT_STREAM_PORT: "{{.StreamPort}}"
AGENT_CA_CERT: "{{.CACert}}"
AGENT_SSL_CERT: "{{.SSLCert}}"
# use agent as a docker socket proxy: [host]:port

View File

@@ -20,7 +20,6 @@ func DefaultAgentName() string {
var (
AgentName string
AgentPort int
AgentStreamPort int
AgentSkipClientCertCheck bool
AgentCACert string
AgentSSLCert string
@@ -36,7 +35,6 @@ func Load() {
DockerSocket = env.GetEnvString("DOCKER_SOCKET", "/var/run/docker.sock")
AgentName = env.GetEnvString("AGENT_NAME", DefaultAgentName())
AgentPort = env.GetEnvInt("AGENT_PORT", 8890)
AgentStreamPort = env.GetEnvInt("AGENT_STREAM_PORT", AgentPort+1)
AgentSkipClientCertCheck = env.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
AgentCACert = env.GetEnvString("AGENT_CA_CERT", "")

View File

@@ -46,10 +46,9 @@ func NewAgentHandler() http.Handler {
mux.HandleFunc(agent.EndpointProxyHTTP+"/{path...}", ProxyHTTP)
mux.HandleFunc(agent.EndpointInfo, func(w http.ResponseWriter, r *http.Request) {
agentInfo := agent.AgentInfo{
Version: version.Get(),
Name: env.AgentName,
Runtime: env.Runtime,
StreamPort: env.AgentStreamPort,
Version: version.Get(),
Name: env.AgentName,
Runtime: env.Runtime,
}
sonic.ConfigDefault.NewEncoder(w).Encode(agentInfo)
})

View File

@@ -16,7 +16,6 @@ type NewAgentRequest struct {
Name string `json:"name" binding:"required"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
StreamPort int `json:"stream_port" binding:"omitempty,min=1,max=65535"`
Type string `json:"type" binding:"required,oneof=docker system"`
Nightly bool `json:"nightly" binding:"omitempty"`
ContainerRuntime agent.ContainerRuntime `json:"container_runtime" binding:"omitempty,oneof=docker podman" default:"docker"`
@@ -69,18 +68,9 @@ func Create(c *gin.Context) {
return
}
if request.StreamPort <= 0 {
request.StreamPort = request.Port + 1
if request.StreamPort > 65535 {
c.JSON(http.StatusBadRequest, apitypes.Error("stream port is out of range"))
return
}
}
var cfg agent.Generator = &agent.AgentEnvConfig{
Name: request.Name,
Port: request.Port,
StreamPort: request.StreamPort,
CACert: ca.String(),
SSLCert: srv.String(),
ContainerRuntime: request.ContainerRuntime,

View File

@@ -2356,8 +2356,13 @@
"x-nullable": false,
"x-omitempty": false
},
"stream_port": {
"type": "integer",
"supports_tcp_stream": {
"type": "boolean",
"x-nullable": false,
"x-omitempty": false
},
"supports_udp_stream": {
"type": "boolean",
"x-nullable": false,
"x-omitempty": false
},
@@ -3859,11 +3864,6 @@
"x-nullable": false,
"x-omitempty": false
},
"stream_port": {
"type": "integer",
"maximum": 65535,
"minimum": 1
},
"type": {
"type": "string",
"enum": [

View File

@@ -8,8 +8,10 @@ definitions:
type: string
runtime:
$ref: '#/definitions/agent.ContainerRuntime'
stream_port:
type: integer
supports_tcp_stream:
type: boolean
supports_udp_stream:
type: boolean
version:
type: string
type: object
@@ -724,10 +726,6 @@ definitions:
maximum: 65535
minimum: 1
type: integer
stream_port:
maximum: 65535
minimum: 1
type: integer
type:
enum:
- docker