refactor(agent/stream): remove connQueueListener and simplify the handshaking flow

This commit is contained in:
yusing
2026-01-08 00:15:47 +08:00
parent d3a8b3c0e6
commit 751d73da7c
3 changed files with 66 additions and 193 deletions

View File

@@ -8,7 +8,6 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"sync"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -25,53 +24,6 @@ import (
"github.com/yusing/goutils/version" "github.com/yusing/goutils/version"
) )
var errListenerClosed = errors.New("listener closed")
type connQueueListener struct {
addr net.Addr
conns chan net.Conn
closed chan struct{}
closeOnce sync.Once
}
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
return &connQueueListener{
addr: addr,
conns: make(chan net.Conn, buffer),
closed: make(chan struct{}),
}
}
func (l *connQueueListener) push(conn net.Conn) error {
select {
case <-l.closed:
_ = conn.Close()
return errListenerClosed
case l.conns <- conn:
return nil
}
}
func (l *connQueueListener) Accept() (net.Conn, error) {
conn, ok := <-l.conns
if !ok {
return nil, errListenerClosed
}
return conn, nil
}
func (l *connQueueListener) Close() error {
l.closeOnce.Do(func() {
close(l.closed)
close(l.conns)
})
return nil
}
func (l *connQueueListener) Addr() net.Addr {
return l.addr
}
func main() { func main() {
writer := zerolog.ConsoleWriter{ writer := zerolog.ConsoleWriter{
Out: os.Stderr, Out: os.Stderr,
@@ -113,7 +65,7 @@ Tips:
t := task.RootTask("agent", false) t := task.RootTask("agent", false)
// One TCP listener on AGENT_PORT, then multiplex by TLS ALPN: // One TCP listener on AGENT_PORT, then multiplex by TLS ALPN:
// - Stream ALPN: route to TCP stream tunnel handler // - Stream ALPN: route to TCP stream tunnel handler (via http.Server.TLSNextProto)
// - Otherwise: route to HTTPS API handler // - Otherwise: route to HTTPS API handler
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort}) tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort})
if err != nil { if err != nil {
@@ -136,44 +88,43 @@ Tips:
muxTLSConfig.ClientAuth = tls.NoClientCert muxTLSConfig.ClientAuth = tls.NoClientCert
} }
httpLn := newConnQueueListener(tcpListener.Addr(), 128) // TLS listener feeds the HTTP server. ALPN stream connections are intercepted
streamLn := newConnQueueListener(tcpListener.Addr(), 128) // using http.Server.TLSNextProto.
tlsLn := tls.NewListener(tcpListener, muxTLSConfig)
streamSrv := stream.NewTCPServerHandler(t.Context())
httpSrv := &http.Server{ httpSrv := &http.Server{
Handler: handler.NewAgentHandler(), Handler: handler.NewAgentHandler(),
BaseContext: func(net.Listener) context.Context { BaseContext: func(net.Listener) context.Context {
return t.Context() return t.Context()
}, },
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
// When a client negotiates StreamALPN, net/http will call this hook instead
// of treating the connection as HTTP.
stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
// ServeConn blocks until the tunnel finishes.
streamSrv.ServeConn(conn)
},
},
} }
{ {
subtask := t.Subtask("agent-http", true) subtask := t.Subtask("agent-http", true)
t.OnCancel("stop_http", func() { t.OnCancel("stop_http", func() {
_ = httpSrv.Shutdown(context.Background()) _ = streamSrv.Close()
_ = httpLn.Close() _ = httpSrv.Close()
_ = tlsLn.Close()
}) })
go func() { go func() {
err := httpSrv.Serve(httpLn) err := httpSrv.Serve(tlsLn)
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error().Err(err).Msg("agent HTTP server stopped with error") log.Error().Err(err).Msg("agent HTTP server stopped with error")
} }
subtask.Finish(err) subtask.Finish(err)
}() }()
log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started") log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started (ALPN mux enabled)")
}
{
tcpServer := stream.NewTCPServerFromListener(t.Context(), streamLn)
subtask := t.Subtask("agent-stream-tcp", true)
t.OnCancel("stop_stream_tcp", func() {
_ = tcpServer.Close()
_ = streamLn.Close()
})
go func() {
err := tcpServer.Start()
subtask.Finish(err)
}()
log.Info().Int("port", env.AgentPort).Msg("TCP stream server started")
} }
log.Info().Int("port", env.AgentPort).Msg("TCP stream handler started (via TLSNextProto)")
{ {
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert) udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert)
@@ -188,50 +139,6 @@ Tips:
log.Info().Int("port", env.AgentPort).Msg("UDP stream server started") log.Info().Int("port", env.AgentPort).Msg("UDP stream server started")
} }
// Accept raw TCP connections, terminate TLS once, and dispatch by ALPN.
{
subtask := t.Subtask("agent-tls-mux", true)
t.OnCancel("stop_mux", func() {
_ = httpLn.Close()
_ = streamLn.Close()
_ = tcpListener.Close()
})
go func() {
defer subtask.Finish(subtask.FinishCause())
for {
select {
case <-t.Context().Done():
return
default:
}
conn, err := tcpListener.Accept()
if err != nil {
if t.Context().Err() != nil {
return
}
log.Error().Err(err).Msg("failed to accept connection")
continue
}
tlsConn := tls.Server(conn, muxTLSConfig)
if err := tlsConn.HandshakeContext(t.Context()); err != nil {
_ = tlsConn.Close()
log.Debug().Err(err).Msg("TLS handshake failed")
continue
}
alpn := tlsConn.ConnectionState().NegotiatedProtocol
switch alpn {
case stream.StreamALPN:
_ = streamLn.push(tlsConn)
default:
_ = httpLn.push(tlsConn)
}
}
}()
}
if socketproxy.ListenAddr != "" { if socketproxy.ListenAddr != "" {
runtime := strutils.Title(string(env.Runtime)) runtime := strutils.Title(string(env.Runtime))

View File

@@ -5,11 +5,9 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"sync"
"testing" "testing"
"time" "time"
@@ -19,51 +17,6 @@ import (
"github.com/yusing/godoxy/agent/pkg/agent/stream" "github.com/yusing/godoxy/agent/pkg/agent/stream"
) )
var errListenerClosed = errors.New("listener closed")
type connQueueListener struct {
addr net.Addr
conns chan net.Conn
closed chan struct{}
closeOnce sync.Once
}
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
return &connQueueListener{
addr: addr,
conns: make(chan net.Conn, buffer),
closed: make(chan struct{}),
}
}
func (l *connQueueListener) push(conn net.Conn) error {
select {
case <-l.closed:
_ = conn.Close()
return errListenerClosed
case l.conns <- conn:
return nil
}
}
func (l *connQueueListener) Accept() (net.Conn, error) {
conn, ok := <-l.conns
if !ok {
return nil, errListenerClosed
}
return conn, nil
}
func (l *connQueueListener) Close() error {
l.closeOnce.Do(func() {
close(l.closed)
close(l.conns)
})
return nil
}
func (l *connQueueListener) Addr() net.Addr { return l.addr }
func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) { func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
caPEM, srvPEM, clientPEM, err := agent.NewAgent() caPEM, srvPEM, clientPEM, err := agent.NewAgent()
require.NoError(t, err, "generate agent certs") require.NoError(t, err, "generate agent certs")
@@ -91,49 +44,32 @@ func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
NextProtos: []string{"http/1.1", stream.StreamALPN}, NextProtos: []string{"http/1.1", stream.StreamALPN},
} }
httpLn := newConnQueueListener(baseLn.Addr(), 16)
streamLn := newConnQueueListener(baseLn.Addr(), 16)
defer httpLn.Close()
defer streamLn.Close()
ctx, cancel := context.WithCancel(t.Context()) ctx, cancel := context.WithCancel(t.Context())
defer cancel() defer cancel()
streamSrv := stream.NewTCPServerHandler(ctx)
defer func() { _ = streamSrv.Close() }()
tlsLn := tls.NewListener(baseLn, serverTLS)
defer func() { _ = tlsLn.Close() }()
// HTTP server // HTTP server
httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("ok")) _, _ = w.Write([]byte("ok"))
})} }),
go func() { _ = httpSrv.Serve(httpLn) }() TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
defer func() { _ = httpSrv.Shutdown(context.Background()) }() stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
streamSrv.ServeConn(conn)
},
},
}
go func() { _ = httpSrv.Serve(tlsLn) }()
defer func() { _ = httpSrv.Close() }()
// Stream server // Stream destination
dstAddr, closeDst := startTCPEcho(t) dstAddr, closeDst := startTCPEcho(t)
defer closeDst() defer closeDst()
tcpStreamSrv := stream.NewTCPServerFromListener(ctx, streamLn)
go func() { _ = tcpStreamSrv.Start() }()
defer func() { _ = tcpStreamSrv.Close() }()
// Mux loop
go func() {
for {
conn, err := baseLn.Accept()
if err != nil {
return
}
tlsConn := tls.Server(conn, serverTLS)
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = tlsConn.Close()
continue
}
if tlsConn.ConnectionState().NegotiatedProtocol == stream.StreamALPN {
_ = streamLn.push(tlsConn)
} else {
_ = httpLn.push(tlsConn)
}
}
}()
// HTTP client over the same port // HTTP client over the same port
clientTLS := &tls.Config{ clientTLS := &tls.Config{
Certificates: []tls.Certificate{*clientCert}, Certificates: []tls.Certificate{*clientCert},

View File

@@ -16,6 +16,17 @@ type TCPServer struct {
connMgr *ConnectionManager[net.Conn] connMgr *ConnectionManager[net.Conn]
} }
// NewTCPServerHandler creates a TCP stream server that can serve already-accepted
// connections (e.g. handed off by an ALPN multiplexer).
//
// This variant does not require a listener. Use TCPServer.ServeConn to handle
// each incoming stream connection.
func NewTCPServerHandler(ctx context.Context) *TCPServer {
s := &TCPServer{ctx: ctx}
s.connMgr = NewConnectionManager(s.createDestConnection)
return s
}
// NewTCPServerFromListener creates a TCP stream server from an already-prepared // NewTCPServerFromListener creates a TCP stream server from an already-prepared
// listener. // listener.
// //
@@ -48,6 +59,9 @@ func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.C
} }
func (s *TCPServer) Start() error { func (s *TCPServer) Start() error {
if s.listener == nil {
return net.ErrClosed
}
for { for {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
@@ -62,12 +76,28 @@ func (s *TCPServer) Start() error {
} }
} }
// ServeConn serves a single stream connection.
//
// The provided connection is expected to be already secured (TLS/mTLS) and to
// speak the stream protocol (i.e. the client will send the stream header first).
//
// This method blocks until the stream finishes.
func (s *TCPServer) ServeConn(conn net.Conn) {
s.handle(conn)
}
func (s *TCPServer) Addr() net.Addr { func (s *TCPServer) Addr() net.Addr {
if s.listener == nil {
return nil
}
return s.listener.Addr() return s.listener.Addr()
} }
func (s *TCPServer) Close() error { func (s *TCPServer) Close() error {
s.connMgr.CloseAllConnections() s.connMgr.CloseAllConnections()
if s.listener == nil {
return nil
}
return s.listener.Close() return s.listener.Close()
} }