diff --git a/agent/pkg/agent/stream/server_flow_test.go b/agent/pkg/agent/stream/server_flow_test.go index 59a1636f..b366e34c 100644 --- a/agent/pkg/agent/stream/server_flow_test.go +++ b/agent/pkg/agent/stream/server_flow_test.go @@ -1,90 +1,25 @@ -package stream +package stream_test import ( "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" "errors" "io" - "math/big" "net" "testing" "time" + + "github.com/pion/dtls/v3" + "github.com/pion/transport/v3/udp" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/agent/pkg/agent" + "github.com/yusing/godoxy/agent/pkg/agent/stream" ) -func newSerial(t *testing.T) *big.Int { - t.Helper() - sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) - if err != nil { - t.Fatalf("rand serial: %v", err) - } - return sn -} - -func genCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) { - t.Helper() - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("GenerateKey: %v", err) - } - - tmpl := &x509.Certificate{ - SerialNumber: newSerial(t), - Subject: pkix.Name{CommonName: "stream-test-ca"}, - NotBefore: time.Now().Add(-time.Minute), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) - if err != nil { - t.Fatalf("CreateCertificate(CA): %v", err) - } - cert, err := x509.ParseCertificate(der) - if err != nil { - t.Fatalf("ParseCertificate(CA): %v", err) - } - return cert, key -} - -func genLeafCert(t *testing.T, ca *x509.Certificate, caKey *ecdsa.PrivateKey, cn string, eku x509.ExtKeyUsage) *tls.Certificate { - t.Helper() - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("GenerateKey: %v", err) - } - - tmpl := &x509.Certificate{ - SerialNumber: newSerial(t), - Subject: pkix.Name{CommonName: cn}, - NotBefore: time.Now().Add(-time.Minute), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{eku}, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - } - der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, &key.PublicKey, caKey) - if err != nil { - t.Fatalf("CreateCertificate(%s): %v", cn, err) - } - leaf, err := x509.ParseCertificate(der) - if err != nil { - t.Fatalf("ParseCertificate(%s): %v", cn, err) - } - return &tls.Certificate{Certificate: [][]byte{der}, PrivateKey: key, Leaf: leaf} -} - func startTCPEcho(t *testing.T) (addr string, closeFn func()) { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Listen tcp: %v", err) - } + require.NoError(t, err, "listen tcp") done := make(chan struct{}) go func() { @@ -110,9 +45,7 @@ func startTCPEcho(t *testing.T) (addr string, closeFn func()) { func startUDPEcho(t *testing.T) (addr string, closeFn func()) { t.Helper() pc, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Listen udp: %v", err) - } + require.NoError(t, err, "listen udp") uc := pc.(*net.UDPConn) done := make(chan struct{}) @@ -135,23 +68,27 @@ func startUDPEcho(t *testing.T) (addr string, closeFn func()) { } func TestTCPServer_FullFlow(t *testing.T) { - ca, caKey := genCA(t) - serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth) - clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth) + caPEM, srvPEM, clientPEM, err := agent.NewAgent() + require.NoError(t, err, "generate agent certs") + + caCert, err := caPEM.ToTLSCert() + require.NoError(t, err, "parse CA cert") + srvCert, err := srvPEM.ToTLSCert() + require.NoError(t, err, "parse server cert") + clientCert, err := clientPEM.ToTLSCert() + require.NoError(t, err, "parse client cert") dstAddr, closeDst := startTCPEcho(t) defer closeDst() tcpLn, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) - if err != nil { - t.Fatalf("ListenTCP: %v", err) - } + require.NoError(t, err, "listen tcp") defer tcpLn.Close() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() - srv := NewTCPServer(ctx, tcpLn, ca, serverCert) + srv := stream.NewTCPServer(ctx, tcpLn, caCert.Leaf, srvCert) errCh := make(chan error, 1) go func() { errCh <- srv.Start() }() defer func() { @@ -160,76 +97,138 @@ func TestTCPServer_FullFlow(t *testing.T) { _ = <-errCh }() - client, err := NewTCPClient(srv.Addr().String(), dstAddr, ca, clientCert) - if err != nil { - t.Fatalf("NewTCPClient: %v", err) - } + client, err := stream.NewTCPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert) + require.NoError(t, err, "create tcp client") defer client.Close() _ = client.SetDeadline(time.Now().Add(2 * time.Second)) msg := []byte("ping over tcp") - if _, err := client.Write(msg); err != nil { - t.Fatalf("client.Write: %v", err) - } + _, err = client.Write(msg) + require.NoError(t, err, "write to client") buf := make([]byte, len(msg)) - if _, err := io.ReadFull(client, buf); err != nil { - t.Fatalf("client.ReadFull: %v", err) - } - if string(buf) != string(msg) { - t.Fatalf("unexpected echo: got %q want %q", string(buf), string(msg)) - } + _, err = io.ReadFull(client, buf) + require.NoError(t, err, "read from client") + require.Equal(t, string(msg), string(buf), "unexpected echo") } -func TestUDPServer_FullFlow(t *testing.T) { - ca, caKey := genCA(t) - serverCert := genLeafCert(t, ca, caKey, "stream-server", x509.ExtKeyUsageServerAuth) - clientCert := genLeafCert(t, ca, caKey, "stream-client", x509.ExtKeyUsageClientAuth) +func TestUDPServer_RejectInvalidClient(t *testing.T) { + caPEM, srvPEM, _, err := agent.NewAgent() + require.NoError(t, err, "generate agent certs") + + caCert, err := caPEM.ToTLSCert() + require.NoError(t, err, "parse CA cert") + srvCert, err := srvPEM.ToTLSCert() + require.NoError(t, err, "parse server cert") + + // Generate a self-signed client cert that is NOT signed by the CA + _, _, invalidClientPEM, err := agent.NewAgent() + require.NoError(t, err, "generate invalid client certs") + invalidClientCert, err := invalidClientPEM.ToTLSCert() + require.NoError(t, err, "parse invalid client cert") dstAddr, closeDst := startUDPEcho(t) defer closeDst() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) defer cancel() - srv := NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, ca, serverCert) + srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert) + errCh := make(chan error, 1) + go func() { errCh <- srv.Start() }() + defer func() { + cancel() + _ = srv.Close() + _ = <-errCh + }() + + time.Sleep(100 * time.Millisecond) + + // Try to connect with a client cert from a different CA + _, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, invalidClientCert) + require.Error(t, err, "expected error when connecting with client cert from different CA") + + var handshakeErr *dtls.HandshakeError + require.ErrorAs(t, err, &handshakeErr, "expected handshake error") +} + +func TestUDPServer_RejectClientWithoutCert(t *testing.T) { + caPEM, srvPEM, _, err := agent.NewAgent() + require.NoError(t, err, "generate agent certs") + + caCert, err := caPEM.ToTLSCert() + require.NoError(t, err, "parse CA cert") + srvCert, err := srvPEM.ToTLSCert() + require.NoError(t, err, "parse server cert") + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert) + errCh := make(chan error, 1) + go func() { errCh <- srv.Start() }() + defer func() { + cancel() + _ = srv.Close() + _ = <-errCh + }() + + time.Sleep(time.Second) + + // Try to connect without any client certificate + // Create a TLS cert without a private key to simulate no client cert + emptyCert := &tls.Certificate{} + _, err = stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, emptyCert) + require.Error(t, err, "expected error when connecting without client cert") + + require.ErrorContains(t, err, "no certificate provided", "expected no cert error") +} + +func TestUDPServer_FullFlow(t *testing.T) { + caPEM, srvPEM, clientPEM, err := agent.NewAgent() + require.NoError(t, err, "generate agent certs") + + caCert, err := caPEM.ToTLSCert() + require.NoError(t, err, "parse CA cert") + srvCert, err := srvPEM.ToTLSCert() + require.NoError(t, err, "parse server cert") + clientCert, err := clientPEM.ToTLSCert() + require.NoError(t, err, "parse client cert") + + dstAddr, closeDst := startUDPEcho(t) + defer closeDst() + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + srv := stream.NewUDPServer(ctx, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, caCert.Leaf, srvCert) errCh := make(chan error, 1) go func() { errCh <- srv.Start() }() defer func() { cancel() _ = srv.Close() err := <-errCh - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, net.ErrClosed) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, udp.ErrClosedListener) { t.Logf("udp server exit: %v", err) } }() - deadline := time.Now().Add(2 * time.Second) - for srv.listener == nil { - if time.Now().After(deadline) { - t.Fatalf("udp server listener did not start") - } - time.Sleep(10 * time.Millisecond) - } + time.Sleep(100 * time.Millisecond) - client, err := NewUDPClient(srv.Addr().String(), dstAddr, ca, clientCert) - if err != nil { - t.Fatalf("NewUDPClient: %v", err) - } + client, err := stream.NewUDPClient(srv.Addr().String(), dstAddr, caCert.Leaf, clientCert) + require.NoError(t, err, "create udp client") defer client.Close() _ = client.SetDeadline(time.Now().Add(2 * time.Second)) msg := []byte("ping over udp") - if _, err := client.Write(msg); err != nil { - t.Fatalf("client.Write: %v", err) - } + _, err = client.Write(msg) + require.NoError(t, err, "write to client") buf := make([]byte, 2048) n, err := client.Read(buf) - if err != nil { - t.Fatalf("client.Read: %v", err) - } - if string(buf[:n]) != string(msg) { - t.Fatalf("unexpected echo: got %q want %q", string(buf[:n]), string(msg)) - } + require.NoError(t, err, "read from client") + require.Equal(t, string(msg), string(buf[:n]), "unexpected echo") }