mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-21 08:59:47 +01:00
201 lines
5.1 KiB
Go
201 lines
5.1 KiB
Go
package stream_test
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/dtls/v3"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/yusing/godoxy/agent/pkg/agent"
|
|
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
|
)
|
|
|
|
func TestTCPServer_FullFlow(t *testing.T) {
|
|
certs := genTestCerts(t)
|
|
|
|
dstAddr, closeDst := startTCPEcho(t)
|
|
defer closeDst()
|
|
|
|
srv := startTCPServer(t, certs)
|
|
|
|
client := NewTCPClient(t, srv.Addr.String(), dstAddr, certs)
|
|
defer client.Close()
|
|
|
|
// Ensure ALPN is negotiated as expected (required for multiplexing).
|
|
withState, ok := client.(interface{ ConnectionState() tls.ConnectionState })
|
|
require.True(t, ok, "tcp client should expose TLS connection state")
|
|
require.Equal(t, stream.StreamALPN, withState.ConnectionState().NegotiatedProtocol)
|
|
|
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
|
msg := []byte("ping over tcp")
|
|
_, err := client.Write(msg)
|
|
require.NoError(t, err, "write to client")
|
|
|
|
buf := make([]byte, len(msg))
|
|
_, err = io.ReadFull(client, buf)
|
|
require.NoError(t, err, "read from client")
|
|
require.Equal(t, string(msg), string(buf), "unexpected echo")
|
|
}
|
|
|
|
func TestTCPServer_ConcurrentConnections(t *testing.T) {
|
|
certs := genTestCerts(t)
|
|
|
|
dstAddr, closeDst := startTCPEcho(t)
|
|
defer closeDst()
|
|
|
|
srv := startTCPServer(t, certs)
|
|
|
|
const nClients = 25
|
|
|
|
errs := make(chan error, nClients)
|
|
var wg sync.WaitGroup
|
|
wg.Add(nClients)
|
|
|
|
for i := range nClients {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
client := NewTCPClient(t, srv.Addr.String(), dstAddr, certs)
|
|
defer client.Close()
|
|
|
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
|
msg := fmt.Appendf(nil, "ping over tcp %d", i)
|
|
if _, err := client.Write(msg); err != nil {
|
|
errs <- fmt.Errorf("write to client: %w", err)
|
|
return
|
|
}
|
|
|
|
buf := make([]byte, len(msg))
|
|
if _, err := io.ReadFull(client, buf); err != nil {
|
|
errs <- fmt.Errorf("read from client: %w", err)
|
|
return
|
|
}
|
|
if string(msg) != string(buf) {
|
|
errs <- fmt.Errorf("unexpected echo: got=%q want=%q", string(buf), string(msg))
|
|
return
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errs)
|
|
for err := range errs {
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
func TestUDPServer_RejectInvalidClient(t *testing.T) {
|
|
certs := genTestCerts(t)
|
|
|
|
// Generate a self-signed client cert that is NOT signed by the CA
|
|
_, _, invalidClientPEM, err := agent.NewAgent()
|
|
require.NoError(t, err, "generate invalid client certs")
|
|
invalidClientCert, err := invalidClientPEM.ToTLSCert()
|
|
require.NoError(t, err, "parse invalid client cert")
|
|
|
|
dstAddr, closeDst := startUDPEcho(t)
|
|
defer closeDst()
|
|
|
|
srv := startUDPServer(t, certs)
|
|
|
|
// Try to connect with a client cert from a different CA
|
|
_, err = stream.NewUDPClient(srv.Addr.String(), dstAddr, certs.CaCert, invalidClientCert)
|
|
require.Error(t, err, "expected error when connecting with client cert from different CA")
|
|
|
|
var handshakeErr *dtls.HandshakeError
|
|
require.ErrorAs(t, err, &handshakeErr, "expected handshake error")
|
|
}
|
|
|
|
func TestUDPServer_RejectClientWithoutCert(t *testing.T) {
|
|
certs := genTestCerts(t)
|
|
|
|
dstAddr, closeDst := startUDPEcho(t)
|
|
defer closeDst()
|
|
|
|
srv := startUDPServer(t, certs)
|
|
|
|
time.Sleep(time.Second)
|
|
|
|
// Try to connect without any client certificate
|
|
// Create a TLS cert without a private key to simulate no client cert
|
|
emptyCert := &tls.Certificate{}
|
|
_, err := stream.NewUDPClient(srv.Addr.String(), dstAddr, certs.CaCert, emptyCert)
|
|
require.Error(t, err, "expected error when connecting without client cert")
|
|
|
|
require.ErrorContains(t, err, "no certificate provided", "expected no cert error")
|
|
}
|
|
|
|
func TestUDPServer_FullFlow(t *testing.T) {
|
|
certs := genTestCerts(t)
|
|
|
|
dstAddr, closeDst := startUDPEcho(t)
|
|
defer closeDst()
|
|
|
|
srv := startUDPServer(t, certs)
|
|
|
|
client := NewUDPClient(t, srv.Addr.String(), dstAddr, certs)
|
|
defer client.Close()
|
|
|
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
|
msg := []byte("ping over udp")
|
|
_, err := client.Write(msg)
|
|
require.NoError(t, err, "write to client")
|
|
|
|
buf := make([]byte, 2048)
|
|
n, err := client.Read(buf)
|
|
require.NoError(t, err, "read from client")
|
|
require.Equal(t, string(msg), string(buf[:n]), "unexpected echo")
|
|
}
|
|
|
|
func TestUDPServer_ConcurrentConnections(t *testing.T) {
|
|
certs := genTestCerts(t)
|
|
|
|
dstAddr, closeDst := startUDPEcho(t)
|
|
defer closeDst()
|
|
|
|
srv := startUDPServer(t, certs)
|
|
|
|
const nClients = 25
|
|
|
|
errs := make(chan error, nClients)
|
|
var wg sync.WaitGroup
|
|
wg.Add(nClients)
|
|
|
|
for i := range nClients {
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
client := NewUDPClient(t, srv.Addr.String(), dstAddr, certs)
|
|
defer client.Close()
|
|
|
|
_ = client.SetDeadline(time.Now().Add(5 * time.Second))
|
|
msg := fmt.Appendf(nil, "ping over udp %d", i)
|
|
if _, err := client.Write(msg); err != nil {
|
|
errs <- fmt.Errorf("write to client: %w", err)
|
|
return
|
|
}
|
|
|
|
buf := make([]byte, 2048)
|
|
n, err := client.Read(buf)
|
|
if err != nil {
|
|
errs <- fmt.Errorf("read from client: %w", err)
|
|
return
|
|
}
|
|
if string(msg) != string(buf[:n]) {
|
|
errs <- fmt.Errorf("unexpected echo: got=%q want=%q", string(buf[:n]), string(msg))
|
|
return
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errs)
|
|
for err := range errs {
|
|
require.NoError(t, err)
|
|
}
|
|
}
|