mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 09:18:31 +02:00
refactor(agent/stream): remove connQueueListener and simplify the handshaking flow
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -25,53 +24,6 @@ import (
|
|||||||
"github.com/yusing/goutils/version"
|
"github.com/yusing/goutils/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errListenerClosed = errors.New("listener closed")
|
|
||||||
|
|
||||||
type connQueueListener struct {
|
|
||||||
addr net.Addr
|
|
||||||
conns chan net.Conn
|
|
||||||
closed chan struct{}
|
|
||||||
closeOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
|
|
||||||
return &connQueueListener{
|
|
||||||
addr: addr,
|
|
||||||
conns: make(chan net.Conn, buffer),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) push(conn net.Conn) error {
|
|
||||||
select {
|
|
||||||
case <-l.closed:
|
|
||||||
_ = conn.Close()
|
|
||||||
return errListenerClosed
|
|
||||||
case l.conns <- conn:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) Accept() (net.Conn, error) {
|
|
||||||
conn, ok := <-l.conns
|
|
||||||
if !ok {
|
|
||||||
return nil, errListenerClosed
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) Close() error {
|
|
||||||
l.closeOnce.Do(func() {
|
|
||||||
close(l.closed)
|
|
||||||
close(l.conns)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) Addr() net.Addr {
|
|
||||||
return l.addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
writer := zerolog.ConsoleWriter{
|
writer := zerolog.ConsoleWriter{
|
||||||
Out: os.Stderr,
|
Out: os.Stderr,
|
||||||
@@ -113,7 +65,7 @@ Tips:
|
|||||||
t := task.RootTask("agent", false)
|
t := task.RootTask("agent", false)
|
||||||
|
|
||||||
// One TCP listener on AGENT_PORT, then multiplex by TLS ALPN:
|
// One TCP listener on AGENT_PORT, then multiplex by TLS ALPN:
|
||||||
// - Stream ALPN: route to TCP stream tunnel handler
|
// - Stream ALPN: route to TCP stream tunnel handler (via http.Server.TLSNextProto)
|
||||||
// - Otherwise: route to HTTPS API handler
|
// - Otherwise: route to HTTPS API handler
|
||||||
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort})
|
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: env.AgentPort})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -136,44 +88,43 @@ Tips:
|
|||||||
muxTLSConfig.ClientAuth = tls.NoClientCert
|
muxTLSConfig.ClientAuth = tls.NoClientCert
|
||||||
}
|
}
|
||||||
|
|
||||||
httpLn := newConnQueueListener(tcpListener.Addr(), 128)
|
// TLS listener feeds the HTTP server. ALPN stream connections are intercepted
|
||||||
streamLn := newConnQueueListener(tcpListener.Addr(), 128)
|
// using http.Server.TLSNextProto.
|
||||||
|
tlsLn := tls.NewListener(tcpListener, muxTLSConfig)
|
||||||
|
|
||||||
|
streamSrv := stream.NewTCPServerHandler(t.Context())
|
||||||
|
|
||||||
httpSrv := &http.Server{
|
httpSrv := &http.Server{
|
||||||
Handler: handler.NewAgentHandler(),
|
Handler: handler.NewAgentHandler(),
|
||||||
BaseContext: func(net.Listener) context.Context {
|
BaseContext: func(net.Listener) context.Context {
|
||||||
return t.Context()
|
return t.Context()
|
||||||
},
|
},
|
||||||
|
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
|
||||||
|
// When a client negotiates StreamALPN, net/http will call this hook instead
|
||||||
|
// of treating the connection as HTTP.
|
||||||
|
stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
|
||||||
|
// ServeConn blocks until the tunnel finishes.
|
||||||
|
streamSrv.ServeConn(conn)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
subtask := t.Subtask("agent-http", true)
|
subtask := t.Subtask("agent-http", true)
|
||||||
t.OnCancel("stop_http", func() {
|
t.OnCancel("stop_http", func() {
|
||||||
_ = httpSrv.Shutdown(context.Background())
|
_ = streamSrv.Close()
|
||||||
_ = httpLn.Close()
|
_ = httpSrv.Close()
|
||||||
|
_ = tlsLn.Close()
|
||||||
})
|
})
|
||||||
go func() {
|
go func() {
|
||||||
err := httpSrv.Serve(httpLn)
|
err := httpSrv.Serve(tlsLn)
|
||||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Error().Err(err).Msg("agent HTTP server stopped with error")
|
log.Error().Err(err).Msg("agent HTTP server stopped with error")
|
||||||
}
|
}
|
||||||
subtask.Finish(err)
|
subtask.Finish(err)
|
||||||
}()
|
}()
|
||||||
log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started")
|
log.Info().Int("port", env.AgentPort).Msg("HTTPS API server started (ALPN mux enabled)")
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
tcpServer := stream.NewTCPServerFromListener(t.Context(), streamLn)
|
|
||||||
subtask := t.Subtask("agent-stream-tcp", true)
|
|
||||||
t.OnCancel("stop_stream_tcp", func() {
|
|
||||||
_ = tcpServer.Close()
|
|
||||||
_ = streamLn.Close()
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
err := tcpServer.Start()
|
|
||||||
subtask.Finish(err)
|
|
||||||
}()
|
|
||||||
log.Info().Int("port", env.AgentPort).Msg("TCP stream server started")
|
|
||||||
}
|
}
|
||||||
|
log.Info().Int("port", env.AgentPort).Msg("TCP stream handler started (via TLSNextProto)")
|
||||||
|
|
||||||
{
|
{
|
||||||
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert)
|
udpServer := stream.NewUDPServer(t.Context(), &net.UDPAddr{Port: env.AgentPort}, caCert.Leaf, srvCert)
|
||||||
@@ -188,50 +139,6 @@ Tips:
|
|||||||
log.Info().Int("port", env.AgentPort).Msg("UDP stream server started")
|
log.Info().Int("port", env.AgentPort).Msg("UDP stream server started")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept raw TCP connections, terminate TLS once, and dispatch by ALPN.
|
|
||||||
{
|
|
||||||
subtask := t.Subtask("agent-tls-mux", true)
|
|
||||||
t.OnCancel("stop_mux", func() {
|
|
||||||
_ = httpLn.Close()
|
|
||||||
_ = streamLn.Close()
|
|
||||||
_ = tcpListener.Close()
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
defer subtask.Finish(subtask.FinishCause())
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-t.Context().Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := tcpListener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if t.Context().Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error().Err(err).Msg("failed to accept connection")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConn := tls.Server(conn, muxTLSConfig)
|
|
||||||
if err := tlsConn.HandshakeContext(t.Context()); err != nil {
|
|
||||||
_ = tlsConn.Close()
|
|
||||||
log.Debug().Err(err).Msg("TLS handshake failed")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
alpn := tlsConn.ConnectionState().NegotiatedProtocol
|
|
||||||
switch alpn {
|
|
||||||
case stream.StreamALPN:
|
|
||||||
_ = streamLn.push(tlsConn)
|
|
||||||
default:
|
|
||||||
_ = httpLn.push(tlsConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
if socketproxy.ListenAddr != "" {
|
if socketproxy.ListenAddr != "" {
|
||||||
runtime := strutils.Title(string(env.Runtime))
|
runtime := strutils.Title(string(env.Runtime))
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,51 +17,6 @@ import (
|
|||||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errListenerClosed = errors.New("listener closed")
|
|
||||||
|
|
||||||
type connQueueListener struct {
|
|
||||||
addr net.Addr
|
|
||||||
conns chan net.Conn
|
|
||||||
closed chan struct{}
|
|
||||||
closeOnce sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
|
|
||||||
return &connQueueListener{
|
|
||||||
addr: addr,
|
|
||||||
conns: make(chan net.Conn, buffer),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) push(conn net.Conn) error {
|
|
||||||
select {
|
|
||||||
case <-l.closed:
|
|
||||||
_ = conn.Close()
|
|
||||||
return errListenerClosed
|
|
||||||
case l.conns <- conn:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) Accept() (net.Conn, error) {
|
|
||||||
conn, ok := <-l.conns
|
|
||||||
if !ok {
|
|
||||||
return nil, errListenerClosed
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) Close() error {
|
|
||||||
l.closeOnce.Do(func() {
|
|
||||||
close(l.closed)
|
|
||||||
close(l.conns)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *connQueueListener) Addr() net.Addr { return l.addr }
|
|
||||||
|
|
||||||
func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
|
func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
|
||||||
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||||
require.NoError(t, err, "generate agent certs")
|
require.NoError(t, err, "generate agent certs")
|
||||||
@@ -91,49 +44,32 @@ func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
|
|||||||
NextProtos: []string{"http/1.1", stream.StreamALPN},
|
NextProtos: []string{"http/1.1", stream.StreamALPN},
|
||||||
}
|
}
|
||||||
|
|
||||||
httpLn := newConnQueueListener(baseLn.Addr(), 16)
|
|
||||||
streamLn := newConnQueueListener(baseLn.Addr(), 16)
|
|
||||||
defer httpLn.Close()
|
|
||||||
defer streamLn.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(t.Context())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
streamSrv := stream.NewTCPServerHandler(ctx)
|
||||||
|
defer func() { _ = streamSrv.Close() }()
|
||||||
|
|
||||||
|
tlsLn := tls.NewListener(baseLn, serverTLS)
|
||||||
|
defer func() { _ = tlsLn.Close() }()
|
||||||
|
|
||||||
// HTTP server
|
// HTTP server
|
||||||
httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
_, _ = w.Write([]byte("ok"))
|
_, _ = w.Write([]byte("ok"))
|
||||||
})}
|
}),
|
||||||
go func() { _ = httpSrv.Serve(httpLn) }()
|
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
|
||||||
defer func() { _ = httpSrv.Shutdown(context.Background()) }()
|
stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
|
||||||
|
streamSrv.ServeConn(conn)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go func() { _ = httpSrv.Serve(tlsLn) }()
|
||||||
|
defer func() { _ = httpSrv.Close() }()
|
||||||
|
|
||||||
// Stream server
|
// Stream destination
|
||||||
dstAddr, closeDst := startTCPEcho(t)
|
dstAddr, closeDst := startTCPEcho(t)
|
||||||
defer closeDst()
|
defer closeDst()
|
||||||
|
|
||||||
tcpStreamSrv := stream.NewTCPServerFromListener(ctx, streamLn)
|
|
||||||
go func() { _ = tcpStreamSrv.Start() }()
|
|
||||||
defer func() { _ = tcpStreamSrv.Close() }()
|
|
||||||
|
|
||||||
// Mux loop
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
conn, err := baseLn.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tlsConn := tls.Server(conn, serverTLS)
|
|
||||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
|
||||||
_ = tlsConn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if tlsConn.ConnectionState().NegotiatedProtocol == stream.StreamALPN {
|
|
||||||
_ = streamLn.push(tlsConn)
|
|
||||||
} else {
|
|
||||||
_ = httpLn.push(tlsConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// HTTP client over the same port
|
// HTTP client over the same port
|
||||||
clientTLS := &tls.Config{
|
clientTLS := &tls.Config{
|
||||||
Certificates: []tls.Certificate{*clientCert},
|
Certificates: []tls.Certificate{*clientCert},
|
||||||
|
|||||||
@@ -16,6 +16,17 @@ type TCPServer struct {
|
|||||||
connMgr *ConnectionManager[net.Conn]
|
connMgr *ConnectionManager[net.Conn]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewTCPServerHandler creates a TCP stream server that can serve already-accepted
|
||||||
|
// connections (e.g. handed off by an ALPN multiplexer).
|
||||||
|
//
|
||||||
|
// This variant does not require a listener. Use TCPServer.ServeConn to handle
|
||||||
|
// each incoming stream connection.
|
||||||
|
func NewTCPServerHandler(ctx context.Context) *TCPServer {
|
||||||
|
s := &TCPServer{ctx: ctx}
|
||||||
|
s.connMgr = NewConnectionManager(s.createDestConnection)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// NewTCPServerFromListener creates a TCP stream server from an already-prepared
|
// NewTCPServerFromListener creates a TCP stream server from an already-prepared
|
||||||
// listener.
|
// listener.
|
||||||
//
|
//
|
||||||
@@ -48,6 +59,9 @@ func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TCPServer) Start() error {
|
func (s *TCPServer) Start() error {
|
||||||
|
if s.listener == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
@@ -62,12 +76,28 @@ func (s *TCPServer) Start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ServeConn serves a single stream connection.
|
||||||
|
//
|
||||||
|
// The provided connection is expected to be already secured (TLS/mTLS) and to
|
||||||
|
// speak the stream protocol (i.e. the client will send the stream header first).
|
||||||
|
//
|
||||||
|
// This method blocks until the stream finishes.
|
||||||
|
func (s *TCPServer) ServeConn(conn net.Conn) {
|
||||||
|
s.handle(conn)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *TCPServer) Addr() net.Addr {
|
func (s *TCPServer) Addr() net.Addr {
|
||||||
|
if s.listener == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return s.listener.Addr()
|
return s.listener.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TCPServer) Close() error {
|
func (s *TCPServer) Close() error {
|
||||||
s.connMgr.CloseAllConnections()
|
s.connMgr.CloseAllConnections()
|
||||||
|
if s.listener == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return s.listener.Close()
|
return s.listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user