mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-18 15:23:51 +01:00
test(agent/stream): refactor server flow tests to use testify and real certificate generation
- Use agent.NewAgent() for properly configured certificates matching real usage - Migrate to testify/require for assertions - Add tests for UDP server rejecting clients with invalid certificates - Use t.Context() for lifecycle management
This commit is contained in:
@@ -1,90 +1,25 @@
|
||||
package stream
|
||||
package stream_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/dtls/v3"
|
||||
"github.com/pion/transport/v3/udp"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||
)
|
||||
|
||||
func newSerial(t *testing.T) *big.Int {
|
||||
t.Helper()
|
||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
t.Fatalf("rand serial: %v", err)
|
||||
}
|
||||
return sn
|
||||
}
|
||||
|
||||
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: newSerial(t),
|
||||
Subject: pkix.Name{CommonName: "stream-test-ca"},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate(CA): %v", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificate(CA): %v", err)
|
||||
}
|
||||
return cert, key
|
||||
}
|
||||
|
||||
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: newSerial(t),
|
||||
Subject: pkix.Name{CommonName: cn},
|
||||
NotBefore: time.Now().Add(-time.Minute),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{eku},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate(%s): %v", cn, err)
|
||||
}
|
||||
leaf, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificate(%s): %v", cn, err)
|
||||
}
|
||||
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
|
||||
}
|
||||
|
||||
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen tcp: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "listen tcp")
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
@@ -110,9 +45,7 @@ func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
t.Helper()
|
||||
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen udp: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "listen udp")
|
||||
uc := pc.(*net.UDPConn)
|
||||
|
||||
done := make(chan struct{})
|
||||
@@ -135,23 +68,27 @@ func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||
}
|
||||
|
||||
func TestTCPServer_FullFlow(t *testing.T) {
|
||||
ca, caKey := genCA(t)
|
||||
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
||||
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
||||
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate agent certs")
|
||||
|
||||
caCert, err := caPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse CA cert")
|
||||
srvCert, err := srvPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse server cert")
|
||||
clientCert, err := clientPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse client cert")
|
||||
|
||||
dstAddr, closeDst := startTCPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTCP: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "listen tcp")
|
||||
defer tcpLn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
|
||||
srv := stream.NewTCPServer(ctx, tcpLn, caCert.Leaf, srvCert)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
defer func() {
|
||||
@@ -160,76 +97,138 @@ func TestTCPServer_FullFlow(t *testing.T) {
|
||||
_ = <-errCh
|
||||
}()
|
||||
|
||||
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTCPClient: %v", err)
|
||||
}
|
||||
client, err := stream.NewTCPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert)
|
||||
require.NoError(t, err, "create tcp client")
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over tcp")
|
||||
if _, err := client.Write(msg); err != nil {
|
||||
t.Fatalf("client.Write: %v", err)
|
||||
}
|
||||
_, err = client.Write(msg)
|
||||
require.NoError(t, err, "write to client")
|
||||
|
||||
buf := make([]byte, len(msg))
|
||||
if _, err := io.ReadFull(client, buf); err != nil {
|
||||
t.Fatalf("client.ReadFull: %v", err)
|
||||
}
|
||||
if string(buf) != string(msg) {
|
||||
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
|
||||
}
|
||||
_, err = io.ReadFull(client, buf)
|
||||
require.NoError(t, err, "read from client")
|
||||
require.Equal(t, string(msg), string(buf), "unexpected echo")
|
||||
}
|
||||
|
||||
func TestUDPServer_FullFlow(t *testing.T) {
|
||||
ca, caKey := genCA(t)
|
||||
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
||||
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
||||
func TestUDPServer_RejectInvalidClient(t *testing.T) {
|
||||
caPEM, srvPEM, _, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate agent certs")
|
||||
|
||||
caCert, err := caPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse CA cert")
|
||||
srvCert, err := srvPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse server cert")
|
||||
|
||||
// Generate a self-signed client cert that is NOT signed by the CA
|
||||
_, _, invalidClientPEM, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate invalid client certs")
|
||||
invalidClientCert, err := invalidClientPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse invalid client cert")
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
|
||||
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
defer func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
_ = <-errCh
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to connect with a client cert from a different CA
|
||||
_, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, invalidClientCert)
|
||||
require.Error(t, err, "expected error when connecting with client cert from different CA")
|
||||
|
||||
var handshakeErr *dtls.HandshakeError
|
||||
require.ErrorAs(t, err, &handshakeErr, "expected handshake error")
|
||||
}
|
||||
|
||||
func TestUDPServer_RejectClientWithoutCert(t *testing.T) {
|
||||
caPEM, srvPEM, _, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate agent certs")
|
||||
|
||||
caCert, err := caPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse CA cert")
|
||||
srvCert, err := srvPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse server cert")
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
defer func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
_ = <-errCh
|
||||
}()
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Try to connect without any client certificate
|
||||
// Create a TLS cert without a private key to simulate no client cert
|
||||
emptyCert := &tls.Certificate{}
|
||||
_, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, emptyCert)
|
||||
require.Error(t, err, "expected error when connecting without client cert")
|
||||
|
||||
require.ErrorContains(t, err, "no certificate provided", "expected no cert error")
|
||||
}
|
||||
|
||||
func TestUDPServer_FullFlow(t *testing.T) {
|
||||
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||
require.NoError(t, err, "generate agent certs")
|
||||
|
||||
caCert, err := caPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse CA cert")
|
||||
srvCert, err := srvPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse server cert")
|
||||
clientCert, err := clientPEM.ToTLSCert()
|
||||
require.NoError(t, err, "parse client cert")
|
||||
|
||||
dstAddr, closeDst := startUDPEcho(t)
|
||||
defer closeDst()
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
|
||||
errCh := make(chan error, 1)
|
||||
go func() { errCh <- srv.Start() }()
|
||||
defer func() {
|
||||
cancel()
|
||||
_ = srv.Close()
|
||||
err := <-errCh
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, udp.ErrClosedListener) {
|
||||
t.Logf("udp server exit: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for srv.listener == nil {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("udp server listener did not start")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
||||
if err != nil {
|
||||
t.Fatalf("NewUDPClient: %v", err)
|
||||
}
|
||||
client, err := stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert)
|
||||
require.NoError(t, err, "create udp client")
|
||||
defer client.Close()
|
||||
|
||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
msg := []byte("ping over udp")
|
||||
if _, err := client.Write(msg); err != nil {
|
||||
t.Fatalf("client.Write: %v", err)
|
||||
}
|
||||
_, err = client.Write(msg)
|
||||
require.NoError(t, err, "write to client")
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := client.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("client.Read: %v", err)
|
||||
}
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
|
||||
}
|
||||
require.NoError(t, err, "read from client")
|
||||
require.Equal(t, string(msg), string(buf[:n]), "unexpected echo")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user