From 751d73da7cf1435f74a21a185690b4d7adb9f4b5 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 8 Jan 2026 00:15:47 +0800 Subject: [PATCH] refactor(agent/stream): remove connQueueListener and simplify the handshaking flow --- agent/cmd/main.go | 133 ++++----------------------- agent/pkg/agent/stream/mux_test.go | 96 ++++--------------- agent/pkg/agent/stream/tcp_server.go | 30 ++++++ 3 files changed, 66 insertions(+), 193 deletions(-) diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 122988f1..45adfa9e 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "os" - "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -25,53 +24,6 @@ import ( "github.com/yusing/goutils/version" ) -var errListenerClosed = errors.New("listener closed") - -type connQueueListener struct { - addr net.Addr - conns chan net.Conn - closed chan struct{} - closeOnce sync.Once -} - -func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener { - return &connQueueListener{ - addr: addr, - conns: make(chan net.Conn, buffer), - closed: make(chan struct{}), - } -} - -func (l *connQueueListener) push(conn net.Conn) error { - select { - case <-l.closed: - _ = conn.Close() - return errListenerClosed - case l.conns <- conn: - return nil - } -} - -func (l *connQueueListener) Accept() (net.Conn, error) { - conn, ok := <-l.conns - if !ok { - return nil, errListenerClosed - } - return conn, nil -} - -func (l *connQueueListener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) - close(l.conns) - }) - return nil -} - -func (l *connQueueListener) Addr() net.Addr { - return l.addr -} - func main() { writer := zerolog.ConsoleWriter{ Out: os.Stderr, @@ -113,7 +65,7 @@ Tips: t := task.RootTask("agent", false) // One TCP listener on AGENT_PORT, then multiplex by TLS ALPN: - // - Stream ALPN: route to TCP stream tunnel handler + // - Stream ALPN: route to TCP stream tunnel handler (via http.Server.TLSNextProto) // - Otherwise: route to HTTPS API handler tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort}) if err != nil { @@ -136,44 +88,43 @@ Tips: muxTLSConfig.ClientAuth = tls.NoClientCert } - httpLn := newConnQueueListener(tcpListener.Addr(), 128) - streamLn := newConnQueueListener(tcpListener.Addr(), 128) + // TLS listener feeds the HTTP server. ALPN stream connections are intercepted + // using http.Server.TLSNextProto. + tlsLn := tls.NewListener(tcpListener, muxTLSConfig) + + streamSrv := stream.NewTCPServerHandler(t.Context()) httpSrv := &http.Server{ Handler: handler.NewAgentHandler(), BaseContext: func(net.Listener) context.Context { return t.Context() }, + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ + // When a client negotiates StreamALPN, net/http will call this hook instead + // of treating the connection as HTTP. + stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + // ServeConn blocks until the tunnel finishes. + streamSrv.ServeConn(conn) + }, + }, } { subtask := t.Subtask("agent-http", true) t.OnCancel("stop_http", func() { - _ = httpSrv.Shutdown(context.Background()) - _ = httpLn.Close() + _ = streamSrv.Close() + _ = httpSrv.Close() + _ = tlsLn.Close() }) go func() { - err := httpSrv.Serve(httpLn) + err := httpSrv.Serve(tlsLn) if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Error().Err(err).Msg("agent HTTP server stopped with error") } subtask.Finish(err) }() - log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started") - } - - { - tcpServer := stream.NewTCPServerFromListener(t.Context(), streamLn) - subtask := t.Subtask("agent-stream-tcp", true) - t.OnCancel("stop_stream_tcp", func() { - _ = tcpServer.Close() - _ = streamLn.Close() - }) - go func() { - err := tcpServer.Start() - subtask.Finish(err) - }() - log.Info().Int("port", env.AgentPort).Msg("TCP stream server started") + log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started (ALPN mux enabled)") } + log.Info().Int("port", env.AgentPort).Msg("TCP stream handler started (via TLSNextProto)") { udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert) @@ -188,50 +139,6 @@ Tips: log.Info().Int("port", env.AgentPort).Msg("UDP stream server started") } - // Accept raw TCP connections, terminate TLS once, and dispatch by ALPN. - { - subtask := t.Subtask("agent-tls-mux", true) - t.OnCancel("stop_mux", func() { - _ = httpLn.Close() - _ = streamLn.Close() - _ = tcpListener.Close() - }) - go func() { - defer subtask.Finish(subtask.FinishCause()) - for { - select { - case <-t.Context().Done(): - return - default: - } - - conn, err := tcpListener.Accept() - if err != nil { - if t.Context().Err() != nil { - return - } - log.Error().Err(err).Msg("failed to accept connection") - continue - } - - tlsConn := tls.Server(conn, muxTLSConfig) - if err := tlsConn.HandshakeContext(t.Context()); err != nil { - _ = tlsConn.Close() - log.Debug().Err(err).Msg("TLS handshake failed") - continue - } - - alpn := tlsConn.ConnectionState().NegotiatedProtocol - switch alpn { - case stream.StreamALPN: - _ = streamLn.push(tlsConn) - default: - _ = httpLn.push(tlsConn) - } - } - }() - } - if socketproxy.ListenAddr != "" { runtime := strutils.Title(string(env.Runtime)) diff --git a/agent/pkg/agent/stream/mux_test.go b/agent/pkg/agent/stream/mux_test.go index 358c7598..afa98b3f 100644 --- a/agent/pkg/agent/stream/mux_test.go +++ b/agent/pkg/agent/stream/mux_test.go @@ -5,11 +5,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "errors" "io" "net" "net/http" - "sync" "testing" "time" @@ -19,51 +17,6 @@ import ( "github.com/yusing/godoxy/agent/pkg/agent/stream" ) -var errListenerClosed = errors.New("listener closed") - -type connQueueListener struct { - addr net.Addr - conns chan net.Conn - closed chan struct{} - closeOnce sync.Once -} - -func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener { - return &connQueueListener{ - addr: addr, - conns: make(chan net.Conn, buffer), - closed: make(chan struct{}), - } -} - -func (l *connQueueListener) push(conn net.Conn) error { - select { - case <-l.closed: - _ = conn.Close() - return errListenerClosed - case l.conns <- conn: - return nil - } -} - -func (l *connQueueListener) Accept() (net.Conn, error) { - conn, ok := <-l.conns - if !ok { - return nil, errListenerClosed - } - return conn, nil -} - -func (l *connQueueListener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) - close(l.conns) - }) - return nil -} - -func (l *connQueueListener) Addr() net.Addr { return l.addr } - func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) { caPEM, srvPEM, clientPEM, err := agent.NewAgent() require.NoError(t, err, "generate agent certs") @@ -91,49 +44,32 @@ func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) { NextProtos: []string{"http/1.1", stream.StreamALPN}, } - httpLn := newConnQueueListener(baseLn.Addr(), 16) - streamLn := newConnQueueListener(baseLn.Addr(), 16) - defer httpLn.Close() - defer streamLn.Close() - ctx, cancel := context.WithCancel(t.Context()) defer cancel() + streamSrv := stream.NewTCPServerHandler(ctx) + defer func() { _ = streamSrv.Close() }() + + tlsLn := tls.NewListener(baseLn, serverTLS) + defer func() { _ = tlsLn.Close() }() + // HTTP server httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) - })} - go func() { _ = httpSrv.Serve(httpLn) }() - defer func() { _ = httpSrv.Shutdown(context.Background()) }() + }), + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){ + stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) { + streamSrv.ServeConn(conn) + }, + }, + } + go func() { _ = httpSrv.Serve(tlsLn) }() + defer func() { _ = httpSrv.Close() }() - // Stream server + // Stream destination dstAddr, closeDst := startTCPEcho(t) defer closeDst() - tcpStreamSrv := stream.NewTCPServerFromListener(ctx, streamLn) - go func() { _ = tcpStreamSrv.Start() }() - defer func() { _ = tcpStreamSrv.Close() }() - - // Mux loop - go func() { - for { - conn, err := baseLn.Accept() - if err != nil { - return - } - tlsConn := tls.Server(conn, serverTLS) - if err := tlsConn.HandshakeContext(ctx); err != nil { - _ = tlsConn.Close() - continue - } - if tlsConn.ConnectionState().NegotiatedProtocol == stream.StreamALPN { - _ = streamLn.push(tlsConn) - } else { - _ = httpLn.push(tlsConn) - } - } - }() - // HTTP client over the same port clientTLS := &tls.Config{ Certificates: []tls.Certificate{*clientCert}, diff --git a/agent/pkg/agent/stream/tcp_server.go b/agent/pkg/agent/stream/tcp_server.go index a5fc73b0..63999387 100644 --- a/agent/pkg/agent/stream/tcp_server.go +++ b/agent/pkg/agent/stream/tcp_server.go @@ -16,6 +16,17 @@ type TCPServer struct { connMgr *ConnectionManager[net.Conn] } +// NewTCPServerHandler creates a TCP stream server that can serve already-accepted +// connections (e.g. handed off by an ALPN multiplexer). +// +// This variant does not require a listener. Use TCPServer.ServeConn to handle +// each incoming stream connection. +func NewTCPServerHandler(ctx context.Context) *TCPServer { + s := &TCPServer{ctx: ctx} + s.connMgr = NewConnectionManager(s.createDestConnection) + return s +} + // NewTCPServerFromListener creates a TCP stream server from an already-prepared // listener. // @@ -48,6 +59,9 @@ func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.C } func (s *TCPServer) Start() error { + if s.listener == nil { + return net.ErrClosed + } for { select { case <-s.ctx.Done(): @@ -62,12 +76,28 @@ func (s *TCPServer) Start() error { } } +// ServeConn serves a single stream connection. +// +// The provided connection is expected to be already secured (TLS/mTLS) and to +// speak the stream protocol (i.e. the client will send the stream header first). +// +// This method blocks until the stream finishes. +func (s *TCPServer) ServeConn(conn net.Conn) { + s.handle(conn) +} + func (s *TCPServer) Addr() net.Addr { + if s.listener == nil { + return nil + } return s.listener.Addr() } func (s *TCPServer) Close() error { s.connMgr.CloseAllConnections() + if s.listener == nil { + return nil + } return s.listener.Close() }