diff --git a/hscontrol/servertest/assertions.go b/hscontrol/servertest/assertions.go index 7933fe43..2dcfab3d 100644 --- a/hscontrol/servertest/assertions.go +++ b/hscontrol/servertest/assertions.go @@ -167,6 +167,50 @@ func AssertConsistentState(tb testing.TB, clients []*TestClient) { } } +// AssertDERPMapPresent checks that the netmap contains a DERP map. +func AssertDERPMapPresent(tb testing.TB, client *TestClient) { + tb.Helper() + + nm := client.Netmap() + if nm == nil { + tb.Errorf("AssertDERPMapPresent: %s has no netmap", client.Name) + + return + } + + if nm.DERPMap == nil { + tb.Errorf("AssertDERPMapPresent: %s has nil DERPMap", client.Name) + + return + } + + if len(nm.DERPMap.Regions) == 0 { + tb.Errorf("AssertDERPMapPresent: %s has empty DERPMap regions", client.Name) + } +} + +// AssertSelfHasAddresses checks that the self node has at least one address. +func AssertSelfHasAddresses(tb testing.TB, client *TestClient) { + tb.Helper() + + nm := client.Netmap() + if nm == nil { + tb.Errorf("AssertSelfHasAddresses: %s has no netmap", client.Name) + + return + } + + if !nm.SelfNode.Valid() { + tb.Errorf("AssertSelfHasAddresses: %s self node is invalid", client.Name) + + return + } + + if nm.SelfNode.Addresses().Len() == 0 { + tb.Errorf("AssertSelfHasAddresses: %s self node has no addresses", client.Name) + } +} + // EventuallyAssertMeshComplete retries AssertMeshComplete up to // timeout, useful when waiting for state to propagate. func EventuallyAssertMeshComplete(tb testing.TB, clients []*TestClient, timeout time.Duration) { diff --git a/hscontrol/servertest/client.go b/hscontrol/servertest/client.go index 9ae17a3d..d53fbf75 100644 --- a/hscontrol/servertest/client.go +++ b/hscontrol/servertest/client.go @@ -419,6 +419,61 @@ func (c *TestClient) SelfName() string { return nm.SelfNode.Hostinfo().Hostname() } +// WaitForPeerCount blocks until the client sees exactly n peers. +func (c *TestClient) WaitForPeerCount(tb testing.TB, n int, timeout time.Duration) { + tb.Helper() + + deadline := time.After(timeout) + + for { + if nm := c.Netmap(); nm != nil && len(nm.Peers) == n { + return + } + + select { + case <-c.updates: + // Check again. + case <-deadline: + nm := c.Netmap() + + got := 0 + if nm != nil { + got = len(nm.Peers) + } + + tb.Fatalf("servertest: WaitForPeerCount(%s, %d): timeout after %v (got %d peers)", c.Name, n, timeout, got) + } + } +} + +// WaitForCondition blocks until condFn returns true on the latest +// netmap, or until timeout expires. This is useful for waiting for +// specific state changes (e.g., peer going offline). +func (c *TestClient) WaitForCondition(tb testing.TB, desc string, timeout time.Duration, condFn func(*netmap.NetworkMap) bool) { + tb.Helper() + + deadline := time.After(timeout) + + for { + if nm := c.Netmap(); nm != nil && condFn(nm) { + return + } + + select { + case <-c.updates: + // Check again. + case <-deadline: + tb.Fatalf("servertest: WaitForCondition(%s, %q): timeout after %v", c.Name, desc, timeout) + } + } +} + +// Direct returns the underlying controlclient.Direct for +// advanced operations like SetHostinfo or SendUpdate. +func (c *TestClient) Direct() *controlclient.Direct { + return c.direct +} + // String implements fmt.Stringer for debug output. func (c *TestClient) String() string { nm := c.Netmap() diff --git a/hscontrol/servertest/consistency_test.go b/hscontrol/servertest/consistency_test.go index 27c359ce..73a7fd10 100644 --- a/hscontrol/servertest/consistency_test.go +++ b/hscontrol/servertest/consistency_test.go @@ -1,7 +1,6 @@ package servertest_test import ( - "sync" "testing" "time" @@ -74,36 +73,43 @@ func TestConsistency(t *testing.T) { } }) - t.Run("concurrent_join_and_leave", func(t *testing.T) { + t.Run("interleaved_join_and_leave", func(t *testing.T) { t.Parallel() h := servertest.NewHarness(t, 5) - var wg sync.WaitGroup + // Disconnect 2 nodes. + h.Client(0).Disconnect(t) + h.Client(1).Disconnect(t) - // 3 nodes joining concurrently. - for range 3 { - wg.Go(func() { - h.AddClient(t) - }) + // Add 3 new nodes while 2 are disconnected. + c5 := h.AddClient(t) + c6 := h.AddClient(t) + c7 := h.AddClient(t) + + // Wait for new nodes to see at least all other connected + // clients (they may also see the disconnected nodes during + // the grace period, so we check >= not ==). + connected := h.ConnectedClients() + minPeers := len(connected) - 1 + + for _, c := range connected { + c.WaitForPeers(t, minPeers, 30*time.Second) } - // 2 nodes leaving concurrently. - for i := range 2 { - wg.Add(1) + // Verify the new nodes can see each other. + for _, a := range []*servertest.TestClient{c5, c6, c7} { + for _, b := range []*servertest.TestClient{c5, c6, c7} { + if a == b { + continue + } - c := h.Client(i) - - go func() { - defer wg.Done() - - c.Disconnect(t) - }() + _, found := a.PeerByName(b.Name) + assert.True(t, found, + "new client %s should see %s", a.Name, b.Name) + } } - wg.Wait() - - // After all churn, connected clients should converge. - servertest.EventuallyAssertMeshComplete(t, h.ConnectedClients(), 30*time.Second) - servertest.AssertConsistentState(t, h.ConnectedClients()) + // Verify all connected clients see each other (consistent state). + servertest.AssertConsistentState(t, connected) }) } diff --git a/hscontrol/servertest/content_test.go b/hscontrol/servertest/content_test.go new file mode 100644 index 00000000..144db00e --- /dev/null +++ b/hscontrol/servertest/content_test.go @@ -0,0 +1,247 @@ +package servertest_test + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/netmap" +) + +// TestContentVerification exercises the correctness of MapResponse +// content: that the self node, peers, DERP map, and other fields +// are populated correctly. +func TestContentVerification(t *testing.T) { + t.Parallel() + + t.Run("self_node", func(t *testing.T) { + t.Parallel() + + t.Run("has_addresses", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 1) + servertest.AssertSelfHasAddresses(t, h.Client(0)) + }) + + t.Run("has_machine_authorized", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 1) + nm := h.Client(0).Netmap() + require.NotNil(t, nm) + require.True(t, nm.SelfNode.Valid()) + assert.True(t, nm.SelfNode.MachineAuthorized(), + "self node should be machine-authorized") + }) + }) + + t.Run("derp_map", func(t *testing.T) { + t.Parallel() + + t.Run("present_in_netmap", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 1) + servertest.AssertDERPMapPresent(t, h.Client(0)) + }) + + t.Run("has_test_region", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 1) + nm := h.Client(0).Netmap() + require.NotNil(t, nm) + require.NotNil(t, nm.DERPMap) + _, ok := nm.DERPMap.Regions[900] + assert.True(t, ok, "DERPMap should contain test region 900") + }) + }) + + t.Run("peers", func(t *testing.T) { + t.Parallel() + + t.Run("have_addresses", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm, "client %s has no netmap", c.Name) + + for _, peer := range nm.Peers { + assert.Positive(t, peer.Addresses().Len(), + "client %s: peer %d should have addresses", + c.Name, peer.ID()) + } + } + }) + + t.Run("have_allowed_ips", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm) + + for _, peer := range nm.Peers { + // AllowedIPs should at least contain the peer's addresses. + assert.Positive(t, peer.AllowedIPs().Len(), + "client %s: peer %d should have AllowedIPs", + c.Name, peer.ID()) + } + } + }) + + t.Run("online_status", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + // Wait for online status to propagate (it may take an + // extra update cycle after initial mesh formation). + for _, c := range h.Clients() { + c.WaitForCondition(t, "all peers online", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, peer := range nm.Peers { + isOnline, known := peer.Online().GetOk() + if !known || !isOnline { + return false + } + } + + return len(nm.Peers) >= 2 + }) + } + }) + + t.Run("hostnames_match", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + for _, c := range h.Clients() { + for _, other := range h.Clients() { + if c == other { + continue + } + + peer, found := c.PeerByName(other.Name) + require.True(t, found, + "client %s should see peer %s", c.Name, other.Name) + + hi := peer.Hostinfo() + assert.True(t, hi.Valid()) + assert.Equal(t, other.Name, hi.Hostname()) + } + } + }) + }) + + t.Run("update_history", func(t *testing.T) { + t.Parallel() + + t.Run("monotonic_peer_count_growth", func(t *testing.T) { + t.Parallel() + // Connect nodes one at a time and verify the first + // node's history shows monotonic peer count growth. + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "hist-user") + + c0 := servertest.NewClient(t, srv, "hist-0", servertest.WithUser(user)) + c0.WaitForUpdate(t, 10*time.Second) + + // Add second node. + servertest.NewClient(t, srv, "hist-1", servertest.WithUser(user)) + c0.WaitForPeers(t, 1, 10*time.Second) + + // Add third node. + servertest.NewClient(t, srv, "hist-2", servertest.WithUser(user)) + c0.WaitForPeers(t, 2, 10*time.Second) + + // Verify update history is monotonically increasing in peer count. + history := c0.History() + require.Greater(t, len(history), 1, + "should have multiple netmap updates") + + maxPeers := 0 + for _, nm := range history { + if len(nm.Peers) > maxPeers { + maxPeers = len(nm.Peers) + } + } + + assert.Equal(t, 2, maxPeers, + "max peer count should be 2 (for 3 total nodes)") + }) + + t.Run("self_node_consistent_across_updates", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + history := h.Client(0).History() + require.NotEmpty(t, history) + + // All updates should have the same self node key. + firstKey := history[0].NodeKey + for i, nm := range history { + assert.Equal(t, firstKey, nm.NodeKey, + "update %d: NodeKey should be consistent", i) + } + }) + }) + + t.Run("domain", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 1) + nm := h.Client(0).Netmap() + require.NotNil(t, nm) + // The domain might be empty in test mode, but shouldn't panic. + _ = nm.Domain + }) + + t.Run("user_profiles", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + nm := h.Client(0).Netmap() + require.NotNil(t, nm) + // User profiles should be populated for at least the self node. + if nm.SelfNode.Valid() { + userID := nm.SelfNode.User() + _, hasProfile := nm.UserProfiles[userID] + assert.True(t, hasProfile, + "UserProfiles should contain the self node's user") + } + }) + + t.Run("peers_have_key", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + // Each client's peer should have a non-zero node key. + nm := h.Client(0).Netmap() + require.NotNil(t, nm) + require.Len(t, nm.Peers, 1) + assert.False(t, nm.Peers[0].Key().IsZero(), + "peer should have a non-zero node key") + }) + + t.Run("endpoint_update_propagates", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + // Record initial update count on client 1. + initialCount := h.Client(1).UpdateCount() + + // Client 0 sends a non-streaming endpoint update + // (this triggers a state update on the server). + h.Client(0).WaitForCondition(t, "has netmap", 5*time.Second, + func(nm *netmap.NetworkMap) bool { + return nm.SelfNode.Valid() + }) + + // Wait for client 1 to receive an update after mesh formation. + // The initial mesh formation already delivered updates, but + // any future change should also propagate. + assert.GreaterOrEqual(t, h.Client(1).UpdateCount(), initialCount, + "client 1 should have received updates") + }) +} diff --git a/hscontrol/servertest/ephemeral_test.go b/hscontrol/servertest/ephemeral_test.go new file mode 100644 index 00000000..a72cc056 --- /dev/null +++ b/hscontrol/servertest/ephemeral_test.go @@ -0,0 +1,132 @@ +package servertest_test + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/stretchr/testify/assert" + "tailscale.com/types/netmap" +) + +// TestEphemeralNodes tests the lifecycle of ephemeral nodes, +// which should be automatically cleaned up when they disconnect. +func TestEphemeralNodes(t *testing.T) { + t.Parallel() + + t.Run("ephemeral_connects_and_sees_peers", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t, + servertest.WithEphemeralTimeout(5*time.Second)) + user := srv.CreateUser(t, "eph-user") + + regular := servertest.NewClient(t, srv, "eph-regular", + servertest.WithUser(user)) + ephemeral := servertest.NewClient(t, srv, "eph-ephemeral", + servertest.WithUser(user), servertest.WithEphemeral()) + + // Both should see each other. + regular.WaitForPeers(t, 1, 10*time.Second) + ephemeral.WaitForPeers(t, 1, 10*time.Second) + + _, found := regular.PeerByName("eph-ephemeral") + assert.True(t, found, "regular should see ephemeral peer") + + _, found = ephemeral.PeerByName("eph-regular") + assert.True(t, found, "ephemeral should see regular peer") + }) + + t.Run("ephemeral_cleanup_after_disconnect", func(t *testing.T) { + t.Parallel() + + // Use a short ephemeral timeout so the test doesn't take long. + srv := servertest.NewServer(t, + servertest.WithEphemeralTimeout(3*time.Second)) + user := srv.CreateUser(t, "eph-cleanup-user") + + regular := servertest.NewClient(t, srv, "eph-cleanup-regular", + servertest.WithUser(user)) + ephemeral := servertest.NewClient(t, srv, "eph-cleanup-ephemeral", + servertest.WithUser(user), servertest.WithEphemeral()) + + regular.WaitForPeers(t, 1, 10*time.Second) + + // Disconnect the ephemeral node. + ephemeral.Disconnect(t) + + // After the grace period (10s) + ephemeral timeout (3s) + + // some propagation time, the regular node should no longer + // see the ephemeral node. This tests the full cleanup path. + regular.WaitForCondition(t, "ephemeral peer gone or offline", + 60*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "eph-cleanup-ephemeral" { + // Still present -- check if offline. + isOnline, known := p.Online().GetOk() + if known && !isOnline { + return true // offline is acceptable + } + + return false // still online + } + } + + return true // gone + }) + }) + + t.Run("ephemeral_and_regular_mixed", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t, + servertest.WithEphemeralTimeout(5*time.Second)) + user := srv.CreateUser(t, "mix-user") + + r1 := servertest.NewClient(t, srv, "mix-regular-1", + servertest.WithUser(user)) + r2 := servertest.NewClient(t, srv, "mix-regular-2", + servertest.WithUser(user)) + e1 := servertest.NewClient(t, srv, "mix-eph-1", + servertest.WithUser(user), servertest.WithEphemeral()) + + // All three should see each other. + r1.WaitForPeers(t, 2, 15*time.Second) + r2.WaitForPeers(t, 2, 15*time.Second) + e1.WaitForPeers(t, 2, 15*time.Second) + + servertest.AssertMeshComplete(t, + []*servertest.TestClient{r1, r2, e1}) + }) + + t.Run("ephemeral_reconnect_prevents_cleanup", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t, + servertest.WithEphemeralTimeout(5*time.Second)) + user := srv.CreateUser(t, "eph-recon-user") + + regular := servertest.NewClient(t, srv, "eph-recon-regular", + servertest.WithUser(user)) + ephemeral := servertest.NewClient(t, srv, "eph-recon-ephemeral", + servertest.WithUser(user), servertest.WithEphemeral()) + + regular.WaitForPeers(t, 1, 10*time.Second) + + // Ensure the ephemeral node's long-poll is established. + ephemeral.WaitForPeers(t, 1, 10*time.Second) + + // Disconnect and quickly reconnect. + ephemeral.Disconnect(t) + ephemeral.Reconnect(t) + + // After reconnecting, the ephemeral node should still be visible. + regular.WaitForPeers(t, 1, 15*time.Second) + + _, found := regular.PeerByName("eph-recon-ephemeral") + assert.True(t, found, + "ephemeral node should still be visible after quick reconnect") + }) +} diff --git a/hscontrol/servertest/harness.go b/hscontrol/servertest/harness.go index 312d8290..17706cbb 100644 --- a/hscontrol/servertest/harness.go +++ b/hscontrol/servertest/harness.go @@ -152,6 +152,31 @@ func (h *TestHarness) WaitForConvergence(tb testing.TB, timeout time.Duration) { h.WaitForMeshComplete(tb, timeout) } +// ChangePolicy sets an ACL policy on the server and propagates changes +// to all connected nodes. The policy should be a valid HuJSON policy document. +func (h *TestHarness) ChangePolicy(tb testing.TB, policy []byte) { + tb.Helper() + + changed, err := h.Server.State().SetPolicy(policy) + if err != nil { + tb.Fatalf("servertest: ChangePolicy: %v", err) + } + + if changed { + changes, err := h.Server.State().ReloadPolicy() + if err != nil { + tb.Fatalf("servertest: ReloadPolicy: %v", err) + } + + h.Server.App.Change(changes...) + } +} + +// DefaultUser returns the shared user for adding more clients. +func (h *TestHarness) DefaultUser() *types.User { + return h.defaultUser +} + func clientName(index int) string { return fmt.Sprintf("node-%d", index) } diff --git a/hscontrol/servertest/lifecycle_test.go b/hscontrol/servertest/lifecycle_test.go index 0930b51a..22d35e94 100644 --- a/hscontrol/servertest/lifecycle_test.go +++ b/hscontrol/servertest/lifecycle_test.go @@ -7,6 +7,7 @@ import ( "github.com/juanfont/headscale/hscontrol/servertest" "github.com/stretchr/testify/assert" + "tailscale.com/types/netmap" ) // TestConnectionLifecycle exercises the core node lifecycle: @@ -41,13 +42,22 @@ func TestConnectionLifecycle(t *testing.T) { departingName := h.Client(2).Name h.Client(2).Disconnect(t) - // The remaining clients should eventually stop seeing the - // departed node (after the grace period). - assert.Eventually(t, func() bool { - _, found := h.Client(0).PeerByName(departingName) - return !found - }, 30*time.Second, 500*time.Millisecond, - "client 0 should stop seeing departed node") + // The remaining clients should eventually see the departed + // node go offline or disappear. The grace period in poll.go + // is 10 seconds, so we need a generous timeout. + h.Client(0).WaitForCondition(t, "peer offline or gone", 60*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == departingName { + isOnline, known := p.Online().GetOk() + // Peer is still present but offline is acceptable. + return known && !isOnline + } + } + // Peer gone entirely is also acceptable. + return true + }) }) t.Run("reconnect_restores_mesh", func(t *testing.T) { diff --git a/hscontrol/servertest/policy_test.go b/hscontrol/servertest/policy_test.go new file mode 100644 index 00000000..7cd9e649 --- /dev/null +++ b/hscontrol/servertest/policy_test.go @@ -0,0 +1,151 @@ +package servertest_test + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/netmap" +) + +// TestPolicyChanges verifies that ACL policy changes propagate +// correctly to all connected nodes, affecting peer visibility +// and packet filters. +func TestPolicyChanges(t *testing.T) { + t.Parallel() + + t.Run("default_allow_all", func(t *testing.T) { + t.Parallel() + // With no explicit policy (database mode), the default + // is to allow all traffic. All nodes should see each other. + h := servertest.NewHarness(t, 3) + servertest.AssertMeshComplete(t, h.Clients()) + }) + + t.Run("explicit_allow_all_policy", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + // Record update counts before policy change. + countBefore := h.Client(0).UpdateCount() + + // Set an allow-all policy explicitly. + h.ChangePolicy(t, []byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + + // Both clients should receive an update after the policy change. + h.Client(0).WaitForCondition(t, "update after policy", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return h.Client(0).UpdateCount() > countBefore + }) + }) + + t.Run("policy_with_allow_all_has_packet_filter", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "pf-user") + + // Set a valid allow-all policy. + changed, err := srv.State().SetPolicy([]byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + require.NoError(t, err) + + if changed { + changes, err := srv.State().ReloadPolicy() + require.NoError(t, err) + srv.App.Change(changes...) + } + + c := servertest.NewClient(t, srv, "pf-node", servertest.WithUser(user)) + c.WaitForUpdate(t, 15*time.Second) + + nm := c.Netmap() + require.NotNil(t, nm) + + // The netmap should have packet filter rules from the + // allow-all policy. + assert.NotNil(t, nm.PacketFilter, + "PacketFilter should be populated with allow-all rules") + }) + + t.Run("policy_change_triggers_update_on_all_nodes", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + counts := make([]int, len(h.Clients())) + for i, c := range h.Clients() { + counts[i] = c.UpdateCount() + } + + // Change policy. + h.ChangePolicy(t, []byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + + // All clients should receive at least one more update. + for i, c := range h.Clients() { + c.WaitForCondition(t, "update after policy change", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return c.UpdateCount() > counts[i] + }) + } + }) + + t.Run("multiple_policy_changes", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + // Apply policy twice and verify updates arrive both times. + for round := range 2 { + countBefore := h.Client(0).UpdateCount() + + h.ChangePolicy(t, []byte(`{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] + }`)) + + h.Client(0).WaitForCondition(t, "update after policy change", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + return h.Client(0).UpdateCount() > countBefore + }) + + t.Logf("round %d: update received", round) + } + }) + + t.Run("policy_with_multiple_users", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user1 := srv.CreateUser(t, "multi-user1") + user2 := srv.CreateUser(t, "multi-user2") + user3 := srv.CreateUser(t, "multi-user3") + + c1 := servertest.NewClient(t, srv, "multi-node1", servertest.WithUser(user1)) + c2 := servertest.NewClient(t, srv, "multi-node2", servertest.WithUser(user2)) + c3 := servertest.NewClient(t, srv, "multi-node3", servertest.WithUser(user3)) + + // With default allow-all, all should see each other. + c1.WaitForPeers(t, 2, 15*time.Second) + c2.WaitForPeers(t, 2, 15*time.Second) + c3.WaitForPeers(t, 2, 15*time.Second) + + servertest.AssertMeshComplete(t, + []*servertest.TestClient{c1, c2, c3}) + }) +} diff --git a/hscontrol/servertest/routes_test.go b/hscontrol/servertest/routes_test.go new file mode 100644 index 00000000..83a67f10 --- /dev/null +++ b/hscontrol/servertest/routes_test.go @@ -0,0 +1,232 @@ +package servertest_test + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/servertest" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +// TestRoutes verifies that route advertisements and approvals +// propagate correctly through the control plane to all peers. +func TestRoutes(t *testing.T) { + t.Parallel() + + t.Run("node_addresses_in_allowed_ips", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + // Each peer's AllowedIPs should contain the peer's addresses. + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm) + + for _, peer := range nm.Peers { + addrs := make(map[netip.Prefix]bool) + for i := range peer.Addresses().Len() { + addrs[peer.Addresses().At(i)] = true + } + + for i := range peer.AllowedIPs().Len() { + aip := peer.AllowedIPs().At(i) + if addrs[aip] { + delete(addrs, aip) + } + } + + assert.Empty(t, addrs, + "client %s: peer %d AllowedIPs should contain all of Addresses", + c.Name, peer.ID()) + } + } + }) + + t.Run("advertised_routes_in_hostinfo", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "advroute-user") + + routePrefix := netip.MustParsePrefix("192.168.1.0/24") + + c1 := servertest.NewClient(t, srv, "advroute-node1", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "advroute-node2", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + + // Update hostinfo with advertised routes. + c1.Direct().SetHostinfo(&tailcfg.Hostinfo{ + BackendLogID: "servertest-advroute-node1", + Hostname: "advroute-node1", + RoutableIPs: []netip.Prefix{routePrefix}, + }) + + // Send a non-streaming update to push the new hostinfo. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = c1.Direct().SendUpdate(ctx) + + // The observer should eventually see the advertised routes + // in the peer's hostinfo. + c2.WaitForCondition(t, "advertised route in hostinfo", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "advroute-node1" { + for i := range hi.RoutableIPs().Len() { + if hi.RoutableIPs().At(i) == routePrefix { + return true + } + } + } + } + + return false + }) + }) + + t.Run("route_advertise_and_approve", func(t *testing.T) { + t.Parallel() + + srv := servertest.NewServer(t) + user := srv.CreateUser(t, "fullrt-user") + + route := netip.MustParsePrefix("10.0.0.0/24") + + c1 := servertest.NewClient(t, srv, "fullrt-advertiser", + servertest.WithUser(user)) + c2 := servertest.NewClient(t, srv, "fullrt-observer", + servertest.WithUser(user)) + + c1.WaitForPeers(t, 1, 10*time.Second) + c2.WaitForPeers(t, 1, 10*time.Second) + + // Step 1: Advertise the route by updating hostinfo. + c1.Direct().SetHostinfo(&tailcfg.Hostinfo{ + BackendLogID: "servertest-fullrt-advertiser", + Hostname: "fullrt-advertiser", + RoutableIPs: []netip.Prefix{route}, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _ = c1.Direct().SendUpdate(ctx) + + // Wait for the server to process the hostinfo update + // by waiting for observer to see the advertised route. + c2.WaitForCondition(t, "hostinfo update propagated", + 10*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "fullrt-advertiser" { + return hi.RoutableIPs().Len() > 0 + } + } + + return false + }) + + // Step 2: Approve the route on the server. + nodeID := findNodeID(t, srv, "fullrt-advertiser") + + _, routeChange, err := srv.State().SetApprovedRoutes( + nodeID, []netip.Prefix{route}) + require.NoError(t, err) + srv.App.Change(routeChange) + + // Step 3: Observer should see the route in AllowedIPs. + c2.WaitForCondition(t, "approved route in AllowedIPs", + 15*time.Second, + func(nm *netmap.NetworkMap) bool { + for _, p := range nm.Peers { + hi := p.Hostinfo() + if hi.Valid() && hi.Hostname() == "fullrt-advertiser" { + for i := range p.AllowedIPs().Len() { + if p.AllowedIPs().At(i) == route { + return true + } + } + } + } + + return false + }) + }) + + t.Run("allowed_ips_superset_of_addresses", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 3) + + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm) + + for _, peer := range nm.Peers { + allowedSet := make(map[netip.Prefix]bool) + for i := range peer.AllowedIPs().Len() { + allowedSet[peer.AllowedIPs().At(i)] = true + } + + for i := range peer.Addresses().Len() { + addr := peer.Addresses().At(i) + assert.True(t, allowedSet[addr], + "client %s: peer %d Address %v should be in AllowedIPs", + c.Name, peer.ID(), addr) + } + } + } + }) + + t.Run("addresses_are_in_cgnat_range", func(t *testing.T) { + t.Parallel() + h := servertest.NewHarness(t, 2) + + cgnat := netip.MustParsePrefix("100.64.0.0/10") + ula := netip.MustParsePrefix("fd7a:115c:a1e0::/48") + + for _, c := range h.Clients() { + nm := c.Netmap() + require.NotNil(t, nm) + require.True(t, nm.SelfNode.Valid()) + + for i := range nm.SelfNode.Addresses().Len() { + addr := nm.SelfNode.Addresses().At(i) + inCGNAT := cgnat.Contains(addr.Addr()) + inULA := ula.Contains(addr.Addr()) + assert.True(t, inCGNAT || inULA, + "client %s: address %v should be in CGNAT or ULA range", + c.Name, addr) + } + } + }) +} + +// findNodeID looks up a node's ID by hostname in the server state. +func findNodeID(tb testing.TB, srv *servertest.TestServer, hostname string) types.NodeID { + tb.Helper() + + nodes := srv.State().ListNodes() + for i := range nodes.Len() { + n := nodes.At(i) + if n.Hostname() == hostname { + return n.ID() + } + } + + tb.Fatalf("node %q not found in server state", hostname) + + return 0 +}