Files
headscale/hscontrol/mapper/batcher_unit_test.go
Kristoffer Dalby 3587225a88 mapper: fix phantom updateSentPeers on disconnected nodes
When send() is called on a node with zero active connections
(disconnected but kept for rapid reconnection), it returns nil
(success). handleNodeChange then calls updateSentPeers, recording
peers as delivered when no client received the data.

This corrupts lastSentPeers: future computePeerDiff calculations
produce wrong results because they compare against phantom state.
After reconnection, the node's initial map resets lastSentPeers,
but any changes processed during the disconnect window leave
stale entries that cause asymmetric peer visibility.

Return errNoActiveConnections from send() when there are no
connections. handleNodeChange treats this as a no-op (the change
was generated but not deliverable) and skips updateSentPeers,
keeping lastSentPeers consistent with what clients actually
received.
2026-04-10 13:18:56 +01:00

1187 lines
34 KiB
Go

package mapper
// Unit tests for batcher components that do NOT require database setup.
// These tests exercise connectionEntry, multiChannelNodeConn, computePeerDiff,
// updateSentPeers, generateMapResponse branching, and handleNodeChange in isolation.
import (
"errors"
"fmt"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
// ============================================================================
// Mock Infrastructure
// ============================================================================
// mockNodeConnection implements nodeConnection for isolated unit testing
// of generateMapResponse and handleNodeChange without a real database.
type mockNodeConnection struct {
id types.NodeID
ver tailcfg.CapabilityVersion
// sendFn allows injecting custom send behavior.
// If nil, sends are recorded and succeed.
sendFn func(*tailcfg.MapResponse) error
// sent records all successful sends for assertion.
sent []*tailcfg.MapResponse
mu sync.Mutex
// Peer tracking
peers *xsync.Map[tailcfg.NodeID, struct{}]
}
func newMockNodeConnection(id types.NodeID) *mockNodeConnection {
return &mockNodeConnection{
id: id,
ver: tailcfg.CapabilityVersion(100),
peers: xsync.NewMap[tailcfg.NodeID, struct{}](),
}
}
// withSendError configures the mock to return the given error on send.
func (m *mockNodeConnection) withSendError(err error) *mockNodeConnection {
m.sendFn = func(_ *tailcfg.MapResponse) error { return err }
return m
}
func (m *mockNodeConnection) nodeID() types.NodeID { return m.id }
func (m *mockNodeConnection) version() tailcfg.CapabilityVersion { return m.ver }
func (m *mockNodeConnection) send(data *tailcfg.MapResponse) error {
if m.sendFn != nil {
return m.sendFn(data)
}
m.mu.Lock()
m.sent = append(m.sent, data)
m.mu.Unlock()
return nil
}
func (m *mockNodeConnection) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
for _, id := range currentPeers {
currentSet[id] = struct{}{}
}
var removed []tailcfg.NodeID
m.peers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
if _, exists := currentSet[id]; !exists {
removed = append(removed, id)
}
return true
})
return removed
}
func (m *mockNodeConnection) updateSentPeers(resp *tailcfg.MapResponse) {
if resp == nil {
return
}
if resp.Peers != nil {
m.peers.Clear()
for _, peer := range resp.Peers {
m.peers.Store(peer.ID, struct{}{})
}
}
for _, peer := range resp.PeersChanged {
m.peers.Store(peer.ID, struct{}{})
}
for _, id := range resp.PeersRemoved {
m.peers.Delete(id)
}
}
// getSent returns a thread-safe copy of all sent responses.
func (m *mockNodeConnection) getSent() []*tailcfg.MapResponse {
m.mu.Lock()
defer m.mu.Unlock()
return append([]*tailcfg.MapResponse{}, m.sent...)
}
// ============================================================================
// Test Helpers
// ============================================================================
// testMapResponse creates a minimal valid MapResponse for testing.
func testMapResponse() *tailcfg.MapResponse {
now := time.Now()
return &tailcfg.MapResponse{
ControlTime: &now,
}
}
// testMapResponseWithPeers creates a MapResponse with the given peer IDs.
func testMapResponseWithPeers(peerIDs ...tailcfg.NodeID) *tailcfg.MapResponse {
resp := testMapResponse()
resp.Peers = make([]*tailcfg.Node, len(peerIDs))
for i, id := range peerIDs {
resp.Peers[i] = &tailcfg.Node{ID: id}
}
return resp
}
// ids is a convenience for creating a slice of tailcfg.NodeID.
func ids(nodeIDs ...tailcfg.NodeID) []tailcfg.NodeID {
return nodeIDs
}
// expectReceive asserts that a message arrives on the channel within 100ms.
func expectReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, msg string) *tailcfg.MapResponse {
t.Helper()
const timeout = 100 * time.Millisecond
select {
case data := <-ch:
return data
case <-time.After(timeout):
t.Fatalf("expected to receive on channel within %v: %s", timeout, msg)
return nil
}
}
// expectNoReceive asserts that no message arrives within timeout.
func expectNoReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, timeout time.Duration, msg string) {
t.Helper()
select {
case data := <-ch:
t.Fatalf("expected no receive but got %+v: %s", data, msg)
case <-time.After(timeout):
// Expected
}
}
// makeConnectionEntry creates a connectionEntry with the given channel.
func makeConnectionEntry(id string, ch chan<- *tailcfg.MapResponse) *connectionEntry {
entry := &connectionEntry{
id: id,
c: ch,
version: tailcfg.CapabilityVersion(100),
created: time.Now(),
}
entry.lastUsed.Store(time.Now().Unix())
return entry
}
// ============================================================================
// connectionEntry.send() Tests
// ============================================================================
func TestConnectionEntry_SendSuccess(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 1)
entry := makeConnectionEntry("test-conn", ch)
data := testMapResponse()
beforeSend := time.Now().Unix()
err := entry.send(data)
require.NoError(t, err)
assert.GreaterOrEqual(t, entry.lastUsed.Load(), beforeSend,
"lastUsed should be updated after successful send")
// Verify data was actually sent
received := expectReceive(t, ch, "data should be on channel")
assert.Equal(t, data, received)
}
func TestConnectionEntry_SendNilData(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 1)
entry := makeConnectionEntry("test-conn", ch)
err := entry.send(nil)
require.NoError(t, err, "nil data should return nil error")
expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent to channel")
}
func TestConnectionEntry_SendTimeout(t *testing.T) {
// Unbuffered channel with no reader = always blocks
ch := make(chan *tailcfg.MapResponse)
entry := makeConnectionEntry("test-conn", ch)
data := testMapResponse()
start := time.Now()
err := entry.send(data)
elapsed := time.Since(start)
require.ErrorIs(t, err, ErrConnectionSendTimeout)
assert.GreaterOrEqual(t, elapsed, 40*time.Millisecond,
"should wait approximately 50ms before timeout")
}
func TestConnectionEntry_SendClosed(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 1)
entry := makeConnectionEntry("test-conn", ch)
// Mark as closed before sending
entry.closed.Store(true)
err := entry.send(testMapResponse())
require.ErrorIs(t, err, errConnectionClosed)
expectNoReceive(t, ch, 10*time.Millisecond,
"closed entry should not send data to channel")
}
func TestConnectionEntry_SendUpdatesLastUsed(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 1)
entry := makeConnectionEntry("test-conn", ch)
// Set lastUsed to a past time
pastTime := time.Now().Add(-1 * time.Hour).Unix()
entry.lastUsed.Store(pastTime)
err := entry.send(testMapResponse())
require.NoError(t, err)
assert.Greater(t, entry.lastUsed.Load(), pastTime,
"lastUsed should be updated to current time after send")
}
// ============================================================================
// multiChannelNodeConn.send() Tests
// ============================================================================
func TestMultiChannelSend_AllSuccess(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// Create 3 buffered channels (all will succeed)
channels := make([]chan *tailcfg.MapResponse, 3)
for i := range channels {
channels[i] = make(chan *tailcfg.MapResponse, 1)
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
}
data := testMapResponse()
err := mc.send(data)
require.NoError(t, err)
assert.Equal(t, 3, mc.getActiveConnectionCount(),
"all connections should remain active after success")
// Verify all channels received the data
for i, ch := range channels {
received := expectReceive(t, ch,
fmt.Sprintf("channel %d should receive data", i))
assert.Equal(t, data, received)
}
}
func TestMultiChannelSend_PartialFailure(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// 2 buffered channels (will succeed) + 1 unbuffered (will timeout)
goodCh1 := make(chan *tailcfg.MapResponse, 1)
goodCh2 := make(chan *tailcfg.MapResponse, 1)
badCh := make(chan *tailcfg.MapResponse) // unbuffered, no reader
mc.addConnection(makeConnectionEntry("good-1", goodCh1))
mc.addConnection(makeConnectionEntry("bad", badCh))
mc.addConnection(makeConnectionEntry("good-2", goodCh2))
err := mc.send(testMapResponse())
require.NoError(t, err, "should succeed if at least one connection works")
assert.Equal(t, 2, mc.getActiveConnectionCount(),
"failed connection should be removed")
// Good channels should have received data
expectReceive(t, goodCh1, "good-1 should receive")
expectReceive(t, goodCh2, "good-2 should receive")
}
func TestMultiChannelSend_AllFail(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// All unbuffered channels with no readers
for i := range 3 {
ch := make(chan *tailcfg.MapResponse) // unbuffered
mc.addConnection(makeConnectionEntry(fmt.Sprintf("bad-%d", i), ch))
}
err := mc.send(testMapResponse())
require.Error(t, err, "should return error when all connections fail")
assert.Equal(t, 0, mc.getActiveConnectionCount(),
"all failed connections should be removed")
}
func TestMultiChannelSend_ZeroConnections(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
err := mc.send(testMapResponse())
require.ErrorIs(t, err, errNoActiveConnections,
"sending to node with 0 connections should return errNoActiveConnections "+
"so callers skip updateSentPeers (prevents phantom peer state)")
}
func TestMultiChannelSend_NilData(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
ch := make(chan *tailcfg.MapResponse, 1)
mc.addConnection(makeConnectionEntry("conn", ch))
err := mc.send(nil)
require.NoError(t, err, "nil data should return nil immediately")
expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent")
}
func TestMultiChannelSend_FailedConnectionRemoved(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
goodCh := make(chan *tailcfg.MapResponse, 10) // large buffer
badCh := make(chan *tailcfg.MapResponse) // unbuffered, will timeout
mc.addConnection(makeConnectionEntry("good", goodCh))
mc.addConnection(makeConnectionEntry("bad", badCh))
assert.Equal(t, 2, mc.getActiveConnectionCount())
// First send: bad connection removed
err := mc.send(testMapResponse())
require.NoError(t, err)
assert.Equal(t, 1, mc.getActiveConnectionCount())
// Second send: only good connection remains, should succeed
err = mc.send(testMapResponse())
require.NoError(t, err)
assert.Equal(t, 1, mc.getActiveConnectionCount())
}
func TestMultiChannelSend_UpdateCount(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
ch := make(chan *tailcfg.MapResponse, 10)
mc.addConnection(makeConnectionEntry("conn", ch))
assert.Equal(t, int64(0), mc.updateCount.Load())
_ = mc.send(testMapResponse())
assert.Equal(t, int64(1), mc.updateCount.Load())
_ = mc.send(testMapResponse())
assert.Equal(t, int64(2), mc.updateCount.Load())
}
// ============================================================================
// multiChannelNodeConn.close() Tests
// ============================================================================
func TestMultiChannelClose_MarksEntriesClosed(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
entries := make([]*connectionEntry, 3)
for i := range entries {
ch := make(chan *tailcfg.MapResponse, 1)
entries[i] = makeConnectionEntry(fmt.Sprintf("conn-%d", i), ch)
mc.addConnection(entries[i])
}
mc.close()
for i, entry := range entries {
assert.True(t, entry.closed.Load(),
"entry %d should be marked as closed", i)
}
}
func TestMultiChannelClose_PreventsSendPanic(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
ch := make(chan *tailcfg.MapResponse, 1)
entry := makeConnectionEntry("conn", ch)
mc.addConnection(entry)
mc.close()
// After close, connectionEntry.send should return errConnectionClosed
// (not panic on send to closed channel)
err := entry.send(testMapResponse())
require.ErrorIs(t, err, errConnectionClosed,
"send after close should return errConnectionClosed, not panic")
}
// ============================================================================
// multiChannelNodeConn connection management Tests
// ============================================================================
func TestMultiChannelNodeConn_AddRemoveConnections(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
ch1 := make(chan *tailcfg.MapResponse, 1)
ch2 := make(chan *tailcfg.MapResponse, 1)
ch3 := make(chan *tailcfg.MapResponse, 1)
// Add connections
mc.addConnection(makeConnectionEntry("c1", ch1))
assert.Equal(t, 1, mc.getActiveConnectionCount())
assert.True(t, mc.hasActiveConnections())
mc.addConnection(makeConnectionEntry("c2", ch2))
mc.addConnection(makeConnectionEntry("c3", ch3))
assert.Equal(t, 3, mc.getActiveConnectionCount())
// Remove by channel pointer
assert.True(t, mc.removeConnectionByChannel(ch2))
assert.Equal(t, 2, mc.getActiveConnectionCount())
// Remove non-existent channel
nonExistentCh := make(chan *tailcfg.MapResponse)
assert.False(t, mc.removeConnectionByChannel(nonExistentCh))
assert.Equal(t, 2, mc.getActiveConnectionCount())
// Remove remaining
assert.True(t, mc.removeConnectionByChannel(ch1))
assert.True(t, mc.removeConnectionByChannel(ch3))
assert.Equal(t, 0, mc.getActiveConnectionCount())
assert.False(t, mc.hasActiveConnections())
}
func TestMultiChannelNodeConn_Version(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// No connections - version should be 0
assert.Equal(t, tailcfg.CapabilityVersion(0), mc.version())
// Add connection with version 100
ch := make(chan *tailcfg.MapResponse, 1)
entry := makeConnectionEntry("conn", ch)
entry.version = tailcfg.CapabilityVersion(100)
mc.addConnection(entry)
assert.Equal(t, tailcfg.CapabilityVersion(100), mc.version())
}
// ============================================================================
// computePeerDiff Tests
// ============================================================================
func TestComputePeerDiff(t *testing.T) {
tests := []struct {
name string
tracked []tailcfg.NodeID // peers previously sent to client
current []tailcfg.NodeID // peers visible now
wantRemoved []tailcfg.NodeID // expected removed peers
}{
{
name: "no_changes",
tracked: ids(1, 2, 3),
current: ids(1, 2, 3),
wantRemoved: nil,
},
{
name: "one_removed",
tracked: ids(1, 2, 3),
current: ids(1, 3),
wantRemoved: ids(2),
},
{
name: "multiple_removed",
tracked: ids(1, 2, 3, 4, 5),
current: ids(2, 4),
wantRemoved: ids(1, 3, 5),
},
{
name: "all_removed",
tracked: ids(1, 2, 3),
current: nil,
wantRemoved: ids(1, 2, 3),
},
{
name: "peers_added_no_removal",
tracked: ids(1),
current: ids(1, 2, 3),
wantRemoved: nil,
},
{
name: "empty_tracked",
tracked: nil,
current: ids(1, 2, 3),
wantRemoved: nil,
},
{
name: "both_empty",
tracked: nil,
current: nil,
wantRemoved: nil,
},
{
name: "disjoint_sets",
tracked: ids(1, 2, 3),
current: ids(4, 5, 6),
wantRemoved: ids(1, 2, 3),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// Populate tracked peers
for _, id := range tt.tracked {
mc.lastSentPeers.Store(id, struct{}{})
}
got := mc.computePeerDiff(tt.current)
assert.ElementsMatch(t, tt.wantRemoved, got,
"removed peers should match expected")
})
}
}
// ============================================================================
// updateSentPeers Tests
// ============================================================================
func TestUpdateSentPeers(t *testing.T) {
t.Run("full_peer_list_replaces_all", func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// Pre-populate with old peers
mc.lastSentPeers.Store(tailcfg.NodeID(100), struct{}{})
mc.lastSentPeers.Store(tailcfg.NodeID(200), struct{}{})
// Send full peer list
mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3))
// Old peers should be gone
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(100))
assert.False(t, exists, "old peer 100 should be cleared")
// New peers should be tracked
for _, id := range ids(1, 2, 3) {
_, exists := mc.lastSentPeers.Load(id)
assert.True(t, exists, "peer %d should be tracked", id)
}
})
t.Run("incremental_add_via_PeersChanged", func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
resp := testMapResponse()
resp.PeersChanged = []*tailcfg.Node{{ID: 2}, {ID: 3}}
mc.updateSentPeers(resp)
// All three should be tracked
for _, id := range ids(1, 2, 3) {
_, exists := mc.lastSentPeers.Load(id)
assert.True(t, exists, "peer %d should be tracked", id)
}
})
t.Run("incremental_remove_via_PeersRemoved", func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{})
mc.lastSentPeers.Store(tailcfg.NodeID(3), struct{}{})
resp := testMapResponse()
resp.PeersRemoved = ids(2)
mc.updateSentPeers(resp)
_, exists1 := mc.lastSentPeers.Load(tailcfg.NodeID(1))
_, exists2 := mc.lastSentPeers.Load(tailcfg.NodeID(2))
_, exists3 := mc.lastSentPeers.Load(tailcfg.NodeID(3))
assert.True(t, exists1, "peer 1 should remain")
assert.False(t, exists2, "peer 2 should be removed")
assert.True(t, exists3, "peer 3 should remain")
})
t.Run("nil_response_is_noop", func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
mc.updateSentPeers(nil)
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(1))
assert.True(t, exists, "nil response should not change tracked peers")
})
t.Run("full_then_incremental_sequence", func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// Step 1: Full peer list
mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3))
// Step 2: Add peer 4
resp := testMapResponse()
resp.PeersChanged = []*tailcfg.Node{{ID: 4}}
mc.updateSentPeers(resp)
// Step 3: Remove peer 2
resp2 := testMapResponse()
resp2.PeersRemoved = ids(2)
mc.updateSentPeers(resp2)
// Should have 1, 3, 4
for _, id := range ids(1, 3, 4) {
_, exists := mc.lastSentPeers.Load(id)
assert.True(t, exists, "peer %d should be tracked", id)
}
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(2))
assert.False(t, exists, "peer 2 should have been removed")
})
t.Run("empty_full_peer_list_clears_all", func(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{})
// Empty Peers slice (not nil) means "no peers"
resp := testMapResponse()
resp.Peers = []*tailcfg.Node{} // empty, not nil
mc.updateSentPeers(resp)
count := 0
mc.lastSentPeers.Range(func(_ tailcfg.NodeID, _ struct{}) bool {
count++
return true
})
assert.Equal(t, 0, count, "empty peer list should clear all tracking")
})
}
// ============================================================================
// generateMapResponse Tests (branching logic only, no DB needed)
// ============================================================================
func TestGenerateMapResponse_EmptyChange(t *testing.T) {
mc := newMockNodeConnection(1)
resp, err := generateMapResponse(mc, nil, change.Change{})
require.NoError(t, err)
assert.Nil(t, resp, "empty change should return nil response")
}
func TestGenerateMapResponse_InvalidNodeID(t *testing.T) {
mc := newMockNodeConnection(0) // Invalid ID
resp, err := generateMapResponse(mc, &mapper{}, change.DERPMap())
require.ErrorIs(t, err, ErrInvalidNodeID)
assert.Nil(t, resp)
}
func TestGenerateMapResponse_NilMapper(t *testing.T) {
mc := newMockNodeConnection(1)
resp, err := generateMapResponse(mc, nil, change.DERPMap())
require.ErrorIs(t, err, ErrMapperNil)
assert.Nil(t, resp)
}
func TestGenerateMapResponse_SelfOnlyOtherNode(t *testing.T) {
mc := newMockNodeConnection(1)
// SelfUpdate targeted at node 99 should be skipped for node 1
ch := change.SelfUpdate(99)
resp, err := generateMapResponse(mc, &mapper{}, ch)
require.NoError(t, err)
assert.Nil(t, resp,
"self-only change targeted at different node should return nil")
}
func TestGenerateMapResponse_SelfOnlySameNode(t *testing.T) {
// SelfUpdate targeted at node 1: IsSelfOnly()=true and TargetNode==nodeID
// This should NOT be short-circuited - it should attempt to generate.
// We verify the routing logic by checking that the change is not empty
// and not filtered out (unlike SelfOnlyOtherNode above).
ch := change.SelfUpdate(1)
assert.False(t, ch.IsEmpty(), "SelfUpdate should not be empty")
assert.True(t, ch.IsSelfOnly(), "SelfUpdate should be self-only")
assert.True(t, ch.ShouldSendToNode(1), "should be sent to target node")
assert.False(t, ch.ShouldSendToNode(2), "should NOT be sent to other nodes")
}
// ============================================================================
// handleNodeChange Tests
// ============================================================================
func TestHandleNodeChange_NilConnection(t *testing.T) {
err := handleNodeChange(nil, nil, change.DERPMap())
assert.ErrorIs(t, err, ErrNodeConnectionNil)
}
func TestHandleNodeChange_EmptyChange(t *testing.T) {
mc := newMockNodeConnection(1)
err := handleNodeChange(mc, nil, change.Change{})
require.NoError(t, err, "empty change should not send anything")
assert.Empty(t, mc.getSent(), "no data should be sent for empty change")
}
var errConnectionBroken = errors.New("connection broken")
func TestHandleNodeChange_SendError(t *testing.T) {
mc := newMockNodeConnection(1).withSendError(errConnectionBroken)
// Need a real mapper for this test - we can't easily mock it.
// Instead, test that when generateMapResponse returns nil data,
// no send occurs. The send error path requires a valid MapResponse
// which requires a mapper with state.
// So we test the nil-data path here.
err := handleNodeChange(mc, nil, change.Change{})
assert.NoError(t, err, "empty change produces nil data, no send needed")
}
func TestHandleNodeChange_NilDataNoSend(t *testing.T) {
mc := newMockNodeConnection(1)
// SelfUpdate targeted at different node produces nil data
ch := change.SelfUpdate(99)
err := handleNodeChange(mc, &mapper{}, ch)
require.NoError(t, err, "nil data should not cause error")
assert.Empty(t, mc.getSent(), "nil data should not trigger send")
}
// ============================================================================
// connectionEntry concurrent safety Tests
// ============================================================================
func TestConnectionEntry_ConcurrentSends(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 100)
entry := makeConnectionEntry("concurrent", ch)
var (
wg sync.WaitGroup
successCount atomic.Int64
)
// 50 goroutines sending concurrently
for range 50 {
wg.Go(func() {
err := entry.send(testMapResponse())
if err == nil {
successCount.Add(1)
}
})
}
wg.Wait()
assert.Equal(t, int64(50), successCount.Load(),
"all sends to buffered channel should succeed")
// Drain and count
count := 0
for range len(ch) {
<-ch
count++
}
assert.Equal(t, 50, count, "all 50 messages should be on channel")
}
func TestConnectionEntry_ConcurrentSendAndClose(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 100)
entry := makeConnectionEntry("race", ch)
var (
wg sync.WaitGroup
panicked atomic.Bool
)
// Goroutines sending rapidly
for range 20 {
wg.Go(func() {
defer func() {
if r := recover(); r != nil {
panicked.Store(true)
}
}()
for range 10 {
_ = entry.send(testMapResponse())
}
})
}
// Close midway through
wg.Go(func() {
time.Sleep(1 * time.Millisecond) //nolint:forbidigo // concurrency test coordination
entry.closed.Store(true)
})
wg.Wait()
assert.False(t, panicked.Load(),
"concurrent send and close should not panic")
}
// ============================================================================
// multiChannelNodeConn concurrent Tests
// ============================================================================
func TestMultiChannelSend_ConcurrentAddAndSend(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// Start with one connection
ch1 := make(chan *tailcfg.MapResponse, 100)
mc.addConnection(makeConnectionEntry("initial", ch1))
var (
wg sync.WaitGroup
panicked atomic.Bool
)
// Goroutine adding connections
wg.Go(func() {
defer func() {
if r := recover(); r != nil {
panicked.Store(true)
}
}()
for i := range 10 {
ch := make(chan *tailcfg.MapResponse, 100)
mc.addConnection(makeConnectionEntry(fmt.Sprintf("added-%d", i), ch))
}
})
// Goroutine sending data
wg.Go(func() {
defer func() {
if r := recover(); r != nil {
panicked.Store(true)
}
}()
for range 20 {
_ = mc.send(testMapResponse())
}
})
wg.Wait()
assert.False(t, panicked.Load(),
"concurrent add and send should not panic (mutex protects both)")
}
func TestMultiChannelSend_ConcurrentRemoveAndSend(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
channels := make([]chan *tailcfg.MapResponse, 10)
for i := range channels {
channels[i] = make(chan *tailcfg.MapResponse, 100)
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
}
var (
wg sync.WaitGroup
panicked atomic.Bool
)
// Goroutine removing connections
wg.Go(func() {
defer func() {
if r := recover(); r != nil {
panicked.Store(true)
}
}()
for _, ch := range channels {
mc.removeConnectionByChannel(ch)
}
})
// Goroutine sending data concurrently
wg.Go(func() {
defer func() {
if r := recover(); r != nil {
panicked.Store(true)
}
}()
for range 20 {
_ = mc.send(testMapResponse())
}
})
wg.Wait()
assert.False(t, panicked.Load(),
"concurrent remove and send should not panic")
}
// ============================================================================
// Regression tests for H1 (timer leak) and H3 (lifecycle)
// ============================================================================
// TestConnectionEntry_SendFastPath_TimerStopped is a regression guard for H1.
// Before the fix, connectionEntry.send used time.After(50ms) which leaked a
// timer into the runtime heap on every call even when the channel send
// succeeded immediately. The fix switched to time.NewTimer + defer Stop().
//
// This test sends many messages on a buffered (non-blocking) channel and
// checks that the number of live goroutines stays bounded, which would
// grow without bound under the old time.After approach at high call rates.
func TestConnectionEntry_SendFastPath_TimerStopped(t *testing.T) {
const sends = 5000
ch := make(chan *tailcfg.MapResponse, sends)
entry := &connectionEntry{
id: "timer-leak-test",
c: ch,
version: 100,
created: time.Now(),
}
resp := testMapResponse()
for range sends {
err := entry.send(resp)
require.NoError(t, err)
}
// Drain the channel so we aren't holding references.
for range sends {
<-ch
}
// Force a GC + timer cleanup pass.
runtime.GC()
// If timers were leaking we'd see a goroutine count much higher
// than baseline. With 5000 leaked timers the count would be
// noticeably elevated. We just check it's reasonable.
numGR := runtime.NumGoroutine()
assert.Less(t, numGR, 200,
"goroutine count after %d fast-path sends should be bounded; got %d (possible timer leak)", sends, numGR)
}
// TestBatcher_CloseWaitsForWorkers is a regression guard for H3.
// Before the fix, Close() would tear down node connections while workers
// were potentially still running, risking sends on closed channels.
// The fix added sync.WaitGroup tracking so Close() blocks until all
// worker goroutines exit.
func TestBatcher_CloseWaitsForWorkers(t *testing.T) {
b := NewBatcher(50*time.Millisecond, 4, nil)
goroutinesBefore := runtime.NumGoroutine()
b.Start()
// Give workers time to start.
time.Sleep(20 * time.Millisecond) //nolint:forbidigo // test timing
goroutinesDuring := runtime.NumGoroutine()
// We expect at least 5 new goroutines: 1 doWork + 4 workers.
assert.GreaterOrEqual(t, goroutinesDuring-goroutinesBefore, 5,
"expected doWork + 4 workers to be running")
// Close should block until all workers have exited.
b.Close()
// After Close returns, goroutines should have dropped back.
// Allow a small margin for runtime goroutines.
goroutinesAfter := runtime.NumGoroutine()
assert.InDelta(t, goroutinesBefore, goroutinesAfter, 3,
"goroutines should return to baseline after Close(); before=%d after=%d",
goroutinesBefore, goroutinesAfter)
}
// TestBatcher_CloseThenStartIsNoop verifies the lifecycle contract:
// once a Batcher has been started, calling Start() again is a no-op
// (the started flag prevents double-start).
func TestBatcher_CloseThenStartIsNoop(t *testing.T) {
b := NewBatcher(50*time.Millisecond, 2, nil)
b.Start()
b.Close()
goroutinesBefore := runtime.NumGoroutine()
// Second Start should be a no-op because started is already true.
b.Start()
// Allow a moment for any hypothetical goroutine to appear.
time.Sleep(10 * time.Millisecond) //nolint:forbidigo // test timing
goroutinesAfter := runtime.NumGoroutine()
assert.InDelta(t, goroutinesBefore, goroutinesAfter, 1,
"Start() after Close() should not spawn new goroutines; before=%d after=%d",
goroutinesBefore, goroutinesAfter)
}
// TestBatcher_CloseStopsTicker verifies that Close() stops the internal
// ticker, preventing resource leaks.
func TestBatcher_CloseStopsTicker(t *testing.T) {
b := NewBatcher(10*time.Millisecond, 1, nil)
b.Start()
b.Close()
// After Close, the ticker should be stopped. Reading from a stopped
// ticker's channel should not deliver any values.
select {
case <-b.tick.C:
t.Fatal("ticker fired after Close(); ticker.Stop() was not called")
case <-time.After(50 * time.Millisecond): //nolint:forbidigo // test timing
// Expected: no tick received.
}
}
// ============================================================================
// Regression tests for M1, M3, M7
// ============================================================================
// TestBatcher_CloseBeforeStart_DoesNotHang is a regression guard for M1.
// Before the fix, done was nil until Start() was called. queueWork and
// MapResponseFromChange select on done, so a nil channel would block
// forever when workCh was full. With done initialized in NewBatcher,
// Close() can be called safely before Start().
func TestBatcher_CloseBeforeStart_DoesNotHang(t *testing.T) {
b := NewBatcher(50*time.Millisecond, 2, nil)
// Close without Start must not panic or hang.
done := make(chan struct{})
go func() {
b.Close()
close(done)
}()
select {
case <-done:
// Success: Close returned promptly.
case <-time.After(2 * time.Second): //nolint:forbidigo // test timing
t.Fatal("Close() before Start() hung; done channel was likely nil")
}
}
// TestBatcher_QueueWorkAfterClose_DoesNotHang verifies that queueWork
// returns immediately via the done channel when the batcher is closed,
// even without Start() having been called.
func TestBatcher_QueueWorkAfterClose_DoesNotHang(t *testing.T) {
b := NewBatcher(50*time.Millisecond, 1, nil)
b.Close()
done := make(chan struct{})
go func() {
// queueWork selects on done; with done closed this must return.
b.queueWork(work{})
close(done)
}()
select {
case <-done:
// Success
case <-time.After(2 * time.Second): //nolint:forbidigo // test timing
t.Fatal("queueWork hung after Close(); done channel select not working")
}
}
// TestIsConnected_FalseAfterAddNodeFailure is a regression guard for M3.
// Before the fix, AddNode error paths removed the connection but did not
// mark the node as disconnected. IsConnected would return true for a
// node with zero active connections.
func TestIsConnected_FalseAfterAddNodeFailure(t *testing.T) {
b := NewBatcher(50*time.Millisecond, 2, nil)
b.Start()
defer b.Close()
id := types.NodeID(42)
// Pre-create the node entry so AddNode reuses it, and set up a
// multiChannelNodeConn with no mapper so MapResponseFromChange will fail.
// markConnected() simulates a previous session leaving it connected.
nc := newMultiChannelNodeConn(id, nil)
nc.markConnected()
b.nodes.Store(id, nc)
ch := make(chan *tailcfg.MapResponse, 1)
err := b.AddNode(id, ch, 100, func() {})
require.Error(t, err, "AddNode should fail with nil mapper")
// After failure, the node should NOT be reported as connected.
assert.False(t, b.IsConnected(id),
"IsConnected should return false after AddNode failure with no remaining connections")
}
// TestRemoveConnectionAtIndex_NilsTrailingSlot is a regression guard for M7.
// Before the fix, removeConnectionAtIndexLocked used append(s[:i], s[i+1:]...)
// which left a stale pointer in the backing array's last slot. The fix
// uses copy + explicit nil of the trailing element.
func TestRemoveConnectionAtIndex_NilsTrailingSlot(t *testing.T) {
mc := newMultiChannelNodeConn(1, nil)
// Manually add three entries under the lock.
entries := make([]*connectionEntry, 3)
for i := range entries {
entries[i] = &connectionEntry{id: fmt.Sprintf("conn-%d", i), c: make(chan<- *tailcfg.MapResponse)}
}
mc.mutex.Lock()
mc.connections = append(mc.connections, entries...)
// Remove the middle entry (index 1).
removed := mc.removeConnectionAtIndexLocked(1, false)
require.Equal(t, entries[1], removed)
// After removal, len should be 2 and the backing array slot at
// index 2 (the old len-1) should be nil.
require.Len(t, mc.connections, 2)
assert.Equal(t, entries[0], mc.connections[0])
assert.Equal(t, entries[2], mc.connections[1])
// Check the backing array directly: the slot just past the new
// length must be nil to avoid retaining the pointer.
backing := mc.connections[:3]
assert.Nil(t, backing[2],
"trailing slot in backing array should be nil after removal")
mc.mutex.Unlock()
}