mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-25 02:09:01 +02:00
test(agent/stream): refactor server flow tests to use testify and real certificate generation
- Use agent.NewAgent() for properly configured certificates matching real usage - Migrate to testify/require for assertions - Add tests for UDP server rejecting clients with invalid certificates - Use t.Context() for lifecycle management
This commit is contained in:
@@ -1,90 +1,25 @@
|
|||||||
package stream
|
package stream_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/dtls/v3"
|
||||||
|
"github.com/pion/transport/v3/udp"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/godoxy/agent/pkg/agent"
|
||||||
|
"github.com/yusing/godoxy/agent/pkg/agent/stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newSerial(t *testing.T) *big.Int {
|
|
||||||
t.Helper()
|
|
||||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("rand serial: %v", err)
|
|
||||||
}
|
|
||||||
return sn
|
|
||||||
}
|
|
||||||
|
|
||||||
func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
|
||||||
t.Helper()
|
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GenerateKey: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl := &x509.Certificate{
|
|
||||||
SerialNumber: newSerial(t),
|
|
||||||
Subject: pkix.Name{CommonName: "stream-test-ca"},
|
|
||||||
NotBefore: time.Now().Add(-time.Minute),
|
|
||||||
NotAfter: time.Now().Add(24 * time.Hour),
|
|
||||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
IsCA: true,
|
|
||||||
}
|
|
||||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateCertificate(CA): %v", err)
|
|
||||||
}
|
|
||||||
cert, err := x509.ParseCertificate(der)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParseCertificate(CA): %v", err)
|
|
||||||
}
|
|
||||||
return cert, key
|
|
||||||
}
|
|
||||||
|
|
||||||
func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate {
|
|
||||||
t.Helper()
|
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GenerateKey: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl := &x509.Certificate{
|
|
||||||
SerialNumber: newSerial(t),
|
|
||||||
Subject: pkix.Name{CommonName: cn},
|
|
||||||
NotBefore: time.Now().Add(-time.Minute),
|
|
||||||
NotAfter: time.Now().Add(24 * time.Hour),
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{eku},
|
|
||||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
||||||
}
|
|
||||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateCertificate(%s): %v", cn, err)
|
|
||||||
}
|
|
||||||
leaf, err := x509.ParseCertificate(der)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParseCertificate(%s): %v", cn, err)
|
|
||||||
}
|
|
||||||
return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
require.NoError(t, err, "listen tcp")
|
||||||
t.Fatalf("Listen tcp: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
@@ -110,9 +45,7 @@ func startTCPEcho(t *testing.T) (addr string, closeFn func()) {
|
|||||||
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||||
if err != nil {
|
require.NoError(t, err, "listen udp")
|
||||||
t.Fatalf("Listen udp: %v", err)
|
|
||||||
}
|
|
||||||
uc := pc.(*net.UDPConn)
|
uc := pc.(*net.UDPConn)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@@ -135,23 +68,27 @@ func startUDPEcho(t *testing.T) (addr string, closeFn func()) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPServer_FullFlow(t *testing.T) {
|
func TestTCPServer_FullFlow(t *testing.T) {
|
||||||
ca, caKey := genCA(t)
|
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||||
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
require.NoError(t, err, "generate agent certs")
|
||||||
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
|
||||||
|
caCert, err := caPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse CA cert")
|
||||||
|
srvCert, err := srvPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse server cert")
|
||||||
|
clientCert, err := clientPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse client cert")
|
||||||
|
|
||||||
dstAddr, closeDst := startTCPEcho(t)
|
dstAddr, closeDst := startTCPEcho(t)
|
||||||
defer closeDst()
|
defer closeDst()
|
||||||
|
|
||||||
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||||
if err != nil {
|
require.NoError(t, err, "listen tcp")
|
||||||
t.Fatalf("ListenTCP: %v", err)
|
|
||||||
}
|
|
||||||
defer tcpLn.Close()
|
defer tcpLn.Close()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
srv := NewTCPServer(ctx, tcpLn, ca, serverCert)
|
srv := stream.NewTCPServer(ctx, tcpLn, caCert.Leaf, srvCert)
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() { errCh <- srv.Start() }()
|
go func() { errCh <- srv.Start() }()
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -160,76 +97,138 @@ func TestTCPServer_FullFlow(t *testing.T) {
|
|||||||
_ = <-errCh
|
_ = <-errCh
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
client, err := stream.NewTCPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert)
|
||||||
if err != nil {
|
require.NoError(t, err, "create tcp client")
|
||||||
t.Fatalf("NewTCPClient: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
msg := []byte("ping over tcp")
|
msg := []byte("ping over tcp")
|
||||||
if _, err := client.Write(msg); err != nil {
|
_, err = client.Write(msg)
|
||||||
t.Fatalf("client.Write: %v", err)
|
require.NoError(t, err, "write to client")
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, len(msg))
|
buf := make([]byte, len(msg))
|
||||||
if _, err := io.ReadFull(client, buf); err != nil {
|
_, err = io.ReadFull(client, buf)
|
||||||
t.Fatalf("client.ReadFull: %v", err)
|
require.NoError(t, err, "read from client")
|
||||||
}
|
require.Equal(t, string(msg), string(buf), "unexpected echo")
|
||||||
if string(buf) != string(msg) {
|
|
||||||
t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPServer_FullFlow(t *testing.T) {
|
func TestUDPServer_RejectInvalidClient(t *testing.T) {
|
||||||
ca, caKey := genCA(t)
|
caPEM, srvPEM, _, err := agent.NewAgent()
|
||||||
serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth)
|
require.NoError(t, err, "generate agent certs")
|
||||||
clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth)
|
|
||||||
|
caCert, err := caPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse CA cert")
|
||||||
|
srvCert, err := srvPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse server cert")
|
||||||
|
|
||||||
|
// Generate a self-signed client cert that is NOT signed by the CA
|
||||||
|
_, _, invalidClientPEM, err := agent.NewAgent()
|
||||||
|
require.NoError(t, err, "generate invalid client certs")
|
||||||
|
invalidClientCert, err := invalidClientPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse invalid client cert")
|
||||||
|
|
||||||
dstAddr, closeDst := startUDPEcho(t)
|
dstAddr, closeDst := startUDPEcho(t)
|
||||||
defer closeDst()
|
defer closeDst()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert)
|
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() { errCh <- srv.Start() }()
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
_ = srv.Close()
|
||||||
|
_ = <-errCh
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Try to connect with a client cert from a different CA
|
||||||
|
_, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, invalidClientCert)
|
||||||
|
require.Error(t, err, "expected error when connecting with client cert from different CA")
|
||||||
|
|
||||||
|
var handshakeErr *dtls.HandshakeError
|
||||||
|
require.ErrorAs(t, err, &handshakeErr, "expected handshake error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPServer_RejectClientWithoutCert(t *testing.T) {
|
||||||
|
caPEM, srvPEM, _, err := agent.NewAgent()
|
||||||
|
require.NoError(t, err, "generate agent certs")
|
||||||
|
|
||||||
|
caCert, err := caPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse CA cert")
|
||||||
|
srvCert, err := srvPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse server cert")
|
||||||
|
|
||||||
|
dstAddr, closeDst := startUDPEcho(t)
|
||||||
|
defer closeDst()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() { errCh <- srv.Start() }()
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
_ = srv.Close()
|
||||||
|
_ = <-errCh
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
// Try to connect without any client certificate
|
||||||
|
// Create a TLS cert without a private key to simulate no client cert
|
||||||
|
emptyCert := &tls.Certificate{}
|
||||||
|
_, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, emptyCert)
|
||||||
|
require.Error(t, err, "expected error when connecting without client cert")
|
||||||
|
|
||||||
|
require.ErrorContains(t, err, "no certificate provided", "expected no cert error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPServer_FullFlow(t *testing.T) {
|
||||||
|
caPEM, srvPEM, clientPEM, err := agent.NewAgent()
|
||||||
|
require.NoError(t, err, "generate agent certs")
|
||||||
|
|
||||||
|
caCert, err := caPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse CA cert")
|
||||||
|
srvCert, err := srvPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse server cert")
|
||||||
|
clientCert, err := clientPEM.ToTLSCert()
|
||||||
|
require.NoError(t, err, "parse client cert")
|
||||||
|
|
||||||
|
dstAddr, closeDst := startUDPEcho(t)
|
||||||
|
defer closeDst()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert)
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() { errCh <- srv.Start() }()
|
go func() { errCh <- srv.Start() }()
|
||||||
defer func() {
|
defer func() {
|
||||||
cancel()
|
cancel()
|
||||||
_ = srv.Close()
|
_ = srv.Close()
|
||||||
err := <-errCh
|
err := <-errCh
|
||||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) {
|
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, udp.ErrClosedListener) {
|
||||||
t.Logf("udp server exit: %v", err)
|
t.Logf("udp server exit: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
deadline := time.Now().Add(2 * time.Second)
|
time.Sleep(100 * time.Millisecond)
|
||||||
for srv.listener == nil {
|
|
||||||
if time.Now().After(deadline) {
|
|
||||||
t.Fatalf("udp server listener did not start")
|
|
||||||
}
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert)
|
client, err := stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert)
|
||||||
if err != nil {
|
require.NoError(t, err, "create udp client")
|
||||||
t.Fatalf("NewUDPClient: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
_ = client.SetDeadline(time.Now().Add(2 * time.Second))
|
||||||
msg := []byte("ping over udp")
|
msg := []byte("ping over udp")
|
||||||
if _, err := client.Write(msg); err != nil {
|
_, err = client.Write(msg)
|
||||||
t.Fatalf("client.Write: %v", err)
|
require.NoError(t, err, "write to client")
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, 2048)
|
buf := make([]byte, 2048)
|
||||||
n, err := client.Read(buf)
|
n, err := client.Read(buf)
|
||||||
if err != nil {
|
require.NoError(t, err, "read from client")
|
||||||
t.Fatalf("client.Read: %v", err)
|
require.Equal(t, string(msg), string(buf[:n]), "unexpected echo")
|
||||||
}
|
|
||||||
if string(buf[:n]) != string(msg) {
|
|
||||||
t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user