mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-10 10:53:36 +02:00
refactor(agent/stream): remove connQueueListener and simplify the handshaking flow
This commit is contained in:
@@ -5,11 +5,9 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -19,51 +17,6 @@ import (
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
)
|
||||
|
||||
var errListenerClosed = errors.New("listener closed")
|
||||
|
||||
type connQueueListener struct {
|
||||
addr net.Addr
|
||||
conns chan net.Conn
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newConnQueueListener(addr net.Addr, buffer int) *connQueueListener {
|
||||
return &connQueueListener{
|
||||
addr: addr,
|
||||
conns: make(chan net.Conn, buffer),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *connQueueListener) push(conn net.Conn) error {
|
||||
select {
|
||||
case <-l.closed:
|
||||
_ = conn.Close()
|
||||
return errListenerClosed
|
||||
case l.conns <- conn:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *connQueueListener) Accept() (net.Conn, error) {
|
||||
conn, ok := <-l.conns
|
||||
if !ok {
|
||||
return nil, errListenerClosed
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *connQueueListener) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.closed)
|
||||
close(l.conns)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *connQueueListener) Addr() net.Addr { return l.addr }
|
||||
|
||||
func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
|
||||
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate agent certs")
|
||||
@@ -91,49 +44,32 @@ func TestTLSALPNMux_HTTPAndStreamShareOnePort(t *testing.T) {
|
||||
NextProtos: []string{"http/1.1", stream.StreamALPN},
|
||||
}
|
||||
|
||||
httpLn := newConnQueueListener(baseLn.Addr(), 16)
|
||||
streamLn := newConnQueueListener(baseLn.Addr(), 16)
|
||||
defer httpLn.Close()
|
||||
defer streamLn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
streamSrv := stream.NewTCPServerHandler(ctx)
|
||||
defer func() { _ = streamSrv.Close() }()
|
||||
|
||||
tlsLn := tls.NewListener(baseLn, serverTLS)
|
||||
defer func() { _ = tlsLn.Close() }()
|
||||
|
||||
// HTTP server
|
||||
httpSrv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})}
|
||||
go func() { _ = httpSrv.Serve(httpLn) }()
|
||||
defer func() { _ = httpSrv.Shutdown(context.Background()) }()
|
||||
}),
|
||||
TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){
|
||||
stream.StreamALPN: func(_ *http.Server, conn *tls.Conn, _ http.Handler) {
|
||||
streamSrv.ServeConn(conn)
|
||||
},
|
||||
},
|
||||
}
|
||||
go func() { _ = httpSrv.Serve(tlsLn) }()
|
||||
defer func() { _ = httpSrv.Close() }()
|
||||
|
||||
// Stream server
|
||||
// Stream destination
|
||||
dstAddr, closeDst := startTCPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
tcpStreamSrv := stream.NewTCPServerFromListener(ctx, streamLn)
|
||||
go func() { _ = tcpStreamSrv.Start() }()
|
||||
defer func() { _ = tcpStreamSrv.Close() }()
|
||||
|
||||
// Mux loop
|
||||
go func() {
|
||||
for {
|
||||
conn, err := baseLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tlsConn := tls.Server(conn, serverTLS)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = tlsConn.Close()
|
||||
continue
|
||||
}
|
||||
if tlsConn.ConnectionState().NegotiatedProtocol == stream.StreamALPN {
|
||||
_ = streamLn.push(tlsConn)
|
||||
} else {
|
||||
_ = httpLn.push(tlsConn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// HTTP client over the same port
|
||||
clientTLS := &tls.Config{
|
||||
Certificates: []tls.Certificate{*clientCert},
|
||||
|
||||
@@ -16,6 +16,17 @@ type TCPServer struct {
|
||||
connMgr *ConnectionManager[net.Conn]
|
||||
}
|
||||
|
||||
// NewTCPServerHandler creates a TCP stream server that can serve already-accepted
|
||||
// connections (e.g. handed off by an ALPN multiplexer).
|
||||
//
|
||||
// This variant does not require a listener. Use TCPServer.ServeConn to handle
|
||||
// each incoming stream connection.
|
||||
func NewTCPServerHandler(ctx context.Context) *TCPServer {
|
||||
s := &TCPServer{ctx: ctx}
|
||||
s.connMgr = NewConnectionManager(s.createDestConnection)
|
||||
return s
|
||||
}
|
||||
|
||||
// NewTCPServerFromListener creates a TCP stream server from an already-prepared
|
||||
// listener.
|
||||
//
|
||||
@@ -48,6 +59,9 @@ func NewTCPServer(ctx context.Context, listener *net.TCPListener, caCert *x509.C
|
||||
}
|
||||
|
||||
func (s *TCPServer) Start() error {
|
||||
if s.listener == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
@@ -62,12 +76,28 @@ func (s *TCPServer) Start() error {
|
||||
}
|
||||
}
|
||||
|
||||
// ServeConn serves a single stream connection.
|
||||
//
|
||||
// The provided connection is expected to be already secured (TLS/mTLS) and to
|
||||
// speak the stream protocol (i.e. the client will send the stream header first).
|
||||
//
|
||||
// This method blocks until the stream finishes.
|
||||
func (s *TCPServer) ServeConn(conn net.Conn) {
|
||||
s.handle(conn)
|
||||
}
|
||||
|
||||
func (s *TCPServer) Addr() net.Addr {
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
func (s *TCPServer) Close() error {
|
||||
s.connMgr.CloseAllConnections()
|
||||
if s.listener == nil {
|
||||
return nil
|
||||
}
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user