Files
headscale/hscontrol/servertest/stress_test.go
Kristoffer Dalby 00c41b6422 hscontrol/servertest: add race, stress, and poll race tests
Add three test files designed to stress the control plane under
concurrent and adversarial conditions:

- race_test.go: 14 tests exercising concurrent mutations, session
  replacement, batcher contention, NodeStore access, and map response
  delivery during disconnect. All pass the Go race detector.

- poll_race_test.go: 8 tests targeting the poll.go grace period
  interleaving. These confirm a logical TOCTOU race: when a node
  disconnects and reconnects within the grace period, the old
  session's deferred Disconnect() can overwrite the new session's
  Connect(), leaving IsOnline=false despite an active poll session.

- stress_test.go: sustained churn, rapid mutations, rolling
  replacement, data integrity checks under load, and verification
  that rapid reconnects do not leak false-offline notifications.

Known failing tests (grace period TOCTOU race):
- server_state_online_after_reconnect_within_grace
- update_history_no_false_offline
- rapid_reconnect_peer_never_sees_offline
2026-03-19 07:05:58 +01:00

741 lines
19 KiB
Go

package servertest_test
import (
"context"
"fmt"
"net/netip"
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/servertest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/netmap"
)
// TestStress hammers the control plane with concurrent operations,
// rapid mutations, and edge cases to surface race conditions and
// consistency bugs.
// TestStressConnectDisconnect exercises rapid connect/disconnect
// patterns that stress the grace period, batcher, and NodeStore.
func TestStressConnectDisconnect(t *testing.T) {
t.Parallel()
// A node that disconnects and reconnects faster than the
// grace period should never cause a second node to see
// the first node as offline.
t.Run("rapid_reconnect_peer_never_sees_offline", func(t *testing.T) {
t.Parallel()
h := servertest.NewHarness(t, 2)
// Wait for both to be online.
h.Client(0).WaitForCondition(t, "peer online", 15*time.Second,
func(nm *netmap.NetworkMap) bool {
for _, p := range nm.Peers {
isOnline, known := p.Online().GetOk()
if known && isOnline {
return true
}
}
return false
})
// Do 10 rapid reconnects and check that client 0 never
// sees client 1 as offline during the process.
sawOffline := false
var offlineMu sync.Mutex
// Monitor client 0's view of client 1 in the background.
stopMonitor := make(chan struct{})
monitorDone := make(chan struct{})
go func() {
defer close(monitorDone)
for {
select {
case <-stopMonitor:
return
default:
}
nm := h.Client(0).Netmap()
if nm == nil {
continue
}
for _, p := range nm.Peers {
isOnline, known := p.Online().GetOk()
if known && !isOnline {
offlineMu.Lock()
sawOffline = true
offlineMu.Unlock()
}
}
}
}()
for range 10 {
h.Client(1).Disconnect(t)
h.Client(1).Reconnect(t)
}
// Give the monitor a moment to catch up, then stop it.
h.Client(0).WaitForPeers(t, 1, 10*time.Second)
close(stopMonitor)
<-monitorDone
offlineMu.Lock()
defer offlineMu.Unlock()
assert.False(t, sawOffline,
"peer should never appear offline during rapid reconnect cycles")
})
// Delete a node while it has an active poll session. The poll
// session should terminate cleanly and other peers should see
// the node disappear.
t.Run("delete_node_during_active_poll", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "delpoll-user")
c1 := servertest.NewClient(t, srv, "delpoll-node1",
servertest.WithUser(user))
c2 := servertest.NewClient(t, srv, "delpoll-node2",
servertest.WithUser(user))
c1.WaitForPeers(t, 1, 10*time.Second)
c2.WaitForPeers(t, 1, 10*time.Second)
// Delete c1 while it's actively polling.
nodeID := findNodeID(t, srv, "delpoll-node1")
nv, ok := srv.State().GetNodeByID(nodeID)
require.True(t, ok)
deleteChange, err := srv.State().DeleteNode(nv)
require.NoError(t, err)
srv.App.Change(deleteChange)
// c2 should see c1 disappear.
c2.WaitForCondition(t, "deleted node gone", 10*time.Second,
func(nm *netmap.NetworkMap) bool {
for _, p := range nm.Peers {
hi := p.Hostinfo()
if hi.Valid() && hi.Hostname() == "delpoll-node1" {
return false
}
}
return true
})
assert.Empty(t, c2.Peers(),
"c2 should have no peers after c1 is deleted")
})
// Connect many nodes, then disconnect half simultaneously.
// The remaining half should converge to see only each other.
t.Run("disconnect_half_remaining_converge", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "halfdisc-user")
const total = 6
clients := make([]*servertest.TestClient, total)
for i := range total {
clients[i] = servertest.NewClient(t, srv,
fmt.Sprintf("halfdisc-%d", i),
servertest.WithUser(user))
}
// Wait for full mesh.
for _, c := range clients {
c.WaitForPeers(t, total-1, 30*time.Second)
}
// Disconnect the first half.
for i := range total / 2 {
clients[i].Disconnect(t)
}
// The remaining half should eventually converge.
remaining := clients[total/2:]
for _, c := range remaining {
c.WaitForCondition(t, "remaining converge",
30*time.Second,
func(nm *netmap.NetworkMap) bool {
// Should see at least the other remaining peers.
onlinePeers := 0
for _, p := range nm.Peers {
isOnline, known := p.Online().GetOk()
if known && isOnline {
onlinePeers++
}
}
// Remaining peers minus self = total/2 - 1
return onlinePeers >= len(remaining)-1
})
}
})
}
// TestStressStateMutations tests rapid server-side state changes.
func TestStressStateMutations(t *testing.T) {
t.Parallel()
// Rapidly approve and remove routes. The final state should
// be consistent.
t.Run("rapid_route_changes_final_state_correct", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "rapidrt-user")
c1 := servertest.NewClient(t, srv, "rapidrt-node1",
servertest.WithUser(user))
c2 := servertest.NewClient(t, srv, "rapidrt-node2",
servertest.WithUser(user))
c1.WaitForPeers(t, 1, 10*time.Second)
nodeID := findNodeID(t, srv, "rapidrt-node1")
// Rapidly change routes 10 times.
for i := range 10 {
route := netip.MustParsePrefix(
fmt.Sprintf("10.%d.0.0/24", i))
_, routeChange, err := srv.State().SetApprovedRoutes(
nodeID, []netip.Prefix{route})
require.NoError(t, err)
srv.App.Change(routeChange)
}
// Final route should be 10.9.0.0/24.
// Verify server state is correct.
nv, ok := srv.State().GetNodeByID(nodeID)
require.True(t, ok)
finalRoutes := nv.ApprovedRoutes().AsSlice()
expected := netip.MustParsePrefix("10.9.0.0/24")
assert.Contains(t, finalRoutes, expected,
"final approved routes should contain the last route set")
assert.Len(t, finalRoutes, 1,
"should have exactly 1 approved route (the last one set)")
// c2 should eventually see the update.
c2.WaitForCondition(t, "final route update received",
10*time.Second,
func(nm *netmap.NetworkMap) bool {
return c2.UpdateCount() > 2
})
})
// Rename a node multiple times rapidly. The final name should
// be correct in the server state and visible to peers.
t.Run("rapid_rename_final_state_correct", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "rapidname-user")
c1 := servertest.NewClient(t, srv, "rapidname-node1",
servertest.WithUser(user))
c2 := servertest.NewClient(t, srv, "rapidname-node2",
servertest.WithUser(user))
c1.WaitForPeers(t, 1, 10*time.Second)
nodeID := findNodeID(t, srv, "rapidname-node1")
// Rename 5 times rapidly.
var finalName string
for i := range 5 {
finalName = fmt.Sprintf("renamed-%d", i)
_, renameChange, err := srv.State().RenameNode(nodeID, finalName)
require.NoError(t, err)
srv.App.Change(renameChange)
}
// Server state should have the final name.
nv, ok := srv.State().GetNodeByID(nodeID)
require.True(t, ok)
assert.Equal(t, finalName, nv.AsStruct().GivenName,
"server should have the final renamed value")
// c2 should see the final name.
c2.WaitForCondition(t, "final name visible", 10*time.Second,
func(nm *netmap.NetworkMap) bool {
for _, p := range nm.Peers {
if p.Name() == finalName {
return true
}
}
return false
})
})
// Multiple policy changes in rapid succession. The final
// policy should be applied correctly.
t.Run("rapid_policy_changes", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "rapidpol-user")
c1 := servertest.NewClient(t, srv, "rapidpol-node1",
servertest.WithUser(user))
c1.WaitForUpdate(t, 10*time.Second)
countBefore := c1.UpdateCount()
// Change policy 5 times rapidly.
for range 5 {
changed, err := srv.State().SetPolicy([]byte(`{
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]}
]
}`))
require.NoError(t, err)
if changed {
changes, err := srv.State().ReloadPolicy()
require.NoError(t, err)
srv.App.Change(changes...)
}
}
// Client should have received at least some updates.
c1.WaitForCondition(t, "updates after policy changes",
10*time.Second,
func(nm *netmap.NetworkMap) bool {
return c1.UpdateCount() > countBefore
})
})
}
// TestStressDataIntegrity verifies data correctness under various conditions.
func TestStressDataIntegrity(t *testing.T) {
t.Parallel()
// Every node's self-addresses should match what peers see
// as that node's Addresses.
t.Run("self_addresses_match_peer_view", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "addrmatch-user")
const n = 5
clients := make([]*servertest.TestClient, n)
for i := range n {
clients[i] = servertest.NewClient(t, srv,
fmt.Sprintf("addrmatch-%d", i),
servertest.WithUser(user))
}
for _, c := range clients {
c.WaitForPeers(t, n-1, 20*time.Second)
}
// Build a map of hostname -> self-addresses.
selfAddrs := make(map[string][]netip.Prefix)
for _, c := range clients {
nm := c.Netmap()
require.NotNil(t, nm)
require.True(t, nm.SelfNode.Valid())
addrs := make([]netip.Prefix, 0, nm.SelfNode.Addresses().Len())
for i := range nm.SelfNode.Addresses().Len() {
addrs = append(addrs, nm.SelfNode.Addresses().At(i))
}
selfAddrs[c.Name] = addrs
}
// Now verify each client's peers have the same addresses
// as those peers' self-view.
for _, c := range clients {
nm := c.Netmap()
require.NotNil(t, nm)
for _, peer := range nm.Peers {
hi := peer.Hostinfo()
if !hi.Valid() {
continue
}
peerName := hi.Hostname()
expected, ok := selfAddrs[peerName]
if !ok {
continue
}
peerAddrs := make([]netip.Prefix, 0, peer.Addresses().Len())
for i := range peer.Addresses().Len() {
peerAddrs = append(peerAddrs, peer.Addresses().At(i))
}
assert.Equal(t, expected, peerAddrs,
"client %s: peer %s addresses should match that peer's self-view",
c.Name, peerName)
}
}
})
// After mesh formation, no peer should have Expired=true.
t.Run("no_peers_expired_after_mesh_formation", func(t *testing.T) {
t.Parallel()
h := servertest.NewHarness(t, 3)
for _, c := range h.Clients() {
nm := c.Netmap()
require.NotNil(t, nm)
assert.False(t, nm.SelfNode.Expired(),
"client %s: self should not be expired", c.Name)
for _, peer := range nm.Peers {
assert.False(t, peer.Expired(),
"client %s: peer %d should not be expired",
c.Name, peer.ID())
}
}
})
// Self node should always be machine-authorized.
t.Run("self_always_machine_authorized", func(t *testing.T) {
t.Parallel()
h := servertest.NewHarness(t, 2)
for _, c := range h.Clients() {
nm := c.Netmap()
require.NotNil(t, nm)
assert.True(t, nm.SelfNode.MachineAuthorized(),
"client %s: self should be machine-authorized", c.Name)
}
// After reconnect, should still be authorized.
h.Client(0).Disconnect(t)
h.Client(0).Reconnect(t)
h.Client(0).WaitForPeers(t, 1, 10*time.Second)
nm := h.Client(0).Netmap()
require.NotNil(t, nm)
assert.True(t, nm.SelfNode.MachineAuthorized(),
"after reconnect: self should be machine-authorized")
})
// Node IDs in the server state should match what clients see.
t.Run("node_ids_consistent_between_server_and_client", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "idcheck-user")
c1 := servertest.NewClient(t, srv, "idcheck-node1",
servertest.WithUser(user))
c2 := servertest.NewClient(t, srv, "idcheck-node2",
servertest.WithUser(user))
c1.WaitForPeers(t, 1, 10*time.Second)
c2.WaitForPeers(t, 1, 10*time.Second)
// Get server-side node IDs.
serverID1 := findNodeID(t, srv, "idcheck-node1")
serverID2 := findNodeID(t, srv, "idcheck-node2")
// Get client-side node IDs.
nm1 := c1.Netmap()
nm2 := c2.Netmap()
require.NotNil(t, nm1)
require.NotNil(t, nm2)
clientID1 := nm1.SelfNode.ID()
clientID2 := nm2.SelfNode.ID()
//nolint:gosec // G115: test-only, IDs won't overflow
assert.Equal(t, int64(serverID1), int64(clientID1),
"node 1: server ID should match client self ID")
//nolint:gosec // G115: test-only, IDs won't overflow
assert.Equal(t, int64(serverID2), int64(clientID2),
"node 2: server ID should match client self ID")
// c1's view of c2's ID should also match.
require.Len(t, nm1.Peers, 1)
//nolint:gosec // G115: test-only, IDs won't overflow
assert.Equal(t, int64(serverID2), int64(nm1.Peers[0].ID()),
"c1's view of c2's ID should match server")
})
// After hostinfo update, ALL peers should see the updated
// hostinfo, not just some.
t.Run("hostinfo_update_reaches_all_peers", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "hiall-user")
const n = 5
clients := make([]*servertest.TestClient, n)
for i := range n {
clients[i] = servertest.NewClient(t, srv,
fmt.Sprintf("hiall-%d", i),
servertest.WithUser(user))
}
for _, c := range clients {
c.WaitForPeers(t, n-1, 20*time.Second)
}
// Client 0 updates its OS.
clients[0].Direct().SetHostinfo(&tailcfg.Hostinfo{
BackendLogID: "servertest-hiall-0",
Hostname: "hiall-0",
OS: "StressTestOS",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = clients[0].Direct().SendUpdate(ctx)
// ALL other clients should see the updated OS.
for i := 1; i < n; i++ {
clients[i].WaitForCondition(t,
fmt.Sprintf("client %d sees OS update", i),
15*time.Second,
func(nm *netmap.NetworkMap) bool {
for _, p := range nm.Peers {
hi := p.Hostinfo()
if hi.Valid() && hi.Hostname() == "hiall-0" {
return hi.OS() == "StressTestOS"
}
}
return false
})
}
})
// MachineKey should be consistent: the server should track
// the same machine key the client registered with.
t.Run("machine_key_consistent", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "mkey-user")
c1 := servertest.NewClient(t, srv, "mkey-node1",
servertest.WithUser(user))
c1.WaitForUpdate(t, 10*time.Second)
nm := c1.Netmap()
require.NotNil(t, nm)
// The client's MachineKey in the netmap should be non-zero.
assert.False(t, nm.MachineKey.IsZero(),
"client's MachineKey should be non-zero")
// Server should have the same key.
nodeID := findNodeID(t, srv, "mkey-node1")
nv, ok := srv.State().GetNodeByID(nodeID)
require.True(t, ok)
assert.Equal(t, nm.MachineKey.String(), nv.MachineKey().String(),
"client and server should agree on MachineKey")
})
// NodeKey should be consistent between client and server.
t.Run("node_key_consistent", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "nkey-user")
c1 := servertest.NewClient(t, srv, "nkey-node1",
servertest.WithUser(user))
c1.WaitForUpdate(t, 10*time.Second)
nm := c1.Netmap()
require.NotNil(t, nm)
assert.False(t, nm.NodeKey.IsZero(),
"client's NodeKey should be non-zero")
nodeID := findNodeID(t, srv, "nkey-node1")
nv, ok := srv.State().GetNodeByID(nodeID)
require.True(t, ok)
assert.Equal(t, nm.NodeKey.String(), nv.NodeKey().String(),
"client and server should agree on NodeKey")
})
}
// TestStressChurn tests behavior under sustained connect/disconnect churn.
func TestStressChurn(t *testing.T) {
t.Parallel()
// Connect 10 nodes, then replace them all one by one.
// Each replacement connects a new node and disconnects the old.
// The remaining nodes should always see a consistent mesh.
t.Run("rolling_replacement", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "rolling-user")
const n = 5
clients := make([]*servertest.TestClient, n)
for i := range n {
clients[i] = servertest.NewClient(t, srv,
fmt.Sprintf("rolling-%d", i),
servertest.WithUser(user))
}
for _, c := range clients {
c.WaitForPeers(t, n-1, 20*time.Second)
}
// Replace each node one at a time.
for i := range n {
clients[i].Disconnect(t)
clients[i] = servertest.NewClient(t, srv,
fmt.Sprintf("rolling-new-%d", i),
servertest.WithUser(user))
}
// Wait for the new set to converge.
for _, c := range clients {
c.WaitForPeers(t, n-1, 30*time.Second)
}
servertest.AssertSymmetricVisibility(t, clients)
})
// Add nodes one at a time and verify the mesh grows correctly
// at each step.
t.Run("incremental_mesh_growth", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "incr-user")
clients := make([]*servertest.TestClient, 0, 8)
for i := range 8 {
c := servertest.NewClient(t, srv,
fmt.Sprintf("incr-%d", i),
servertest.WithUser(user))
clients = append(clients, c)
// After each addition, verify all existing clients see
// the correct number of peers.
expectedPeers := i // i-th node means i peers for existing nodes
for _, existing := range clients {
existing.WaitForPeers(t, expectedPeers, 15*time.Second)
}
}
// Final check.
servertest.AssertMeshComplete(t, clients)
})
// Connect/disconnect the same node many times. The server
// should handle this without leaking state.
t.Run("repeated_connect_disconnect_same_node", func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "repeat-user")
observer := servertest.NewClient(t, srv, "repeat-observer",
servertest.WithUser(user))
flapper := servertest.NewClient(t, srv, "repeat-flapper",
servertest.WithUser(user))
observer.WaitForPeers(t, 1, 10*time.Second)
for i := range 10 {
flapper.Disconnect(t)
flapper.Reconnect(t)
flapper.WaitForPeers(t, 1, 10*time.Second)
if i%3 == 0 {
t.Logf("cycle %d: flapper sees %d peers, observer sees %d peers",
i, len(flapper.Peers()), len(observer.Peers()))
}
}
// After all cycles, mesh should be healthy.
observer.WaitForPeers(t, 1, 10*time.Second)
_, found := observer.PeerByName("repeat-flapper")
assert.True(t, found,
"observer should still see flapper after 10 reconnect cycles")
})
// All nodes disconnect and reconnect simultaneously.
t.Run("mass_reconnect", func(t *testing.T) {
t.Parallel()
sizes := []int{4, 6}
for _, n := range sizes {
t.Run(fmt.Sprintf("%d_nodes", n), func(t *testing.T) {
t.Parallel()
srv := servertest.NewServer(t)
user := srv.CreateUser(t, "massrecon-user")
clients := make([]*servertest.TestClient, n)
for i := range n {
clients[i] = servertest.NewClient(t, srv,
fmt.Sprintf("massrecon-%d", i),
servertest.WithUser(user))
}
for _, c := range clients {
c.WaitForPeers(t, n-1, 20*time.Second)
}
// All disconnect.
for _, c := range clients {
c.Disconnect(t)
}
// All reconnect.
for _, c := range clients {
c.Reconnect(t)
}
// Should re-form mesh.
for _, c := range clients {
c.WaitForPeers(t, n-1, 30*time.Second)
}
servertest.AssertMeshComplete(t, clients)
servertest.AssertConsistentState(t, clients)
})
}
})
}