package mapper // Unit tests for batcher components that do NOT require database setup. // These tests exercise connectionEntry, multiChannelNodeConn, computePeerDiff, // updateSentPeers, generateMapResponse branching, and handleNodeChange in isolation. import ( "errors" "fmt" "runtime" "sync" "sync/atomic" "testing" "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/puzpuzpuz/xsync/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" ) // ============================================================================ // Mock Infrastructure // ============================================================================ // mockNodeConnection implements nodeConnection for isolated unit testing // of generateMapResponse and handleNodeChange without a real database. type mockNodeConnection struct { id types.NodeID ver tailcfg.CapabilityVersion // sendFn allows injecting custom send behavior. // If nil, sends are recorded and succeed. sendFn func(*tailcfg.MapResponse) error // sent records all successful sends for assertion. sent []*tailcfg.MapResponse mu sync.Mutex // Peer tracking peers *xsync.Map[tailcfg.NodeID, struct{}] } func newMockNodeConnection(id types.NodeID) *mockNodeConnection { return &mockNodeConnection{ id: id, ver: tailcfg.CapabilityVersion(100), peers: xsync.NewMap[tailcfg.NodeID, struct{}](), } } // withSendError configures the mock to return the given error on send. func (m *mockNodeConnection) withSendError(err error) *mockNodeConnection { m.sendFn = func(_ *tailcfg.MapResponse) error { return err } return m } func (m *mockNodeConnection) nodeID() types.NodeID { return m.id } func (m *mockNodeConnection) version() tailcfg.CapabilityVersion { return m.ver } func (m *mockNodeConnection) send(data *tailcfg.MapResponse) error { if m.sendFn != nil { return m.sendFn(data) } m.mu.Lock() m.sent = append(m.sent, data) m.mu.Unlock() return nil } func (m *mockNodeConnection) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID { currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers)) for _, id := range currentPeers { currentSet[id] = struct{}{} } var removed []tailcfg.NodeID m.peers.Range(func(id tailcfg.NodeID, _ struct{}) bool { if _, exists := currentSet[id]; !exists { removed = append(removed, id) } return true }) return removed } func (m *mockNodeConnection) updateSentPeers(resp *tailcfg.MapResponse) { if resp == nil { return } if resp.Peers != nil { m.peers.Clear() for _, peer := range resp.Peers { m.peers.Store(peer.ID, struct{}{}) } } for _, peer := range resp.PeersChanged { m.peers.Store(peer.ID, struct{}{}) } for _, id := range resp.PeersRemoved { m.peers.Delete(id) } } // getSent returns a thread-safe copy of all sent responses. func (m *mockNodeConnection) getSent() []*tailcfg.MapResponse { m.mu.Lock() defer m.mu.Unlock() return append([]*tailcfg.MapResponse{}, m.sent...) } // ============================================================================ // Test Helpers // ============================================================================ // testMapResponse creates a minimal valid MapResponse for testing. func testMapResponse() *tailcfg.MapResponse { now := time.Now() return &tailcfg.MapResponse{ ControlTime: &now, } } // testMapResponseWithPeers creates a MapResponse with the given peer IDs. func testMapResponseWithPeers(peerIDs ...tailcfg.NodeID) *tailcfg.MapResponse { resp := testMapResponse() resp.Peers = make([]*tailcfg.Node, len(peerIDs)) for i, id := range peerIDs { resp.Peers[i] = &tailcfg.Node{ID: id} } return resp } // ids is a convenience for creating a slice of tailcfg.NodeID. func ids(nodeIDs ...tailcfg.NodeID) []tailcfg.NodeID { return nodeIDs } // expectReceive asserts that a message arrives on the channel within 100ms. func expectReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, msg string) *tailcfg.MapResponse { t.Helper() const timeout = 100 * time.Millisecond select { case data := <-ch: return data case <-time.After(timeout): t.Fatalf("expected to receive on channel within %v: %s", timeout, msg) return nil } } // expectNoReceive asserts that no message arrives within timeout. func expectNoReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, timeout time.Duration, msg string) { t.Helper() select { case data := <-ch: t.Fatalf("expected no receive but got %+v: %s", data, msg) case <-time.After(timeout): // Expected } } // makeConnectionEntry creates a connectionEntry with the given channel. func makeConnectionEntry(id string, ch chan<- *tailcfg.MapResponse) *connectionEntry { entry := &connectionEntry{ id: id, c: ch, version: tailcfg.CapabilityVersion(100), created: time.Now(), } entry.lastUsed.Store(time.Now().Unix()) return entry } // ============================================================================ // connectionEntry.send() Tests // ============================================================================ func TestConnectionEntry_SendSuccess(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 1) entry := makeConnectionEntry("test-conn", ch) data := testMapResponse() beforeSend := time.Now().Unix() err := entry.send(data) require.NoError(t, err) assert.GreaterOrEqual(t, entry.lastUsed.Load(), beforeSend, "lastUsed should be updated after successful send") // Verify data was actually sent received := expectReceive(t, ch, "data should be on channel") assert.Equal(t, data, received) } func TestConnectionEntry_SendNilData(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 1) entry := makeConnectionEntry("test-conn", ch) err := entry.send(nil) require.NoError(t, err, "nil data should return nil error") expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent to channel") } func TestConnectionEntry_SendTimeout(t *testing.T) { // Unbuffered channel with no reader = always blocks ch := make(chan *tailcfg.MapResponse) entry := makeConnectionEntry("test-conn", ch) data := testMapResponse() start := time.Now() err := entry.send(data) elapsed := time.Since(start) require.ErrorIs(t, err, ErrConnectionSendTimeout) assert.GreaterOrEqual(t, elapsed, 40*time.Millisecond, "should wait approximately 50ms before timeout") } func TestConnectionEntry_SendClosed(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 1) entry := makeConnectionEntry("test-conn", ch) // Mark as closed before sending entry.closed.Store(true) err := entry.send(testMapResponse()) require.ErrorIs(t, err, errConnectionClosed) expectNoReceive(t, ch, 10*time.Millisecond, "closed entry should not send data to channel") } func TestConnectionEntry_SendUpdatesLastUsed(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 1) entry := makeConnectionEntry("test-conn", ch) // Set lastUsed to a past time pastTime := time.Now().Add(-1 * time.Hour).Unix() entry.lastUsed.Store(pastTime) err := entry.send(testMapResponse()) require.NoError(t, err) assert.Greater(t, entry.lastUsed.Load(), pastTime, "lastUsed should be updated to current time after send") } // ============================================================================ // multiChannelNodeConn.send() Tests // ============================================================================ func TestMultiChannelSend_AllSuccess(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // Create 3 buffered channels (all will succeed) channels := make([]chan *tailcfg.MapResponse, 3) for i := range channels { channels[i] = make(chan *tailcfg.MapResponse, 1) mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i])) } data := testMapResponse() err := mc.send(data) require.NoError(t, err) assert.Equal(t, 3, mc.getActiveConnectionCount(), "all connections should remain active after success") // Verify all channels received the data for i, ch := range channels { received := expectReceive(t, ch, fmt.Sprintf("channel %d should receive data", i)) assert.Equal(t, data, received) } } func TestMultiChannelSend_PartialFailure(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // 2 buffered channels (will succeed) + 1 unbuffered (will timeout) goodCh1 := make(chan *tailcfg.MapResponse, 1) goodCh2 := make(chan *tailcfg.MapResponse, 1) badCh := make(chan *tailcfg.MapResponse) // unbuffered, no reader mc.addConnection(makeConnectionEntry("good-1", goodCh1)) mc.addConnection(makeConnectionEntry("bad", badCh)) mc.addConnection(makeConnectionEntry("good-2", goodCh2)) err := mc.send(testMapResponse()) require.NoError(t, err, "should succeed if at least one connection works") assert.Equal(t, 2, mc.getActiveConnectionCount(), "failed connection should be removed") // Good channels should have received data expectReceive(t, goodCh1, "good-1 should receive") expectReceive(t, goodCh2, "good-2 should receive") } func TestMultiChannelSend_AllFail(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // All unbuffered channels with no readers for i := range 3 { ch := make(chan *tailcfg.MapResponse) // unbuffered mc.addConnection(makeConnectionEntry(fmt.Sprintf("bad-%d", i), ch)) } err := mc.send(testMapResponse()) require.Error(t, err, "should return error when all connections fail") assert.Equal(t, 0, mc.getActiveConnectionCount(), "all failed connections should be removed") } func TestMultiChannelSend_ZeroConnections(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) err := mc.send(testMapResponse()) require.NoError(t, err, "sending to node with 0 connections should succeed silently (rapid reconnection scenario)") } func TestMultiChannelSend_NilData(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) ch := make(chan *tailcfg.MapResponse, 1) mc.addConnection(makeConnectionEntry("conn", ch)) err := mc.send(nil) require.NoError(t, err, "nil data should return nil immediately") expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent") } func TestMultiChannelSend_FailedConnectionRemoved(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) goodCh := make(chan *tailcfg.MapResponse, 10) // large buffer badCh := make(chan *tailcfg.MapResponse) // unbuffered, will timeout mc.addConnection(makeConnectionEntry("good", goodCh)) mc.addConnection(makeConnectionEntry("bad", badCh)) assert.Equal(t, 2, mc.getActiveConnectionCount()) // First send: bad connection removed err := mc.send(testMapResponse()) require.NoError(t, err) assert.Equal(t, 1, mc.getActiveConnectionCount()) // Second send: only good connection remains, should succeed err = mc.send(testMapResponse()) require.NoError(t, err) assert.Equal(t, 1, mc.getActiveConnectionCount()) } func TestMultiChannelSend_UpdateCount(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) ch := make(chan *tailcfg.MapResponse, 10) mc.addConnection(makeConnectionEntry("conn", ch)) assert.Equal(t, int64(0), mc.updateCount.Load()) _ = mc.send(testMapResponse()) assert.Equal(t, int64(1), mc.updateCount.Load()) _ = mc.send(testMapResponse()) assert.Equal(t, int64(2), mc.updateCount.Load()) } // ============================================================================ // multiChannelNodeConn.close() Tests // ============================================================================ func TestMultiChannelClose_MarksEntriesClosed(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) entries := make([]*connectionEntry, 3) for i := range entries { ch := make(chan *tailcfg.MapResponse, 1) entries[i] = makeConnectionEntry(fmt.Sprintf("conn-%d", i), ch) mc.addConnection(entries[i]) } mc.close() for i, entry := range entries { assert.True(t, entry.closed.Load(), "entry %d should be marked as closed", i) } } func TestMultiChannelClose_PreventsSendPanic(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) ch := make(chan *tailcfg.MapResponse, 1) entry := makeConnectionEntry("conn", ch) mc.addConnection(entry) mc.close() // After close, connectionEntry.send should return errConnectionClosed // (not panic on send to closed channel) err := entry.send(testMapResponse()) require.ErrorIs(t, err, errConnectionClosed, "send after close should return errConnectionClosed, not panic") } // ============================================================================ // multiChannelNodeConn connection management Tests // ============================================================================ func TestMultiChannelNodeConn_AddRemoveConnections(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) ch1 := make(chan *tailcfg.MapResponse, 1) ch2 := make(chan *tailcfg.MapResponse, 1) ch3 := make(chan *tailcfg.MapResponse, 1) // Add connections mc.addConnection(makeConnectionEntry("c1", ch1)) assert.Equal(t, 1, mc.getActiveConnectionCount()) assert.True(t, mc.hasActiveConnections()) mc.addConnection(makeConnectionEntry("c2", ch2)) mc.addConnection(makeConnectionEntry("c3", ch3)) assert.Equal(t, 3, mc.getActiveConnectionCount()) // Remove by channel pointer assert.True(t, mc.removeConnectionByChannel(ch2)) assert.Equal(t, 2, mc.getActiveConnectionCount()) // Remove non-existent channel nonExistentCh := make(chan *tailcfg.MapResponse) assert.False(t, mc.removeConnectionByChannel(nonExistentCh)) assert.Equal(t, 2, mc.getActiveConnectionCount()) // Remove remaining assert.True(t, mc.removeConnectionByChannel(ch1)) assert.True(t, mc.removeConnectionByChannel(ch3)) assert.Equal(t, 0, mc.getActiveConnectionCount()) assert.False(t, mc.hasActiveConnections()) } func TestMultiChannelNodeConn_Version(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // No connections - version should be 0 assert.Equal(t, tailcfg.CapabilityVersion(0), mc.version()) // Add connection with version 100 ch := make(chan *tailcfg.MapResponse, 1) entry := makeConnectionEntry("conn", ch) entry.version = tailcfg.CapabilityVersion(100) mc.addConnection(entry) assert.Equal(t, tailcfg.CapabilityVersion(100), mc.version()) } // ============================================================================ // computePeerDiff Tests // ============================================================================ func TestComputePeerDiff(t *testing.T) { tests := []struct { name string tracked []tailcfg.NodeID // peers previously sent to client current []tailcfg.NodeID // peers visible now wantRemoved []tailcfg.NodeID // expected removed peers }{ { name: "no_changes", tracked: ids(1, 2, 3), current: ids(1, 2, 3), wantRemoved: nil, }, { name: "one_removed", tracked: ids(1, 2, 3), current: ids(1, 3), wantRemoved: ids(2), }, { name: "multiple_removed", tracked: ids(1, 2, 3, 4, 5), current: ids(2, 4), wantRemoved: ids(1, 3, 5), }, { name: "all_removed", tracked: ids(1, 2, 3), current: nil, wantRemoved: ids(1, 2, 3), }, { name: "peers_added_no_removal", tracked: ids(1), current: ids(1, 2, 3), wantRemoved: nil, }, { name: "empty_tracked", tracked: nil, current: ids(1, 2, 3), wantRemoved: nil, }, { name: "both_empty", tracked: nil, current: nil, wantRemoved: nil, }, { name: "disjoint_sets", tracked: ids(1, 2, 3), current: ids(4, 5, 6), wantRemoved: ids(1, 2, 3), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // Populate tracked peers for _, id := range tt.tracked { mc.lastSentPeers.Store(id, struct{}{}) } got := mc.computePeerDiff(tt.current) assert.ElementsMatch(t, tt.wantRemoved, got, "removed peers should match expected") }) } } // ============================================================================ // updateSentPeers Tests // ============================================================================ func TestUpdateSentPeers(t *testing.T) { t.Run("full_peer_list_replaces_all", func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // Pre-populate with old peers mc.lastSentPeers.Store(tailcfg.NodeID(100), struct{}{}) mc.lastSentPeers.Store(tailcfg.NodeID(200), struct{}{}) // Send full peer list mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3)) // Old peers should be gone _, exists := mc.lastSentPeers.Load(tailcfg.NodeID(100)) assert.False(t, exists, "old peer 100 should be cleared") // New peers should be tracked for _, id := range ids(1, 2, 3) { _, exists := mc.lastSentPeers.Load(id) assert.True(t, exists, "peer %d should be tracked", id) } }) t.Run("incremental_add_via_PeersChanged", func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{}) resp := testMapResponse() resp.PeersChanged = []*tailcfg.Node{{ID: 2}, {ID: 3}} mc.updateSentPeers(resp) // All three should be tracked for _, id := range ids(1, 2, 3) { _, exists := mc.lastSentPeers.Load(id) assert.True(t, exists, "peer %d should be tracked", id) } }) t.Run("incremental_remove_via_PeersRemoved", func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{}) mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{}) mc.lastSentPeers.Store(tailcfg.NodeID(3), struct{}{}) resp := testMapResponse() resp.PeersRemoved = ids(2) mc.updateSentPeers(resp) _, exists1 := mc.lastSentPeers.Load(tailcfg.NodeID(1)) _, exists2 := mc.lastSentPeers.Load(tailcfg.NodeID(2)) _, exists3 := mc.lastSentPeers.Load(tailcfg.NodeID(3)) assert.True(t, exists1, "peer 1 should remain") assert.False(t, exists2, "peer 2 should be removed") assert.True(t, exists3, "peer 3 should remain") }) t.Run("nil_response_is_noop", func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{}) mc.updateSentPeers(nil) _, exists := mc.lastSentPeers.Load(tailcfg.NodeID(1)) assert.True(t, exists, "nil response should not change tracked peers") }) t.Run("full_then_incremental_sequence", func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // Step 1: Full peer list mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3)) // Step 2: Add peer 4 resp := testMapResponse() resp.PeersChanged = []*tailcfg.Node{{ID: 4}} mc.updateSentPeers(resp) // Step 3: Remove peer 2 resp2 := testMapResponse() resp2.PeersRemoved = ids(2) mc.updateSentPeers(resp2) // Should have 1, 3, 4 for _, id := range ids(1, 3, 4) { _, exists := mc.lastSentPeers.Load(id) assert.True(t, exists, "peer %d should be tracked", id) } _, exists := mc.lastSentPeers.Load(tailcfg.NodeID(2)) assert.False(t, exists, "peer 2 should have been removed") }) t.Run("empty_full_peer_list_clears_all", func(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{}) mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{}) // Empty Peers slice (not nil) means "no peers" resp := testMapResponse() resp.Peers = []*tailcfg.Node{} // empty, not nil mc.updateSentPeers(resp) count := 0 mc.lastSentPeers.Range(func(_ tailcfg.NodeID, _ struct{}) bool { count++ return true }) assert.Equal(t, 0, count, "empty peer list should clear all tracking") }) } // ============================================================================ // generateMapResponse Tests (branching logic only, no DB needed) // ============================================================================ func TestGenerateMapResponse_EmptyChange(t *testing.T) { mc := newMockNodeConnection(1) resp, err := generateMapResponse(mc, nil, change.Change{}) require.NoError(t, err) assert.Nil(t, resp, "empty change should return nil response") } func TestGenerateMapResponse_InvalidNodeID(t *testing.T) { mc := newMockNodeConnection(0) // Invalid ID resp, err := generateMapResponse(mc, &mapper{}, change.DERPMap()) require.ErrorIs(t, err, ErrInvalidNodeID) assert.Nil(t, resp) } func TestGenerateMapResponse_NilMapper(t *testing.T) { mc := newMockNodeConnection(1) resp, err := generateMapResponse(mc, nil, change.DERPMap()) require.ErrorIs(t, err, ErrMapperNil) assert.Nil(t, resp) } func TestGenerateMapResponse_SelfOnlyOtherNode(t *testing.T) { mc := newMockNodeConnection(1) // SelfUpdate targeted at node 99 should be skipped for node 1 ch := change.SelfUpdate(99) resp, err := generateMapResponse(mc, &mapper{}, ch) require.NoError(t, err) assert.Nil(t, resp, "self-only change targeted at different node should return nil") } func TestGenerateMapResponse_SelfOnlySameNode(t *testing.T) { // SelfUpdate targeted at node 1: IsSelfOnly()=true and TargetNode==nodeID // This should NOT be short-circuited - it should attempt to generate. // We verify the routing logic by checking that the change is not empty // and not filtered out (unlike SelfOnlyOtherNode above). ch := change.SelfUpdate(1) assert.False(t, ch.IsEmpty(), "SelfUpdate should not be empty") assert.True(t, ch.IsSelfOnly(), "SelfUpdate should be self-only") assert.True(t, ch.ShouldSendToNode(1), "should be sent to target node") assert.False(t, ch.ShouldSendToNode(2), "should NOT be sent to other nodes") } // ============================================================================ // handleNodeChange Tests // ============================================================================ func TestHandleNodeChange_NilConnection(t *testing.T) { err := handleNodeChange(nil, nil, change.DERPMap()) assert.ErrorIs(t, err, ErrNodeConnectionNil) } func TestHandleNodeChange_EmptyChange(t *testing.T) { mc := newMockNodeConnection(1) err := handleNodeChange(mc, nil, change.Change{}) require.NoError(t, err, "empty change should not send anything") assert.Empty(t, mc.getSent(), "no data should be sent for empty change") } var errConnectionBroken = errors.New("connection broken") func TestHandleNodeChange_SendError(t *testing.T) { mc := newMockNodeConnection(1).withSendError(errConnectionBroken) // Need a real mapper for this test - we can't easily mock it. // Instead, test that when generateMapResponse returns nil data, // no send occurs. The send error path requires a valid MapResponse // which requires a mapper with state. // So we test the nil-data path here. err := handleNodeChange(mc, nil, change.Change{}) assert.NoError(t, err, "empty change produces nil data, no send needed") } func TestHandleNodeChange_NilDataNoSend(t *testing.T) { mc := newMockNodeConnection(1) // SelfUpdate targeted at different node produces nil data ch := change.SelfUpdate(99) err := handleNodeChange(mc, &mapper{}, ch) require.NoError(t, err, "nil data should not cause error") assert.Empty(t, mc.getSent(), "nil data should not trigger send") } // ============================================================================ // connectionEntry concurrent safety Tests // ============================================================================ func TestConnectionEntry_ConcurrentSends(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 100) entry := makeConnectionEntry("concurrent", ch) var ( wg sync.WaitGroup successCount atomic.Int64 ) // 50 goroutines sending concurrently for range 50 { wg.Go(func() { err := entry.send(testMapResponse()) if err == nil { successCount.Add(1) } }) } wg.Wait() assert.Equal(t, int64(50), successCount.Load(), "all sends to buffered channel should succeed") // Drain and count count := 0 for range len(ch) { <-ch count++ } assert.Equal(t, 50, count, "all 50 messages should be on channel") } func TestConnectionEntry_ConcurrentSendAndClose(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 100) entry := makeConnectionEntry("race", ch) var ( wg sync.WaitGroup panicked atomic.Bool ) // Goroutines sending rapidly for range 20 { wg.Go(func() { defer func() { if r := recover(); r != nil { panicked.Store(true) } }() for range 10 { _ = entry.send(testMapResponse()) } }) } // Close midway through wg.Go(func() { time.Sleep(1 * time.Millisecond) //nolint:forbidigo // concurrency test coordination entry.closed.Store(true) }) wg.Wait() assert.False(t, panicked.Load(), "concurrent send and close should not panic") } // ============================================================================ // multiChannelNodeConn concurrent Tests // ============================================================================ func TestMultiChannelSend_ConcurrentAddAndSend(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // Start with one connection ch1 := make(chan *tailcfg.MapResponse, 100) mc.addConnection(makeConnectionEntry("initial", ch1)) var ( wg sync.WaitGroup panicked atomic.Bool ) // Goroutine adding connections wg.Go(func() { defer func() { if r := recover(); r != nil { panicked.Store(true) } }() for i := range 10 { ch := make(chan *tailcfg.MapResponse, 100) mc.addConnection(makeConnectionEntry(fmt.Sprintf("added-%d", i), ch)) } }) // Goroutine sending data wg.Go(func() { defer func() { if r := recover(); r != nil { panicked.Store(true) } }() for range 20 { _ = mc.send(testMapResponse()) } }) wg.Wait() assert.False(t, panicked.Load(), "concurrent add and send should not panic (mutex protects both)") } func TestMultiChannelSend_ConcurrentRemoveAndSend(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) channels := make([]chan *tailcfg.MapResponse, 10) for i := range channels { channels[i] = make(chan *tailcfg.MapResponse, 100) mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i])) } var ( wg sync.WaitGroup panicked atomic.Bool ) // Goroutine removing connections wg.Go(func() { defer func() { if r := recover(); r != nil { panicked.Store(true) } }() for _, ch := range channels { mc.removeConnectionByChannel(ch) } }) // Goroutine sending data concurrently wg.Go(func() { defer func() { if r := recover(); r != nil { panicked.Store(true) } }() for range 20 { _ = mc.send(testMapResponse()) } }) wg.Wait() assert.False(t, panicked.Load(), "concurrent remove and send should not panic") } // ============================================================================ // Regression tests for H1 (timer leak) and H3 (lifecycle) // ============================================================================ // TestConnectionEntry_SendFastPath_TimerStopped is a regression guard for H1. // Before the fix, connectionEntry.send used time.After(50ms) which leaked a // timer into the runtime heap on every call even when the channel send // succeeded immediately. The fix switched to time.NewTimer + defer Stop(). // // This test sends many messages on a buffered (non-blocking) channel and // checks that the number of live goroutines stays bounded, which would // grow without bound under the old time.After approach at high call rates. func TestConnectionEntry_SendFastPath_TimerStopped(t *testing.T) { const sends = 5000 ch := make(chan *tailcfg.MapResponse, sends) entry := &connectionEntry{ id: "timer-leak-test", c: ch, version: 100, created: time.Now(), } resp := testMapResponse() for range sends { err := entry.send(resp) require.NoError(t, err) } // Drain the channel so we aren't holding references. for range sends { <-ch } // Force a GC + timer cleanup pass. runtime.GC() // If timers were leaking we'd see a goroutine count much higher // than baseline. With 5000 leaked timers the count would be // noticeably elevated. We just check it's reasonable. numGR := runtime.NumGoroutine() assert.Less(t, numGR, 200, "goroutine count after %d fast-path sends should be bounded; got %d (possible timer leak)", sends, numGR) } // TestBatcher_CloseWaitsForWorkers is a regression guard for H3. // Before the fix, Close() would tear down node connections while workers // were potentially still running, risking sends on closed channels. // The fix added sync.WaitGroup tracking so Close() blocks until all // worker goroutines exit. func TestBatcher_CloseWaitsForWorkers(t *testing.T) { b := NewBatcher(50*time.Millisecond, 4, nil) goroutinesBefore := runtime.NumGoroutine() b.Start() // Give workers time to start. time.Sleep(20 * time.Millisecond) //nolint:forbidigo // test timing goroutinesDuring := runtime.NumGoroutine() // We expect at least 5 new goroutines: 1 doWork + 4 workers. assert.GreaterOrEqual(t, goroutinesDuring-goroutinesBefore, 5, "expected doWork + 4 workers to be running") // Close should block until all workers have exited. b.Close() // After Close returns, goroutines should have dropped back. // Allow a small margin for runtime goroutines. goroutinesAfter := runtime.NumGoroutine() assert.InDelta(t, goroutinesBefore, goroutinesAfter, 3, "goroutines should return to baseline after Close(); before=%d after=%d", goroutinesBefore, goroutinesAfter) } // TestBatcher_CloseThenStartIsNoop verifies the lifecycle contract: // once a Batcher has been started, calling Start() again is a no-op // (the started flag prevents double-start). func TestBatcher_CloseThenStartIsNoop(t *testing.T) { b := NewBatcher(50*time.Millisecond, 2, nil) b.Start() b.Close() goroutinesBefore := runtime.NumGoroutine() // Second Start should be a no-op because started is already true. b.Start() // Allow a moment for any hypothetical goroutine to appear. time.Sleep(10 * time.Millisecond) //nolint:forbidigo // test timing goroutinesAfter := runtime.NumGoroutine() assert.InDelta(t, goroutinesBefore, goroutinesAfter, 1, "Start() after Close() should not spawn new goroutines; before=%d after=%d", goroutinesBefore, goroutinesAfter) } // TestBatcher_CloseStopsTicker verifies that Close() stops the internal // ticker, preventing resource leaks. func TestBatcher_CloseStopsTicker(t *testing.T) { b := NewBatcher(10*time.Millisecond, 1, nil) b.Start() b.Close() // After Close, the ticker should be stopped. Reading from a stopped // ticker's channel should not deliver any values. select { case <-b.tick.C: t.Fatal("ticker fired after Close(); ticker.Stop() was not called") case <-time.After(50 * time.Millisecond): //nolint:forbidigo // test timing // Expected: no tick received. } } // ============================================================================ // Regression tests for M1, M3, M7 // ============================================================================ // TestBatcher_CloseBeforeStart_DoesNotHang is a regression guard for M1. // Before the fix, done was nil until Start() was called. queueWork and // MapResponseFromChange select on done, so a nil channel would block // forever when workCh was full. With done initialized in NewBatcher, // Close() can be called safely before Start(). func TestBatcher_CloseBeforeStart_DoesNotHang(t *testing.T) { b := NewBatcher(50*time.Millisecond, 2, nil) // Close without Start must not panic or hang. done := make(chan struct{}) go func() { b.Close() close(done) }() select { case <-done: // Success: Close returned promptly. case <-time.After(2 * time.Second): //nolint:forbidigo // test timing t.Fatal("Close() before Start() hung; done channel was likely nil") } } // TestBatcher_QueueWorkAfterClose_DoesNotHang verifies that queueWork // returns immediately via the done channel when the batcher is closed, // even without Start() having been called. func TestBatcher_QueueWorkAfterClose_DoesNotHang(t *testing.T) { b := NewBatcher(50*time.Millisecond, 1, nil) b.Close() done := make(chan struct{}) go func() { // queueWork selects on done; with done closed this must return. b.queueWork(work{}) close(done) }() select { case <-done: // Success case <-time.After(2 * time.Second): //nolint:forbidigo // test timing t.Fatal("queueWork hung after Close(); done channel select not working") } } // TestIsConnected_FalseAfterAddNodeFailure is a regression guard for M3. // Before the fix, AddNode error paths removed the connection but left // b.connected with its previous value (nil = connected). IsConnected // would return true for a node with zero active connections. func TestIsConnected_FalseAfterAddNodeFailure(t *testing.T) { b := NewBatcher(50*time.Millisecond, 2, nil) b.Start() defer b.Close() id := types.NodeID(42) // Simulate a previous session leaving the node marked as connected. b.connected.Store(id, nil) // nil = connected // Pre-create the node entry so AddNode reuses it, and set up a // multiChannelNodeConn with no mapper so MapResponseFromChange will fail. nc := newMultiChannelNodeConn(id, nil) b.nodes.Store(id, nc) ch := make(chan *tailcfg.MapResponse, 1) err := b.AddNode(id, ch, 100, func() {}) require.Error(t, err, "AddNode should fail with nil mapper") // After failure, the node should NOT be reported as connected. assert.False(t, b.IsConnected(id), "IsConnected should return false after AddNode failure with no remaining connections") } // TestRemoveConnectionAtIndex_NilsTrailingSlot is a regression guard for M7. // Before the fix, removeConnectionAtIndexLocked used append(s[:i], s[i+1:]...) // which left a stale pointer in the backing array's last slot. The fix // uses copy + explicit nil of the trailing element. func TestRemoveConnectionAtIndex_NilsTrailingSlot(t *testing.T) { mc := newMultiChannelNodeConn(1, nil) // Manually add three entries under the lock. entries := make([]*connectionEntry, 3) for i := range entries { entries[i] = &connectionEntry{id: fmt.Sprintf("conn-%d", i), c: make(chan<- *tailcfg.MapResponse)} } mc.mutex.Lock() mc.connections = append(mc.connections, entries...) // Remove the middle entry (index 1). removed := mc.removeConnectionAtIndexLocked(1, false) require.Equal(t, entries[1], removed) // After removal, len should be 2 and the backing array slot at // index 2 (the old len-1) should be nil. require.Len(t, mc.connections, 2) assert.Equal(t, entries[0], mc.connections[0]) assert.Equal(t, entries[2], mc.connections[1]) // Check the backing array directly: the slot just past the new // length must be nil to avoid retaining the pointer. backing := mc.connections[:3] assert.Nil(t, backing[2], "trailing slot in backing array should be nil after removal") mc.mutex.Unlock() }