test(agent/stream): refactor server flow tests to use testify and real certificate generation

- Use agent.NewAgent() for properly configured certificates matching real usage
- Migrate to testify/require for assertions
- Add tests for UDP server rejecting clients with invalid certificates
- Use t.Context() for lifecycle management
This commit is contained in:
yusing
2026-01-07 14:19:58 +08:00
parent 0a28d026c5
commit 56f7841eda

View File

@@ -1,90 +1,25 @@
package stream
package stream_test
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"io"
"math/big"
"net"
"testing"
"time"
"github.com/pion/dtls/v3"
"github.com/pion/transport/v3/udp"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/agent/pkg/agent/stream"
)
func newSerial(t *testing.T) *big.Int {
t.Helper()
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
t.Fatalf("rand serial: %v", err)
}
return sn
}
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: newSerial(t),
Subject: pkix.Name{CommonName: "stream-test-ca"},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate(CA): %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate(CA): %v", err)
}
return cert, key
}
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: newSerial(t),
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{eku},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
if err != nil {
t.Fatalf("CreateCertificate(%s): %v", cn, err)
}
leaf, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate(%s): %v", cn, err)
}
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
}
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen tcp: %v", err)
}
require.NoError(t, err, "listen tcp")
done := make(chan struct{})
go func() {
@@ -110,9 +45,7 @@ func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
t.Helper()
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen udp: %v", err)
}
require.NoError(t, err, "listen udp")
uc := pc.(*net.UDPConn)
done := make(chan struct{})
@@ -135,23 +68,27 @@ func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
}
func TestTCPServer_FullFlow(t *testing.T) {
ca, caKey := genCA(t)
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
require.NoError(t, err, "generate agent certs")
caCert, err := caPEM.ToTLSCert()
require.NoError(t, err, "parse CA cert")
srvCert, err := srvPEM.ToTLSCert()
require.NoError(t, err, "parse server cert")
clientCert, err := clientPEM.ToTLSCert()
require.NoError(t, err, "parse client cert")
dstAddr, closeDst := startTCPEcho(t)
defer closeDst()
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
t.Fatalf("ListenTCP: %v", err)
}
require.NoError(t, err, "listen tcp")
defer tcpLn.Close()
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
srv := stream.NewTCPServer(ctx, tcpLn, caCert.Leaf, srvCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
@@ -160,76 +97,138 @@ func TestTCPServer_FullFlow(t *testing.T) {
_ = <-errCh
}()
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
if err != nil {
t.Fatalf("NewTCPClient: %v", err)
}
client, err := stream.NewTCPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert)
require.NoError(t, err, "create tcp client")
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over tcp")
if _, err := client.Write(msg); err != nil {
t.Fatalf("client.Write: %v", err)
}
_, err = client.Write(msg)
require.NoError(t, err, "write to client")
buf := make([]byte, len(msg))
if _, err := io.ReadFull(client, buf); err != nil {
t.Fatalf("client.ReadFull: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
}
_, err = io.ReadFull(client, buf)
require.NoError(t, err, "read from client")
require.Equal(t, string(msg), string(buf), "unexpected echo")
}
func TestUDPServer_FullFlow(t *testing.T) {
ca, caKey := genCA(t)
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
func TestUDPServer_RejectInvalidClient(t *testing.T) {
caPEM, srvPEM, _, err := agent.NewAgent()
require.NoError(t, err, "generate agent certs")
caCert, err := caPEM.ToTLSCert()
require.NoError(t, err, "parse CA cert")
srvCert, err := srvPEM.ToTLSCert()
require.NoError(t, err, "parse server cert")
// Generate a self-signed client cert that is NOT signed by the CA
_, _, invalidClientPEM, err := agent.NewAgent()
require.NoError(t, err, "generate invalid client certs")
invalidClientCert, err := invalidClientPEM.ToTLSCert()
require.NoError(t, err, "parse invalid client cert")
dstAddr, closeDst := startUDPEcho(t)
defer closeDst()
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
_ = <-errCh
}()
time.Sleep(100 * time.Millisecond)
// Try to connect with a client cert from a different CA
_, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, invalidClientCert)
require.Error(t, err, "expected error when connecting with client cert from different CA")
var handshakeErr *dtls.HandshakeError
require.ErrorAs(t, err, &handshakeErr, "expected handshake error")
}
func TestUDPServer_RejectClientWithoutCert(t *testing.T) {
caPEM, srvPEM, _, err := agent.NewAgent()
require.NoError(t, err, "generate agent certs")
caCert, err := caPEM.ToTLSCert()
require.NoError(t, err, "parse CA cert")
srvCert, err := srvPEM.ToTLSCert()
require.NoError(t, err, "parse server cert")
dstAddr, closeDst := startUDPEcho(t)
defer closeDst()
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
_ = <-errCh
}()
time.Sleep(time.Second)
// Try to connect without any client certificate
// Create a TLS cert without a private key to simulate no client cert
emptyCert := &tls.Certificate{}
_, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, emptyCert)
require.Error(t, err, "expected error when connecting without client cert")
require.ErrorContains(t, err, "no certificate provided", "expected no cert error")
}
func TestUDPServer_FullFlow(t *testing.T) {
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
require.NoError(t, err, "generate agent certs")
caCert, err := caPEM.ToTLSCert()
require.NoError(t, err, "parse CA cert")
srvCert, err := srvPEM.ToTLSCert()
require.NoError(t, err, "parse server cert")
clientCert, err := clientPEM.ToTLSCert()
require.NoError(t, err, "parse client cert")
dstAddr, closeDst := startUDPEcho(t)
defer closeDst()
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
errCh := make(chan error, 1)
go func() { errCh <- srv.Start() }()
defer func() {
cancel()
_ = srv.Close()
err := <-errCh
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, udp.ErrClosedListener) {
t.Logf("udp server exit: %v", err)
}
}()
deadline := time.Now().Add(2 * time.Second)
for srv.listener == nil {
if time.Now().After(deadline) {
t.Fatalf("udp server listener did not start")
}
time.Sleep(10 * time.Millisecond)
}
time.Sleep(100 * time.Millisecond)
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
if err != nil {
t.Fatalf("NewUDPClient: %v", err)
}
client, err := stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert)
require.NoError(t, err, "create udp client")
defer client.Close()
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
msg := []byte("ping over udp")
if _, err := client.Write(msg); err != nil {
t.Fatalf("client.Write: %v", err)
}
_, err = client.Write(msg)
require.NoError(t, err, "write to client")
buf := make([]byte, 2048)
n, err := client.Read(buf)
if err != nil {
t.Fatalf("client.Read: %v", err)
}
if string(buf[:n]) != string(msg) {
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
}
require.NoError(t, err, "read from client")
require.Equal(t, string(msg), string(buf[:n]), "unexpected echo")
}