From 00c41b642227fdac7ebe174d34679073babafa6c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 17 Mar 2026 14:32:35 +0000 Subject: [PATCH] hscontrol/servertest: add race, stress, and poll race tests Add three test files designed to stress the control plane under concurrent and adversarial conditions: - race_test.go: 14 tests exercising concurrent mutations, session replacement, batcher contention, NodeStore access, and map response delivery during disconnect. All pass the Go race detector. - poll_race_test.go: 8 tests targeting the poll.go grace period interleaving. These confirm a logical TOCTOU race: when a node disconnects and reconnects within the grace period, the old session's deferred Disconnect() can overwrite the new session's Connect(), leaving IsOnline=false despite an active poll session. - stress_test.go: sustained churn, rapid mutations, rolling replacement, data integrity checks under load, and verification that rapid reconnects do not leak false-offline notifications. Known failing tests (grace period TOCTOU race): - server_state_online_after_reconnect_within_grace - update_history_no_false_offline - rapid_reconnect_peer_never_sees_offline --- hscontrol/servertest/poll_race_test.go | 375 +++++++++++++ hscontrol/servertest/race_test.go | 702 +++++++++++++++++++++++ hscontrol/servertest/stress_test.go | 740 +++++++++++++++++++++++++ 3 files changed, 1817 insertions(+) create mode 100644 hscontrol/servertest/poll_race_test.go create mode 100644 hscontrol/servertest/race_test.go create mode 100644 hscontrol/servertest/stress_test.go diff --git a/hscontrol/servertest/poll_race_test.go b/hscontrol/servertest/poll_race_test.go new file mode 100644 index 00000000..500bf3fb --- /dev/null +++ b/hscontrol/servertest/poll_race_test.go @@ -0,0 +1,375 @@ +package servertest_test + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/netmap" +) + +// TestPollRace targets logical race conditions specifically in the +// poll.go session lifecycle and the batcher's handling of concurrent +// sessions for the same node. + +func TestPollRace(t *testing.T) { + t.Parallel() + + // The core race: when a node disconnects, poll.go starts a + // grace period goroutine (10s ticker loop). If the node + // reconnects during this period, the new session calls + // Connect() to mark the node online. But the old grace period + // goroutine is still running and may call Disconnect() AFTER + // the new Connect(), setting IsOnline=false incorrectly. + // + // This test verifies the exact symptom: after reconnect within + // the grace period, the server-side node state should be online. + t.Run("server_state_online_after_reconnect_within_grace", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "gracerace-user") + + c1 := servertest.NewClient(t, srv, "gracerace-node1", + servertest.WithUser(user)) + servertest.NewClient(t, srv, "gracerace-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID := findNodeID(t, srv, "gracerace-node1") + + // Disconnect and immediately reconnect. + c1.Disconnect(t) + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // Check server-side state immediately. + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + isOnline, known := nv.IsOnline().GetOk() + assert.True(t, known, + "server should know online status after reconnect") + assert.True(t, isOnline, + "server should show node as online after reconnect within grace period") + }) + + // Same test but wait a few seconds after reconnect. The old + // grace period goroutine may still be running. + t.Run("server_state_online_2s_after_reconnect", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "gracewait-user") + + c1 := servertest.NewClient(t, srv, "gracewait-node1", + servertest.WithUser(user)) + servertest.NewClient(t, srv, "gracewait-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID := findNodeID(t, srv, "gracewait-node1") + + c1.Disconnect(t) + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // Wait 2 seconds for the old grace period to potentially fire. + timer := time.NewTimer(2 * time.Second) + defer timer.Stop() + + <-timer.C + + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + isOnline, known := nv.IsOnline().GetOk() + assert.True(t, known, + "server should know online status 2s after reconnect") + assert.True(t, isOnline, + "server should STILL show node as online 2s after reconnect (grace period goroutine should not overwrite)") + }) + + // Wait the full grace period (10s) after reconnect. The old + // grace period goroutine should have checked IsConnected + // and found the node connected, so should NOT have called + // Disconnect(). + t.Run("server_state_online_12s_after_reconnect", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "gracelong-user") + + c1 := servertest.NewClient(t, srv, "gracelong-node1", + servertest.WithUser(user)) + servertest.NewClient(t, srv, "gracelong-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID := findNodeID(t, srv, "gracelong-node1") + + c1.Disconnect(t) + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // Wait past the full grace period. + timer := time.NewTimer(12 * time.Second) + defer timer.Stop() + + <-timer.C + + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + isOnline, known := nv.IsOnline().GetOk() + assert.True(t, known, + "server should know online status after grace period expires") + assert.True(t, isOnline, + "server should show node as online after grace period -- the reconnect should have prevented the Disconnect() call") + }) + + // Peer's view: after rapid reconnect, the peer should see + // the reconnected node as online, not offline. + t.Run("peer_sees_online_after_rapid_reconnect", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "peeronl-user") + + c1 := servertest.NewClient(t, srv, "peeronl-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "peeronl-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + // Wait for online status to propagate first. + c2.WaitForCondition(t, "peer initially online", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "peeronl-node1" { + isOnline, known := p.Online().GetOk() + + return known && isOnline + } + } + + return false + }) + + // Rapid reconnect. + c1.Disconnect(t) + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // Wait 3 seconds for any stale updates to propagate. + timer := time.NewTimer(3 * time.Second) + defer timer.Stop() + + <-timer.C + + // At this point, c2 should see c1 as ONLINE. + // If the grace period race is present, c2 might + // temporarily see offline and then online again. + nm := c2.Netmap() + require.NotNil(t, nm) + + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "peeronl-node1" { + isOnline, known := p.Online().GetOk() + assert.True(t, known, + "peer online status should be known") + assert.True(t, isOnline, + "peer should be online 3s after rapid reconnect") + } + } + }) + + // The batcher's IsConnected check: when the grace period + // goroutine calls IsConnected(), it should return true if + // a new session has been added for the same node. + t.Run("batcher_knows_reconnected_during_grace", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "batchknow-user") + + c1 := servertest.NewClient(t, srv, "batchknow-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "batchknow-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + c2.WaitForPeers(t, 1, 10*time.Second) + + // Disconnect and reconnect. + c1.Disconnect(t) + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // The mesh should be complete with both nodes seeing + // each other as online. + c2.WaitForCondition(t, "c1 online after reconnect", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "batchknow-node1" { + isOnline, known := p.Online().GetOk() + + return known && isOnline + } + } + + return false + }) + }) + + // Test that the update history shows a clean transition: + // the peer should never appear in the history with + // online=false if the reconnect was fast enough. + t.Run("update_history_no_false_offline", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "histroff-user") + + c1 := servertest.NewClient(t, srv, "histroff-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "histroff-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + c2.WaitForPeers(t, 1, 10*time.Second) + + // Record c2's update count before reconnect. + countBefore := c2.UpdateCount() + + // Rapid reconnect. + c1.Disconnect(t) + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // Wait a moment for all updates to arrive. + timer := time.NewTimer(3 * time.Second) + defer timer.Stop() + + <-timer.C + + // Check c2's update history for any false offline. + history := c2.History() + sawOffline := false + + for i := countBefore; i < len(history); i++ { + nm := history[i] + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "histroff-node1" { + isOnline, known := p.Online().GetOk() + if known && !isOnline { + sawOffline = true + + t.Logf("update %d: saw peer offline (should not happen during rapid reconnect)", i) + } + } + } + } + + assert.False(t, sawOffline, + "peer should never appear offline in update history during rapid reconnect") + }) + + // Multiple rapid reconnects should not cause the peer count + // to be wrong. After N reconnects, the reconnecting node should + // still see the right number of peers and vice versa. + t.Run("peer_count_stable_after_many_reconnects", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "peercount-user") + + const n = 4 + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("peercount-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // Reconnect client 0 five times. + for range 5 { + clients[0].Disconnect(t) + clients[0].Reconnect(t) + } + + // All clients should still see n-1 peers. + for _, c := range clients { + c.WaitForPeers(t, n-1, 15*time.Second) + } + + servertest.AssertMeshComplete(t, clients) + }) + + // Route approval during reconnect: approve a route while a + // node is reconnecting. Both the reconnecting node and peers + // should eventually see the correct state. + t.Run("route_approval_during_reconnect", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "rtrecon-user") + + c1 := servertest.NewClient(t, srv, "rtrecon-node1", + servertest.WithUser(user)) + servertest.NewClient(t, srv, "rtrecon-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID1 := findNodeID(t, srv, "rtrecon-node1") + + // Disconnect c1. + c1.Disconnect(t) + + // While c1 is disconnected, approve a route for it. + route := netip.MustParsePrefix("10.55.0.0/24") + _, routeChange, err := srv.State().SetApprovedRoutes( + nodeID1, []netip.Prefix{route}) + require.NoError(t, err) + srv.App.Change(routeChange) + + // Reconnect c1. + c1.Reconnect(t) + c1.WaitForPeers(t, 1, 15*time.Second) + + // c1 should receive a self-update with the new route. + c1.WaitForCondition(t, "self-update after route+reconnect", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return nm != nil && nm.SelfNode.Valid() + }) + + // Verify server state is correct. + nv, ok := srv.State().GetNodeByID(nodeID1) + require.True(t, ok) + + routes := nv.ApprovedRoutes().AsSlice() + assert.Contains(t, routes, route, + "approved route should persist through reconnect") + }) +} diff --git a/hscontrol/servertest/race_test.go b/hscontrol/servertest/race_test.go new file mode 100644 index 00000000..81f30f23 --- /dev/null +++ b/hscontrol/servertest/race_test.go @@ -0,0 +1,702 @@ +package servertest_test + +import ( + "context" + "fmt" + "net/netip" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +// TestRace contains tests designed to trigger race conditions in +// the control plane. Run with -race to detect data races. +// These tests stress concurrent access patterns in poll.go, +// the batcher, the NodeStore, and the mapper. + +// TestRacePollSessionReplacement tests the race between an old +// poll session's deferred cleanup and a new session starting. +func TestRacePollSessionReplacement(t *testing.T) { + t.Parallel() + + // Rapidly replace the poll session by doing immediate + // disconnect+reconnect. This races the old session's + // deferred cleanup (RemoveNode, Disconnect, grace period + // goroutine) with the new session's setup (AddNode, Connect, + // initial map send). + t.Run("immediate_session_replace_10x", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "sessrepl-user") + + c1 := servertest.NewClient(t, srv, "sessrepl-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "sessrepl-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + for range 10 { + c1.Disconnect(t) + // Reconnect immediately -- no sleep. This creates the + // tightest possible race between old session cleanup + // and new session setup. + c1.Reconnect(t) + } + + c1.WaitForPeers(t, 1, 15*time.Second) + c2.WaitForPeers(t, 1, 15*time.Second) + + // Both clients should still have a consistent view. + servertest.AssertMeshComplete(t, + []*servertest.TestClient{c1, c2}) + }) + + // Two nodes rapidly reconnecting simultaneously. + t.Run("two_nodes_reconnect_simultaneously", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "simrecon-user") + + c1 := servertest.NewClient(t, srv, "simrecon-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "simrecon-node2", + servertest.WithUser(user)) + c3 := servertest.NewClient(t, srv, "simrecon-node3", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 2, 15*time.Second) + + for range 5 { + // Both disconnect at the same time. + c1.Disconnect(t) + c2.Disconnect(t) + + // Both reconnect at the same time. + c1.Reconnect(t) + c2.Reconnect(t) + } + + // Mesh should recover. + c1.WaitForPeers(t, 2, 15*time.Second) + c2.WaitForPeers(t, 2, 15*time.Second) + c3.WaitForPeers(t, 2, 15*time.Second) + + servertest.AssertConsistentState(t, + []*servertest.TestClient{c1, c2, c3}) + }) +} + +// TestRaceConcurrentServerMutations tests concurrent mutations +// on the server side while nodes are connected and polling. +func TestRaceConcurrentServerMutations(t *testing.T) { + t.Parallel() + + // Rename, route approval, and policy change all happening + // concurrently while nodes are connected. + t.Run("concurrent_rename_route_policy", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "conmut-user") + + c1 := servertest.NewClient(t, srv, "conmut-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "conmut-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID1 := findNodeID(t, srv, "conmut-node1") + + var wg sync.WaitGroup + + // Concurrent renames. + + wg.Go(func() { + for i := range 5 { + name := fmt.Sprintf("conmut-renamed-%d", i) + srv.State().RenameNode(nodeID1, name) //nolint:errcheck + } + }) + + // Concurrent route changes. + + wg.Go(func() { + for i := range 5 { + route := netip.MustParsePrefix( + fmt.Sprintf("10.%d.0.0/24", i)) + _, c, _ := srv.State().SetApprovedRoutes( + nodeID1, []netip.Prefix{route}) + srv.App.Change(c) + } + }) + + // Concurrent policy changes. + + wg.Go(func() { + for range 5 { + changed, err := srv.State().SetPolicy([]byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + if err == nil && changed { + changes, err := srv.State().ReloadPolicy() + if err == nil { + srv.App.Change(changes...) + } + } + } + }) + + wg.Wait() + + // Server should not have panicked, and clients should still + // be getting updates. + c2.WaitForCondition(t, "still receiving updates", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return nm != nil && len(nm.Peers) > 0 + }) + }) + + // Delete a node while simultaneously changing policy. + t.Run("delete_during_policy_change", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "delpol-user") + + c1 := servertest.NewClient(t, srv, "delpol-node1", + servertest.WithUser(user)) + servertest.NewClient(t, srv, "delpol-node2", + servertest.WithUser(user)) + c3 := servertest.NewClient(t, srv, "delpol-node3", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 2, 15*time.Second) + + nodeID2 := findNodeID(t, srv, "delpol-node2") + nv2, ok := srv.State().GetNodeByID(nodeID2) + require.True(t, ok) + + var wg sync.WaitGroup + + // Delete node2 and change policy simultaneously. + + wg.Go(func() { + delChange, err := srv.State().DeleteNode(nv2) + if err == nil { + srv.App.Change(delChange) + } + }) + + wg.Go(func() { + changed, err := srv.State().SetPolicy([]byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + if err == nil && changed { + changes, err := srv.State().ReloadPolicy() + if err == nil { + srv.App.Change(changes...) + } + } + }) + + wg.Wait() + + // c1 and c3 should converge -- both should see each other + // but not node2. + c1.WaitForCondition(t, "node2 gone from c1", 10*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "delpol-node2" { + return false + } + } + + return true + }) + + c3.WaitForCondition(t, "node2 gone from c3", 10*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "delpol-node2" { + return false + } + } + + return true + }) + }) + + // Many clients sending hostinfo updates simultaneously. + t.Run("concurrent_hostinfo_updates", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "chiupd-user") + + const n = 6 + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("chiupd-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // All clients update their hostinfo simultaneously. + var wg sync.WaitGroup + for i, c := range clients { + wg.Go(func() { + c.Direct().SetHostinfo(&tailcfg.Hostinfo{ + BackendLogID: fmt.Sprintf("servertest-chiupd-%d", i), + Hostname: fmt.Sprintf("chiupd-%d", i), + OS: fmt.Sprintf("ConcurrentOS-%d", i), + }) + + ctx, cancel := context.WithTimeout( + context.Background(), 5*time.Second) + defer cancel() + + _ = c.Direct().SendUpdate(ctx) + }) + } + + wg.Wait() + + // Each client should eventually see all others' updated OS. + for _, observer := range clients { + observer.WaitForCondition(t, "all OS updates visible", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + seenOS := 0 + + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.OS() != "" && + len(hi.OS()) > 12 { // "ConcurrentOS-" prefix + seenOS++ + } + } + // Should see n-1 peers with updated OS. + return seenOS >= n-1 + }) + } + }) +} + +// TestRaceConnectDuringGracePeriod tests connecting a new node +// while another node is in its grace period. +func TestRaceConnectDuringGracePeriod(t *testing.T) { + t.Parallel() + + // A node disconnects, and during the 10-second grace period + // a new node joins. The new node should see the disconnecting + // node as a peer (it hasn't been removed yet). + t.Run("new_node_during_grace_period", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "grace-user") + + c1 := servertest.NewClient(t, srv, "grace-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "grace-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + // Disconnect c1 -- starts grace period. + c1.Disconnect(t) + + // Immediately add a new node while c1 is in grace period. + c3 := servertest.NewClient(t, srv, "grace-node3", + servertest.WithUser(user)) + + // c3 should see c2 for sure. Whether it sees c1 depends on + // whether c1's grace period has expired. Either way it should + // not panic or hang. + c3.WaitForPeers(t, 1, 15*time.Second) + + // c2 should see c3. + c2.WaitForCondition(t, "c2 sees c3", 10*time.Second, + func(nm *netmap.NetworkMap) bool { + _, found := c2.PeerByName("grace-node3") + + return found + }) + }) + + // Multiple nodes disconnect and new ones connect simultaneously, + // creating a mixed grace-period race. + t.Run("multi_disconnect_multi_connect_race", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "mixgrace-user") + + const n = 4 + + originals := make([]*servertest.TestClient, n) + for i := range n { + originals[i] = servertest.NewClient(t, srv, + fmt.Sprintf("mixgrace-orig-%d", i), + servertest.WithUser(user)) + } + + for _, c := range originals { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // Disconnect half. + for i := range n / 2 { + originals[i].Disconnect(t) + } + + // Add new nodes during grace period. + replacements := make([]*servertest.TestClient, n/2) + for i := range n / 2 { + replacements[i] = servertest.NewClient(t, srv, + fmt.Sprintf("mixgrace-new-%d", i), + servertest.WithUser(user)) + } + + // The surviving originals + new nodes should form a mesh. + surviving := originals[n/2:] + allActive := append(surviving, replacements...) + + for _, c := range allActive { + c.WaitForPeers(t, len(allActive)-1, 30*time.Second) + } + + servertest.AssertConsistentState(t, allActive) + }) +} + +// TestRaceBatcherContention tests race conditions in the batcher +// when many changes arrive simultaneously. +func TestRaceBatcherContention(t *testing.T) { + t.Parallel() + + // Many nodes connecting at the same time generates many + // concurrent Change() calls. The batcher must handle this + // without dropping updates or panicking. + t.Run("many_simultaneous_connects", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "batchcon-user") + + const n = 8 + + clients := make([]*servertest.TestClient, n) + + // Create all clients as fast as possible. + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("batchcon-%d", i), + servertest.WithUser(user)) + } + + // All should converge. + for _, c := range clients { + c.WaitForPeers(t, n-1, 30*time.Second) + } + + servertest.AssertMeshComplete(t, clients) + }) + + // Rapid connect + disconnect + connect of different nodes + // generates interleaved AddNode/RemoveNode/AddNode in the + // batcher. + t.Run("interleaved_add_remove_add", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "intleave-user") + + observer := servertest.NewClient(t, srv, "intleave-obs", + servertest.WithUser(user)) + observer.WaitForUpdate(t, 10*time.Second) + + // Rapidly create, disconnect, create nodes. + for i := range 5 { + c := servertest.NewClient(t, srv, + fmt.Sprintf("intleave-temp-%d", i), + servertest.WithUser(user)) + c.WaitForUpdate(t, 10*time.Second) + c.Disconnect(t) + } + + // Add a final persistent node. + final := servertest.NewClient(t, srv, "intleave-final", + servertest.WithUser(user)) + + // Observer should see at least the final node. + observer.WaitForCondition(t, "sees final node", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + _, found := observer.PeerByName("intleave-final") + + return found + }) + + // Final should see observer. + final.WaitForCondition(t, "sees observer", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + _, found := final.PeerByName("intleave-obs") + + return found + }) + }) + + // Route changes and node connect happening at the same time. + t.Run("route_change_during_connect", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "rtcon-user") + + c1 := servertest.NewClient(t, srv, "rtcon-node1", + servertest.WithUser(user)) + c1.WaitForUpdate(t, 10*time.Second) + + nodeID1 := findNodeID(t, srv, "rtcon-node1") + + // Approve routes while c2 is connecting. + var wg sync.WaitGroup + + wg.Go(func() { + route := netip.MustParsePrefix("10.88.0.0/24") + _, c, _ := srv.State().SetApprovedRoutes( + nodeID1, []netip.Prefix{route}) + srv.App.Change(c) + }) + + wg.Add(1) + + var c2 *servertest.TestClient + + go func() { + defer wg.Done() + + c2 = servertest.NewClient(t, srv, "rtcon-node2", + servertest.WithUser(user)) + }() + + wg.Wait() + + // Both should converge. + c1.WaitForPeers(t, 1, 10*time.Second) + c2.WaitForPeers(t, 1, 10*time.Second) + }) +} + +// TestRaceMapResponseDuringDisconnect tests what happens when a +// map response is being written while the session is being torn down. +func TestRaceMapResponseDuringDisconnect(t *testing.T) { + t.Parallel() + + // Generate a lot of updates for a node, then disconnect it + // while updates are still being delivered. The disconnect + // should be clean -- no panics, no hangs. + t.Run("disconnect_during_update_storm", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "updstorm-user") + + victim := servertest.NewClient(t, srv, "updstorm-victim", + servertest.WithUser(user)) + victim.WaitForUpdate(t, 10*time.Second) + + // Create several nodes to generate connection updates. + for i := range 5 { + servertest.NewClient(t, srv, + fmt.Sprintf("updstorm-gen-%d", i), + servertest.WithUser(user)) + } + + // While updates are flying, disconnect the victim. + victim.Disconnect(t) + + // No panic, no hang = success. The other nodes should + // still be working. + remaining := servertest.NewClient(t, srv, "updstorm-check", + servertest.WithUser(user)) + remaining.WaitForPeers(t, 5, 15*time.Second) + }) + + // Send a hostinfo update and disconnect almost simultaneously. + t.Run("hostinfo_update_then_immediate_disconnect", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "hidc-user") + + c1 := servertest.NewClient(t, srv, "hidc-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "hidc-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + // Fire a hostinfo update. + c1.Direct().SetHostinfo(&tailcfg.Hostinfo{ + BackendLogID: "servertest-hidc-node1", + Hostname: "hidc-node1", + OS: "DisconnectOS", + }) + + ctx, cancel := context.WithTimeout( + context.Background(), 5*time.Second) + defer cancel() + + _ = c1.Direct().SendUpdate(ctx) + + // Immediately disconnect. + c1.Disconnect(t) + + // c2 might or might not see the OS update, but it should + // not panic or hang. Verify c2 is still functional. + c2.WaitForCondition(t, "c2 still functional", 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return nm != nil + }) + }) +} + +// TestRaceNodeStoreContention tests concurrent access to the NodeStore. +func TestRaceNodeStoreContention(t *testing.T) { + t.Parallel() + + // Many GetNodeByID calls while nodes are connecting and + // disconnecting. This tests the NodeStore's read/write locking. + t.Run("concurrent_reads_during_mutations", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "nsrace-user") + + const n = 4 + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("nsrace-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 15*time.Second) + } + + nodeIDs := make([]types.NodeID, n) + for i := range n { + nodeIDs[i] = findNodeID(t, srv, + fmt.Sprintf("nsrace-%d", i)) + } + + // Concurrently: read nodes, disconnect/reconnect, read again. + var wg sync.WaitGroup + + // Readers. + for range 4 { + wg.Go(func() { + for range 100 { + for _, id := range nodeIDs { + nv, ok := srv.State().GetNodeByID(id) + if ok { + _ = nv.Hostname() + _ = nv.IsOnline() + _ = nv.ApprovedRoutes() + } + } + } + }) + } + + // Mutators: disconnect and reconnect nodes. + for i := range 2 { + wg.Go(func() { + clients[i].Disconnect(t) + clients[i].Reconnect(t) + }) + } + + wg.Wait() + + // Everything should still be working. + for i := 2; i < n; i++ { + _, ok := srv.State().GetNodeByID(nodeIDs[i]) + assert.True(t, ok, + "node %d should still be in NodeStore", i) + } + }) + + // ListNodes while nodes are being added and removed. + t.Run("list_nodes_during_churn", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "listrace-user") + + var wg sync.WaitGroup + + // Continuously list nodes. + stop := make(chan struct{}) + + wg.Go(func() { + for { + select { + case <-stop: + return + default: + nodes := srv.State().ListNodes() + // Access each node to exercise read paths. + for i := range nodes.Len() { + n := nodes.At(i) + _ = n.Hostname() + _ = n.IPs() + } + } + } + }) + + // Add and remove nodes. + for i := range 5 { + c := servertest.NewClient(t, srv, + fmt.Sprintf("listrace-%d", i), + servertest.WithUser(user)) + c.WaitForUpdate(t, 10*time.Second) + + if i%2 == 0 { + c.Disconnect(t) + } + } + + close(stop) + wg.Wait() + }) +} diff --git a/hscontrol/servertest/stress_test.go b/hscontrol/servertest/stress_test.go new file mode 100644 index 00000000..ad22a3ac --- /dev/null +++ b/hscontrol/servertest/stress_test.go @@ -0,0 +1,740 @@ +package servertest_test + +import ( + "context" + "fmt" + "net/netip" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +// TestStress hammers the control plane with concurrent operations, +// rapid mutations, and edge cases to surface race conditions and +// consistency bugs. + +// TestStressConnectDisconnect exercises rapid connect/disconnect +// patterns that stress the grace period, batcher, and NodeStore. +func TestStressConnectDisconnect(t *testing.T) { + t.Parallel() + + // A node that disconnects and reconnects faster than the + // grace period should never cause a second node to see + // the first node as offline. + t.Run("rapid_reconnect_peer_never_sees_offline", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + // Wait for both to be online. + h.Client(0).WaitForCondition(t, "peer online", 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + isOnline, known := p.Online().GetOk() + if known && isOnline { + return true + } + } + + return false + }) + + // Do 10 rapid reconnects and check that client 0 never + // sees client 1 as offline during the process. + sawOffline := false + + var offlineMu sync.Mutex + + // Monitor client 0's view of client 1 in the background. + stopMonitor := make(chan struct{}) + monitorDone := make(chan struct{}) + + go func() { + defer close(monitorDone) + + for { + select { + case <-stopMonitor: + return + default: + } + + nm := h.Client(0).Netmap() + if nm == nil { + continue + } + + for _, p := range nm.Peers { + isOnline, known := p.Online().GetOk() + if known && !isOnline { + offlineMu.Lock() + sawOffline = true + offlineMu.Unlock() + } + } + } + }() + + for range 10 { + h.Client(1).Disconnect(t) + h.Client(1).Reconnect(t) + } + + // Give the monitor a moment to catch up, then stop it. + h.Client(0).WaitForPeers(t, 1, 10*time.Second) + close(stopMonitor) + <-monitorDone + + offlineMu.Lock() + defer offlineMu.Unlock() + + assert.False(t, sawOffline, + "peer should never appear offline during rapid reconnect cycles") + }) + + // Delete a node while it has an active poll session. The poll + // session should terminate cleanly and other peers should see + // the node disappear. + t.Run("delete_node_during_active_poll", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "delpoll-user") + + c1 := servertest.NewClient(t, srv, "delpoll-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "delpoll-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + c2.WaitForPeers(t, 1, 10*time.Second) + + // Delete c1 while it's actively polling. + nodeID := findNodeID(t, srv, "delpoll-node1") + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + deleteChange, err := srv.State().DeleteNode(nv) + require.NoError(t, err) + srv.App.Change(deleteChange) + + // c2 should see c1 disappear. + c2.WaitForCondition(t, "deleted node gone", 10*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "delpoll-node1" { + return false + } + } + + return true + }) + + assert.Empty(t, c2.Peers(), + "c2 should have no peers after c1 is deleted") + }) + + // Connect many nodes, then disconnect half simultaneously. + // The remaining half should converge to see only each other. + t.Run("disconnect_half_remaining_converge", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "halfdisc-user") + + const total = 6 + + clients := make([]*servertest.TestClient, total) + for i := range total { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("halfdisc-%d", i), + servertest.WithUser(user)) + } + + // Wait for full mesh. + for _, c := range clients { + c.WaitForPeers(t, total-1, 30*time.Second) + } + + // Disconnect the first half. + for i := range total / 2 { + clients[i].Disconnect(t) + } + + // The remaining half should eventually converge. + remaining := clients[total/2:] + + for _, c := range remaining { + c.WaitForCondition(t, "remaining converge", + 30*time.Second, + func(nm *netmap.NetworkMap) bool { + // Should see at least the other remaining peers. + onlinePeers := 0 + + for _, p := range nm.Peers { + isOnline, known := p.Online().GetOk() + if known && isOnline { + onlinePeers++ + } + } + // Remaining peers minus self = total/2 - 1 + return onlinePeers >= len(remaining)-1 + }) + } + }) +} + +// TestStressStateMutations tests rapid server-side state changes. +func TestStressStateMutations(t *testing.T) { + t.Parallel() + + // Rapidly approve and remove routes. The final state should + // be consistent. + t.Run("rapid_route_changes_final_state_correct", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "rapidrt-user") + + c1 := servertest.NewClient(t, srv, "rapidrt-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "rapidrt-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID := findNodeID(t, srv, "rapidrt-node1") + + // Rapidly change routes 10 times. + for i := range 10 { + route := netip.MustParsePrefix( + fmt.Sprintf("10.%d.0.0/24", i)) + + _, routeChange, err := srv.State().SetApprovedRoutes( + nodeID, []netip.Prefix{route}) + require.NoError(t, err) + srv.App.Change(routeChange) + } + + // Final route should be 10.9.0.0/24. + // Verify server state is correct. + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + finalRoutes := nv.ApprovedRoutes().AsSlice() + expected := netip.MustParsePrefix("10.9.0.0/24") + assert.Contains(t, finalRoutes, expected, + "final approved routes should contain the last route set") + assert.Len(t, finalRoutes, 1, + "should have exactly 1 approved route (the last one set)") + + // c2 should eventually see the update. + c2.WaitForCondition(t, "final route update received", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return c2.UpdateCount() > 2 + }) + }) + + // Rename a node multiple times rapidly. The final name should + // be correct in the server state and visible to peers. + t.Run("rapid_rename_final_state_correct", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "rapidname-user") + + c1 := servertest.NewClient(t, srv, "rapidname-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "rapidname-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + nodeID := findNodeID(t, srv, "rapidname-node1") + + // Rename 5 times rapidly. + var finalName string + for i := range 5 { + finalName = fmt.Sprintf("renamed-%d", i) + + _, renameChange, err := srv.State().RenameNode(nodeID, finalName) + require.NoError(t, err) + srv.App.Change(renameChange) + } + + // Server state should have the final name. + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + assert.Equal(t, finalName, nv.AsStruct().GivenName, + "server should have the final renamed value") + + // c2 should see the final name. + c2.WaitForCondition(t, "final name visible", 10*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + if p.Name() == finalName { + return true + } + } + + return false + }) + }) + + // Multiple policy changes in rapid succession. The final + // policy should be applied correctly. + t.Run("rapid_policy_changes", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "rapidpol-user") + + c1 := servertest.NewClient(t, srv, "rapidpol-node1", + servertest.WithUser(user)) + c1.WaitForUpdate(t, 10*time.Second) + + countBefore := c1.UpdateCount() + + // Change policy 5 times rapidly. + for range 5 { + changed, err := srv.State().SetPolicy([]byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + require.NoError(t, err) + + if changed { + changes, err := srv.State().ReloadPolicy() + require.NoError(t, err) + srv.App.Change(changes...) + } + } + + // Client should have received at least some updates. + c1.WaitForCondition(t, "updates after policy changes", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return c1.UpdateCount() > countBefore + }) + }) +} + +// TestStressDataIntegrity verifies data correctness under various conditions. +func TestStressDataIntegrity(t *testing.T) { + t.Parallel() + + // Every node's self-addresses should match what peers see + // as that node's Addresses. + t.Run("self_addresses_match_peer_view", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "addrmatch-user") + + const n = 5 + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("addrmatch-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // Build a map of hostname -> self-addresses. + selfAddrs := make(map[string][]netip.Prefix) + + for _, c := range clients { + nm := c.Netmap() + require.NotNil(t, nm) + require.True(t, nm.SelfNode.Valid()) + + addrs := make([]netip.Prefix, 0, nm.SelfNode.Addresses().Len()) + for i := range nm.SelfNode.Addresses().Len() { + addrs = append(addrs, nm.SelfNode.Addresses().At(i)) + } + + selfAddrs[c.Name] = addrs + } + + // Now verify each client's peers have the same addresses + // as those peers' self-view. + for _, c := range clients { + nm := c.Netmap() + require.NotNil(t, nm) + + for _, peer := range nm.Peers { + hi := peer.Hostinfo() + if !hi.Valid() { + continue + } + + peerName := hi.Hostname() + + expected, ok := selfAddrs[peerName] + if !ok { + continue + } + + peerAddrs := make([]netip.Prefix, 0, peer.Addresses().Len()) + for i := range peer.Addresses().Len() { + peerAddrs = append(peerAddrs, peer.Addresses().At(i)) + } + + assert.Equal(t, expected, peerAddrs, + "client %s: peer %s addresses should match that peer's self-view", + c.Name, peerName) + } + } + }) + + // After mesh formation, no peer should have Expired=true. + t.Run("no_peers_expired_after_mesh_formation", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm) + + assert.False(t, nm.SelfNode.Expired(), + "client %s: self should not be expired", c.Name) + + for _, peer := range nm.Peers { + assert.False(t, peer.Expired(), + "client %s: peer %d should not be expired", + c.Name, peer.ID()) + } + } + }) + + // Self node should always be machine-authorized. + t.Run("self_always_machine_authorized", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm) + assert.True(t, nm.SelfNode.MachineAuthorized(), + "client %s: self should be machine-authorized", c.Name) + } + + // After reconnect, should still be authorized. + h.Client(0).Disconnect(t) + h.Client(0).Reconnect(t) + h.Client(0).WaitForPeers(t, 1, 10*time.Second) + + nm := h.Client(0).Netmap() + require.NotNil(t, nm) + assert.True(t, nm.SelfNode.MachineAuthorized(), + "after reconnect: self should be machine-authorized") + }) + + // Node IDs in the server state should match what clients see. + t.Run("node_ids_consistent_between_server_and_client", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "idcheck-user") + + c1 := servertest.NewClient(t, srv, "idcheck-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "idcheck-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + c2.WaitForPeers(t, 1, 10*time.Second) + + // Get server-side node IDs. + serverID1 := findNodeID(t, srv, "idcheck-node1") + serverID2 := findNodeID(t, srv, "idcheck-node2") + + // Get client-side node IDs. + nm1 := c1.Netmap() + nm2 := c2.Netmap() + + require.NotNil(t, nm1) + require.NotNil(t, nm2) + + clientID1 := nm1.SelfNode.ID() + clientID2 := nm2.SelfNode.ID() + + //nolint:gosec // G115: test-only, IDs won't overflow + assert.Equal(t, int64(serverID1), int64(clientID1), + "node 1: server ID should match client self ID") + //nolint:gosec // G115: test-only, IDs won't overflow + assert.Equal(t, int64(serverID2), int64(clientID2), + "node 2: server ID should match client self ID") + + // c1's view of c2's ID should also match. + require.Len(t, nm1.Peers, 1) + //nolint:gosec // G115: test-only, IDs won't overflow + assert.Equal(t, int64(serverID2), int64(nm1.Peers[0].ID()), + "c1's view of c2's ID should match server") + }) + + // After hostinfo update, ALL peers should see the updated + // hostinfo, not just some. + t.Run("hostinfo_update_reaches_all_peers", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "hiall-user") + + const n = 5 + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("hiall-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // Client 0 updates its OS. + clients[0].Direct().SetHostinfo(&tailcfg.Hostinfo{ + BackendLogID: "servertest-hiall-0", + Hostname: "hiall-0", + OS: "StressTestOS", + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = clients[0].Direct().SendUpdate(ctx) + + // ALL other clients should see the updated OS. + for i := 1; i < n; i++ { + clients[i].WaitForCondition(t, + fmt.Sprintf("client %d sees OS update", i), + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "hiall-0" { + return hi.OS() == "StressTestOS" + } + } + + return false + }) + } + }) + + // MachineKey should be consistent: the server should track + // the same machine key the client registered with. + t.Run("machine_key_consistent", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "mkey-user") + + c1 := servertest.NewClient(t, srv, "mkey-node1", + servertest.WithUser(user)) + c1.WaitForUpdate(t, 10*time.Second) + + nm := c1.Netmap() + require.NotNil(t, nm) + + // The client's MachineKey in the netmap should be non-zero. + assert.False(t, nm.MachineKey.IsZero(), + "client's MachineKey should be non-zero") + + // Server should have the same key. + nodeID := findNodeID(t, srv, "mkey-node1") + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + assert.Equal(t, nm.MachineKey.String(), nv.MachineKey().String(), + "client and server should agree on MachineKey") + }) + + // NodeKey should be consistent between client and server. + t.Run("node_key_consistent", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "nkey-user") + + c1 := servertest.NewClient(t, srv, "nkey-node1", + servertest.WithUser(user)) + c1.WaitForUpdate(t, 10*time.Second) + + nm := c1.Netmap() + require.NotNil(t, nm) + + assert.False(t, nm.NodeKey.IsZero(), + "client's NodeKey should be non-zero") + + nodeID := findNodeID(t, srv, "nkey-node1") + nv, ok := srv.State().GetNodeByID(nodeID) + require.True(t, ok) + + assert.Equal(t, nm.NodeKey.String(), nv.NodeKey().String(), + "client and server should agree on NodeKey") + }) +} + +// TestStressChurn tests behavior under sustained connect/disconnect churn. +func TestStressChurn(t *testing.T) { + t.Parallel() + + // Connect 10 nodes, then replace them all one by one. + // Each replacement connects a new node and disconnects the old. + // The remaining nodes should always see a consistent mesh. + t.Run("rolling_replacement", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "rolling-user") + + const n = 5 + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("rolling-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // Replace each node one at a time. + for i := range n { + clients[i].Disconnect(t) + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("rolling-new-%d", i), + servertest.WithUser(user)) + } + + // Wait for the new set to converge. + for _, c := range clients { + c.WaitForPeers(t, n-1, 30*time.Second) + } + + servertest.AssertSymmetricVisibility(t, clients) + }) + + // Add nodes one at a time and verify the mesh grows correctly + // at each step. + t.Run("incremental_mesh_growth", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "incr-user") + + clients := make([]*servertest.TestClient, 0, 8) + + for i := range 8 { + c := servertest.NewClient(t, srv, + fmt.Sprintf("incr-%d", i), + servertest.WithUser(user)) + clients = append(clients, c) + + // After each addition, verify all existing clients see + // the correct number of peers. + expectedPeers := i // i-th node means i peers for existing nodes + for _, existing := range clients { + existing.WaitForPeers(t, expectedPeers, 15*time.Second) + } + } + + // Final check. + servertest.AssertMeshComplete(t, clients) + }) + + // Connect/disconnect the same node many times. The server + // should handle this without leaking state. + t.Run("repeated_connect_disconnect_same_node", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "repeat-user") + + observer := servertest.NewClient(t, srv, "repeat-observer", + servertest.WithUser(user)) + flapper := servertest.NewClient(t, srv, "repeat-flapper", + servertest.WithUser(user)) + + observer.WaitForPeers(t, 1, 10*time.Second) + + for i := range 10 { + flapper.Disconnect(t) + flapper.Reconnect(t) + flapper.WaitForPeers(t, 1, 10*time.Second) + + if i%3 == 0 { + t.Logf("cycle %d: flapper sees %d peers, observer sees %d peers", + i, len(flapper.Peers()), len(observer.Peers())) + } + } + + // After all cycles, mesh should be healthy. + observer.WaitForPeers(t, 1, 10*time.Second) + + _, found := observer.PeerByName("repeat-flapper") + assert.True(t, found, + "observer should still see flapper after 10 reconnect cycles") + }) + + // All nodes disconnect and reconnect simultaneously. + t.Run("mass_reconnect", func(t *testing.T) { + t.Parallel() + + sizes := []int{4, 6} + for _, n := range sizes { + t.Run(fmt.Sprintf("%d_nodes", n), func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "massrecon-user") + + clients := make([]*servertest.TestClient, n) + for i := range n { + clients[i] = servertest.NewClient(t, srv, + fmt.Sprintf("massrecon-%d", i), + servertest.WithUser(user)) + } + + for _, c := range clients { + c.WaitForPeers(t, n-1, 20*time.Second) + } + + // All disconnect. + for _, c := range clients { + c.Disconnect(t) + } + + // All reconnect. + for _, c := range clients { + c.Reconnect(t) + } + + // Should re-form mesh. + for _, c := range clients { + c.WaitForPeers(t, n-1, 30*time.Second) + } + + servertest.AssertMeshComplete(t, clients) + servertest.AssertConsistentState(t, clients) + }) + } + }) +}