mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-19 16:21:23 +01:00
- TestBatcher_CloseBeforeStart_DoesNotHang: verifies Close() before Start() returns promptly now that done is initialized in NewBatcher. - TestBatcher_QueueWorkAfterClose_DoesNotHang: verifies queueWork returns via the done channel after Close(), even without Start(). - TestIsConnected_FalseAfterAddNodeFailure: verifies IsConnected returns false after AddNode fails and removes the last connection. - TestRemoveConnectionAtIndex_NilsTrailingSlot: verifies the backing array slot is nil-ed after removal to avoid retaining pointers. Updates #2545
1187 lines
34 KiB
Go
1187 lines
34 KiB
Go
package mapper
|
|
|
|
// Unit tests for batcher components that do NOT require database setup.
|
|
// These tests exercise connectionEntry, multiChannelNodeConn, computePeerDiff,
|
|
// updateSentPeers, generateMapResponse branching, and handleNodeChange in isolation.
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"tailscale.com/tailcfg"
|
|
)
|
|
|
|
// ============================================================================
|
|
// Mock Infrastructure
|
|
// ============================================================================
|
|
|
|
// mockNodeConnection implements nodeConnection for isolated unit testing
|
|
// of generateMapResponse and handleNodeChange without a real database.
|
|
type mockNodeConnection struct {
|
|
id types.NodeID
|
|
ver tailcfg.CapabilityVersion
|
|
|
|
// sendFn allows injecting custom send behavior.
|
|
// If nil, sends are recorded and succeed.
|
|
sendFn func(*tailcfg.MapResponse) error
|
|
|
|
// sent records all successful sends for assertion.
|
|
sent []*tailcfg.MapResponse
|
|
mu sync.Mutex
|
|
|
|
// Peer tracking
|
|
peers *xsync.Map[tailcfg.NodeID, struct{}]
|
|
}
|
|
|
|
func newMockNodeConnection(id types.NodeID) *mockNodeConnection {
|
|
return &mockNodeConnection{
|
|
id: id,
|
|
ver: tailcfg.CapabilityVersion(100),
|
|
peers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
|
}
|
|
}
|
|
|
|
// withSendError configures the mock to return the given error on send.
|
|
func (m *mockNodeConnection) withSendError(err error) *mockNodeConnection {
|
|
m.sendFn = func(_ *tailcfg.MapResponse) error { return err }
|
|
return m
|
|
}
|
|
|
|
func (m *mockNodeConnection) nodeID() types.NodeID { return m.id }
|
|
func (m *mockNodeConnection) version() tailcfg.CapabilityVersion { return m.ver }
|
|
|
|
func (m *mockNodeConnection) send(data *tailcfg.MapResponse) error {
|
|
if m.sendFn != nil {
|
|
return m.sendFn(data)
|
|
}
|
|
|
|
m.mu.Lock()
|
|
m.sent = append(m.sent, data)
|
|
m.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *mockNodeConnection) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
|
|
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
|
|
for _, id := range currentPeers {
|
|
currentSet[id] = struct{}{}
|
|
}
|
|
|
|
var removed []tailcfg.NodeID
|
|
|
|
m.peers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
|
|
if _, exists := currentSet[id]; !exists {
|
|
removed = append(removed, id)
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return removed
|
|
}
|
|
|
|
func (m *mockNodeConnection) updateSentPeers(resp *tailcfg.MapResponse) {
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
if resp.Peers != nil {
|
|
m.peers.Clear()
|
|
|
|
for _, peer := range resp.Peers {
|
|
m.peers.Store(peer.ID, struct{}{})
|
|
}
|
|
}
|
|
|
|
for _, peer := range resp.PeersChanged {
|
|
m.peers.Store(peer.ID, struct{}{})
|
|
}
|
|
|
|
for _, id := range resp.PeersRemoved {
|
|
m.peers.Delete(id)
|
|
}
|
|
}
|
|
|
|
// getSent returns a thread-safe copy of all sent responses.
|
|
func (m *mockNodeConnection) getSent() []*tailcfg.MapResponse {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
return append([]*tailcfg.MapResponse{}, m.sent...)
|
|
}
|
|
|
|
// ============================================================================
|
|
// Test Helpers
|
|
// ============================================================================
|
|
|
|
// testMapResponse creates a minimal valid MapResponse for testing.
|
|
func testMapResponse() *tailcfg.MapResponse {
|
|
now := time.Now()
|
|
|
|
return &tailcfg.MapResponse{
|
|
ControlTime: &now,
|
|
}
|
|
}
|
|
|
|
// testMapResponseWithPeers creates a MapResponse with the given peer IDs.
|
|
func testMapResponseWithPeers(peerIDs ...tailcfg.NodeID) *tailcfg.MapResponse {
|
|
resp := testMapResponse()
|
|
|
|
resp.Peers = make([]*tailcfg.Node, len(peerIDs))
|
|
for i, id := range peerIDs {
|
|
resp.Peers[i] = &tailcfg.Node{ID: id}
|
|
}
|
|
|
|
return resp
|
|
}
|
|
|
|
// ids is a convenience for creating a slice of tailcfg.NodeID.
|
|
func ids(nodeIDs ...tailcfg.NodeID) []tailcfg.NodeID {
|
|
return nodeIDs
|
|
}
|
|
|
|
// expectReceive asserts that a message arrives on the channel within 100ms.
|
|
func expectReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, msg string) *tailcfg.MapResponse {
|
|
t.Helper()
|
|
|
|
const timeout = 100 * time.Millisecond
|
|
|
|
select {
|
|
case data := <-ch:
|
|
return data
|
|
case <-time.After(timeout):
|
|
t.Fatalf("expected to receive on channel within %v: %s", timeout, msg)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// expectNoReceive asserts that no message arrives within timeout.
|
|
func expectNoReceive(t *testing.T, ch <-chan *tailcfg.MapResponse, timeout time.Duration, msg string) {
|
|
t.Helper()
|
|
|
|
select {
|
|
case data := <-ch:
|
|
t.Fatalf("expected no receive but got %+v: %s", data, msg)
|
|
case <-time.After(timeout):
|
|
// Expected
|
|
}
|
|
}
|
|
|
|
// makeConnectionEntry creates a connectionEntry with the given channel.
|
|
func makeConnectionEntry(id string, ch chan<- *tailcfg.MapResponse) *connectionEntry {
|
|
entry := &connectionEntry{
|
|
id: id,
|
|
c: ch,
|
|
version: tailcfg.CapabilityVersion(100),
|
|
created: time.Now(),
|
|
}
|
|
entry.lastUsed.Store(time.Now().Unix())
|
|
|
|
return entry
|
|
}
|
|
|
|
// ============================================================================
|
|
// connectionEntry.send() Tests
|
|
// ============================================================================
|
|
|
|
func TestConnectionEntry_SendSuccess(t *testing.T) {
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entry := makeConnectionEntry("test-conn", ch)
|
|
data := testMapResponse()
|
|
|
|
beforeSend := time.Now().Unix()
|
|
err := entry.send(data)
|
|
|
|
require.NoError(t, err)
|
|
assert.GreaterOrEqual(t, entry.lastUsed.Load(), beforeSend,
|
|
"lastUsed should be updated after successful send")
|
|
|
|
// Verify data was actually sent
|
|
received := expectReceive(t, ch, "data should be on channel")
|
|
assert.Equal(t, data, received)
|
|
}
|
|
|
|
func TestConnectionEntry_SendNilData(t *testing.T) {
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entry := makeConnectionEntry("test-conn", ch)
|
|
|
|
err := entry.send(nil)
|
|
|
|
require.NoError(t, err, "nil data should return nil error")
|
|
expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent to channel")
|
|
}
|
|
|
|
func TestConnectionEntry_SendTimeout(t *testing.T) {
|
|
// Unbuffered channel with no reader = always blocks
|
|
ch := make(chan *tailcfg.MapResponse)
|
|
entry := makeConnectionEntry("test-conn", ch)
|
|
data := testMapResponse()
|
|
|
|
start := time.Now()
|
|
err := entry.send(data)
|
|
elapsed := time.Since(start)
|
|
|
|
require.ErrorIs(t, err, ErrConnectionSendTimeout)
|
|
assert.GreaterOrEqual(t, elapsed, 40*time.Millisecond,
|
|
"should wait approximately 50ms before timeout")
|
|
}
|
|
|
|
func TestConnectionEntry_SendClosed(t *testing.T) {
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entry := makeConnectionEntry("test-conn", ch)
|
|
|
|
// Mark as closed before sending
|
|
entry.closed.Store(true)
|
|
|
|
err := entry.send(testMapResponse())
|
|
|
|
require.ErrorIs(t, err, errConnectionClosed)
|
|
expectNoReceive(t, ch, 10*time.Millisecond,
|
|
"closed entry should not send data to channel")
|
|
}
|
|
|
|
func TestConnectionEntry_SendUpdatesLastUsed(t *testing.T) {
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entry := makeConnectionEntry("test-conn", ch)
|
|
|
|
// Set lastUsed to a past time
|
|
pastTime := time.Now().Add(-1 * time.Hour).Unix()
|
|
entry.lastUsed.Store(pastTime)
|
|
|
|
err := entry.send(testMapResponse())
|
|
require.NoError(t, err)
|
|
|
|
assert.Greater(t, entry.lastUsed.Load(), pastTime,
|
|
"lastUsed should be updated to current time after send")
|
|
}
|
|
|
|
// ============================================================================
|
|
// multiChannelNodeConn.send() Tests
|
|
// ============================================================================
|
|
|
|
func TestMultiChannelSend_AllSuccess(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// Create 3 buffered channels (all will succeed)
|
|
channels := make([]chan *tailcfg.MapResponse, 3)
|
|
for i := range channels {
|
|
channels[i] = make(chan *tailcfg.MapResponse, 1)
|
|
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
|
|
}
|
|
|
|
data := testMapResponse()
|
|
err := mc.send(data)
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 3, mc.getActiveConnectionCount(),
|
|
"all connections should remain active after success")
|
|
|
|
// Verify all channels received the data
|
|
for i, ch := range channels {
|
|
received := expectReceive(t, ch,
|
|
fmt.Sprintf("channel %d should receive data", i))
|
|
assert.Equal(t, data, received)
|
|
}
|
|
}
|
|
|
|
func TestMultiChannelSend_PartialFailure(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// 2 buffered channels (will succeed) + 1 unbuffered (will timeout)
|
|
goodCh1 := make(chan *tailcfg.MapResponse, 1)
|
|
goodCh2 := make(chan *tailcfg.MapResponse, 1)
|
|
badCh := make(chan *tailcfg.MapResponse) // unbuffered, no reader
|
|
|
|
mc.addConnection(makeConnectionEntry("good-1", goodCh1))
|
|
mc.addConnection(makeConnectionEntry("bad", badCh))
|
|
mc.addConnection(makeConnectionEntry("good-2", goodCh2))
|
|
|
|
err := mc.send(testMapResponse())
|
|
|
|
require.NoError(t, err, "should succeed if at least one connection works")
|
|
assert.Equal(t, 2, mc.getActiveConnectionCount(),
|
|
"failed connection should be removed")
|
|
|
|
// Good channels should have received data
|
|
expectReceive(t, goodCh1, "good-1 should receive")
|
|
expectReceive(t, goodCh2, "good-2 should receive")
|
|
}
|
|
|
|
func TestMultiChannelSend_AllFail(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// All unbuffered channels with no readers
|
|
for i := range 3 {
|
|
ch := make(chan *tailcfg.MapResponse) // unbuffered
|
|
mc.addConnection(makeConnectionEntry(fmt.Sprintf("bad-%d", i), ch))
|
|
}
|
|
|
|
err := mc.send(testMapResponse())
|
|
|
|
require.Error(t, err, "should return error when all connections fail")
|
|
assert.Equal(t, 0, mc.getActiveConnectionCount(),
|
|
"all failed connections should be removed")
|
|
}
|
|
|
|
func TestMultiChannelSend_ZeroConnections(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
err := mc.send(testMapResponse())
|
|
|
|
require.NoError(t, err,
|
|
"sending to node with 0 connections should succeed silently (rapid reconnection scenario)")
|
|
}
|
|
|
|
func TestMultiChannelSend_NilData(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
mc.addConnection(makeConnectionEntry("conn", ch))
|
|
|
|
err := mc.send(nil)
|
|
|
|
require.NoError(t, err, "nil data should return nil immediately")
|
|
expectNoReceive(t, ch, 10*time.Millisecond, "nil data should not be sent")
|
|
}
|
|
|
|
func TestMultiChannelSend_FailedConnectionRemoved(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
goodCh := make(chan *tailcfg.MapResponse, 10) // large buffer
|
|
badCh := make(chan *tailcfg.MapResponse) // unbuffered, will timeout
|
|
|
|
mc.addConnection(makeConnectionEntry("good", goodCh))
|
|
mc.addConnection(makeConnectionEntry("bad", badCh))
|
|
|
|
assert.Equal(t, 2, mc.getActiveConnectionCount())
|
|
|
|
// First send: bad connection removed
|
|
err := mc.send(testMapResponse())
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, mc.getActiveConnectionCount())
|
|
|
|
// Second send: only good connection remains, should succeed
|
|
err = mc.send(testMapResponse())
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, mc.getActiveConnectionCount())
|
|
}
|
|
|
|
func TestMultiChannelSend_UpdateCount(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
ch := make(chan *tailcfg.MapResponse, 10)
|
|
mc.addConnection(makeConnectionEntry("conn", ch))
|
|
|
|
assert.Equal(t, int64(0), mc.updateCount.Load())
|
|
|
|
_ = mc.send(testMapResponse())
|
|
assert.Equal(t, int64(1), mc.updateCount.Load())
|
|
|
|
_ = mc.send(testMapResponse())
|
|
assert.Equal(t, int64(2), mc.updateCount.Load())
|
|
}
|
|
|
|
// ============================================================================
|
|
// multiChannelNodeConn.close() Tests
|
|
// ============================================================================
|
|
|
|
func TestMultiChannelClose_MarksEntriesClosed(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
entries := make([]*connectionEntry, 3)
|
|
for i := range entries {
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entries[i] = makeConnectionEntry(fmt.Sprintf("conn-%d", i), ch)
|
|
mc.addConnection(entries[i])
|
|
}
|
|
|
|
mc.close()
|
|
|
|
for i, entry := range entries {
|
|
assert.True(t, entry.closed.Load(),
|
|
"entry %d should be marked as closed", i)
|
|
}
|
|
}
|
|
|
|
func TestMultiChannelClose_PreventsSendPanic(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entry := makeConnectionEntry("conn", ch)
|
|
mc.addConnection(entry)
|
|
|
|
mc.close()
|
|
|
|
// After close, connectionEntry.send should return errConnectionClosed
|
|
// (not panic on send to closed channel)
|
|
err := entry.send(testMapResponse())
|
|
require.ErrorIs(t, err, errConnectionClosed,
|
|
"send after close should return errConnectionClosed, not panic")
|
|
}
|
|
|
|
// ============================================================================
|
|
// multiChannelNodeConn connection management Tests
|
|
// ============================================================================
|
|
|
|
func TestMultiChannelNodeConn_AddRemoveConnections(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
ch1 := make(chan *tailcfg.MapResponse, 1)
|
|
ch2 := make(chan *tailcfg.MapResponse, 1)
|
|
ch3 := make(chan *tailcfg.MapResponse, 1)
|
|
|
|
// Add connections
|
|
mc.addConnection(makeConnectionEntry("c1", ch1))
|
|
assert.Equal(t, 1, mc.getActiveConnectionCount())
|
|
assert.True(t, mc.hasActiveConnections())
|
|
|
|
mc.addConnection(makeConnectionEntry("c2", ch2))
|
|
mc.addConnection(makeConnectionEntry("c3", ch3))
|
|
assert.Equal(t, 3, mc.getActiveConnectionCount())
|
|
|
|
// Remove by channel pointer
|
|
assert.True(t, mc.removeConnectionByChannel(ch2))
|
|
assert.Equal(t, 2, mc.getActiveConnectionCount())
|
|
|
|
// Remove non-existent channel
|
|
nonExistentCh := make(chan *tailcfg.MapResponse)
|
|
assert.False(t, mc.removeConnectionByChannel(nonExistentCh))
|
|
assert.Equal(t, 2, mc.getActiveConnectionCount())
|
|
|
|
// Remove remaining
|
|
assert.True(t, mc.removeConnectionByChannel(ch1))
|
|
assert.True(t, mc.removeConnectionByChannel(ch3))
|
|
assert.Equal(t, 0, mc.getActiveConnectionCount())
|
|
assert.False(t, mc.hasActiveConnections())
|
|
}
|
|
|
|
func TestMultiChannelNodeConn_Version(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// No connections - version should be 0
|
|
assert.Equal(t, tailcfg.CapabilityVersion(0), mc.version())
|
|
|
|
// Add connection with version 100
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
entry := makeConnectionEntry("conn", ch)
|
|
entry.version = tailcfg.CapabilityVersion(100)
|
|
mc.addConnection(entry)
|
|
|
|
assert.Equal(t, tailcfg.CapabilityVersion(100), mc.version())
|
|
}
|
|
|
|
// ============================================================================
|
|
// computePeerDiff Tests
|
|
// ============================================================================
|
|
|
|
func TestComputePeerDiff(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tracked []tailcfg.NodeID // peers previously sent to client
|
|
current []tailcfg.NodeID // peers visible now
|
|
wantRemoved []tailcfg.NodeID // expected removed peers
|
|
}{
|
|
{
|
|
name: "no_changes",
|
|
tracked: ids(1, 2, 3),
|
|
current: ids(1, 2, 3),
|
|
wantRemoved: nil,
|
|
},
|
|
{
|
|
name: "one_removed",
|
|
tracked: ids(1, 2, 3),
|
|
current: ids(1, 3),
|
|
wantRemoved: ids(2),
|
|
},
|
|
{
|
|
name: "multiple_removed",
|
|
tracked: ids(1, 2, 3, 4, 5),
|
|
current: ids(2, 4),
|
|
wantRemoved: ids(1, 3, 5),
|
|
},
|
|
{
|
|
name: "all_removed",
|
|
tracked: ids(1, 2, 3),
|
|
current: nil,
|
|
wantRemoved: ids(1, 2, 3),
|
|
},
|
|
{
|
|
name: "peers_added_no_removal",
|
|
tracked: ids(1),
|
|
current: ids(1, 2, 3),
|
|
wantRemoved: nil,
|
|
},
|
|
{
|
|
name: "empty_tracked",
|
|
tracked: nil,
|
|
current: ids(1, 2, 3),
|
|
wantRemoved: nil,
|
|
},
|
|
{
|
|
name: "both_empty",
|
|
tracked: nil,
|
|
current: nil,
|
|
wantRemoved: nil,
|
|
},
|
|
{
|
|
name: "disjoint_sets",
|
|
tracked: ids(1, 2, 3),
|
|
current: ids(4, 5, 6),
|
|
wantRemoved: ids(1, 2, 3),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// Populate tracked peers
|
|
for _, id := range tt.tracked {
|
|
mc.lastSentPeers.Store(id, struct{}{})
|
|
}
|
|
|
|
got := mc.computePeerDiff(tt.current)
|
|
|
|
assert.ElementsMatch(t, tt.wantRemoved, got,
|
|
"removed peers should match expected")
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// updateSentPeers Tests
|
|
// ============================================================================
|
|
|
|
func TestUpdateSentPeers(t *testing.T) {
|
|
t.Run("full_peer_list_replaces_all", func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
// Pre-populate with old peers
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(100), struct{}{})
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(200), struct{}{})
|
|
|
|
// Send full peer list
|
|
mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3))
|
|
|
|
// Old peers should be gone
|
|
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(100))
|
|
assert.False(t, exists, "old peer 100 should be cleared")
|
|
|
|
// New peers should be tracked
|
|
for _, id := range ids(1, 2, 3) {
|
|
_, exists := mc.lastSentPeers.Load(id)
|
|
assert.True(t, exists, "peer %d should be tracked", id)
|
|
}
|
|
})
|
|
|
|
t.Run("incremental_add_via_PeersChanged", func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
|
|
|
resp := testMapResponse()
|
|
resp.PeersChanged = []*tailcfg.Node{{ID: 2}, {ID: 3}}
|
|
mc.updateSentPeers(resp)
|
|
|
|
// All three should be tracked
|
|
for _, id := range ids(1, 2, 3) {
|
|
_, exists := mc.lastSentPeers.Load(id)
|
|
assert.True(t, exists, "peer %d should be tracked", id)
|
|
}
|
|
})
|
|
|
|
t.Run("incremental_remove_via_PeersRemoved", func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{})
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(3), struct{}{})
|
|
|
|
resp := testMapResponse()
|
|
resp.PeersRemoved = ids(2)
|
|
mc.updateSentPeers(resp)
|
|
|
|
_, exists1 := mc.lastSentPeers.Load(tailcfg.NodeID(1))
|
|
_, exists2 := mc.lastSentPeers.Load(tailcfg.NodeID(2))
|
|
_, exists3 := mc.lastSentPeers.Load(tailcfg.NodeID(3))
|
|
|
|
assert.True(t, exists1, "peer 1 should remain")
|
|
assert.False(t, exists2, "peer 2 should be removed")
|
|
assert.True(t, exists3, "peer 3 should remain")
|
|
})
|
|
|
|
t.Run("nil_response_is_noop", func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
|
|
|
mc.updateSentPeers(nil)
|
|
|
|
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(1))
|
|
assert.True(t, exists, "nil response should not change tracked peers")
|
|
})
|
|
|
|
t.Run("full_then_incremental_sequence", func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// Step 1: Full peer list
|
|
mc.updateSentPeers(testMapResponseWithPeers(1, 2, 3))
|
|
|
|
// Step 2: Add peer 4
|
|
resp := testMapResponse()
|
|
resp.PeersChanged = []*tailcfg.Node{{ID: 4}}
|
|
mc.updateSentPeers(resp)
|
|
|
|
// Step 3: Remove peer 2
|
|
resp2 := testMapResponse()
|
|
resp2.PeersRemoved = ids(2)
|
|
mc.updateSentPeers(resp2)
|
|
|
|
// Should have 1, 3, 4
|
|
for _, id := range ids(1, 3, 4) {
|
|
_, exists := mc.lastSentPeers.Load(id)
|
|
assert.True(t, exists, "peer %d should be tracked", id)
|
|
}
|
|
|
|
_, exists := mc.lastSentPeers.Load(tailcfg.NodeID(2))
|
|
assert.False(t, exists, "peer 2 should have been removed")
|
|
})
|
|
|
|
t.Run("empty_full_peer_list_clears_all", func(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(1), struct{}{})
|
|
mc.lastSentPeers.Store(tailcfg.NodeID(2), struct{}{})
|
|
|
|
// Empty Peers slice (not nil) means "no peers"
|
|
resp := testMapResponse()
|
|
resp.Peers = []*tailcfg.Node{} // empty, not nil
|
|
mc.updateSentPeers(resp)
|
|
|
|
count := 0
|
|
|
|
mc.lastSentPeers.Range(func(_ tailcfg.NodeID, _ struct{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
assert.Equal(t, 0, count, "empty peer list should clear all tracking")
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// generateMapResponse Tests (branching logic only, no DB needed)
|
|
// ============================================================================
|
|
|
|
func TestGenerateMapResponse_EmptyChange(t *testing.T) {
|
|
mc := newMockNodeConnection(1)
|
|
|
|
resp, err := generateMapResponse(mc, nil, change.Change{})
|
|
|
|
require.NoError(t, err)
|
|
assert.Nil(t, resp, "empty change should return nil response")
|
|
}
|
|
|
|
func TestGenerateMapResponse_InvalidNodeID(t *testing.T) {
|
|
mc := newMockNodeConnection(0) // Invalid ID
|
|
|
|
resp, err := generateMapResponse(mc, &mapper{}, change.DERPMap())
|
|
|
|
require.ErrorIs(t, err, ErrInvalidNodeID)
|
|
assert.Nil(t, resp)
|
|
}
|
|
|
|
func TestGenerateMapResponse_NilMapper(t *testing.T) {
|
|
mc := newMockNodeConnection(1)
|
|
|
|
resp, err := generateMapResponse(mc, nil, change.DERPMap())
|
|
|
|
require.ErrorIs(t, err, ErrMapperNil)
|
|
assert.Nil(t, resp)
|
|
}
|
|
|
|
func TestGenerateMapResponse_SelfOnlyOtherNode(t *testing.T) {
|
|
mc := newMockNodeConnection(1)
|
|
|
|
// SelfUpdate targeted at node 99 should be skipped for node 1
|
|
ch := change.SelfUpdate(99)
|
|
resp, err := generateMapResponse(mc, &mapper{}, ch)
|
|
|
|
require.NoError(t, err)
|
|
assert.Nil(t, resp,
|
|
"self-only change targeted at different node should return nil")
|
|
}
|
|
|
|
func TestGenerateMapResponse_SelfOnlySameNode(t *testing.T) {
|
|
// SelfUpdate targeted at node 1: IsSelfOnly()=true and TargetNode==nodeID
|
|
// This should NOT be short-circuited - it should attempt to generate.
|
|
// We verify the routing logic by checking that the change is not empty
|
|
// and not filtered out (unlike SelfOnlyOtherNode above).
|
|
ch := change.SelfUpdate(1)
|
|
assert.False(t, ch.IsEmpty(), "SelfUpdate should not be empty")
|
|
assert.True(t, ch.IsSelfOnly(), "SelfUpdate should be self-only")
|
|
assert.True(t, ch.ShouldSendToNode(1), "should be sent to target node")
|
|
assert.False(t, ch.ShouldSendToNode(2), "should NOT be sent to other nodes")
|
|
}
|
|
|
|
// ============================================================================
|
|
// handleNodeChange Tests
|
|
// ============================================================================
|
|
|
|
func TestHandleNodeChange_NilConnection(t *testing.T) {
|
|
err := handleNodeChange(nil, nil, change.DERPMap())
|
|
|
|
assert.ErrorIs(t, err, ErrNodeConnectionNil)
|
|
}
|
|
|
|
func TestHandleNodeChange_EmptyChange(t *testing.T) {
|
|
mc := newMockNodeConnection(1)
|
|
|
|
err := handleNodeChange(mc, nil, change.Change{})
|
|
|
|
require.NoError(t, err, "empty change should not send anything")
|
|
assert.Empty(t, mc.getSent(), "no data should be sent for empty change")
|
|
}
|
|
|
|
var errConnectionBroken = errors.New("connection broken")
|
|
|
|
func TestHandleNodeChange_SendError(t *testing.T) {
|
|
mc := newMockNodeConnection(1).withSendError(errConnectionBroken)
|
|
|
|
// Need a real mapper for this test - we can't easily mock it.
|
|
// Instead, test that when generateMapResponse returns nil data,
|
|
// no send occurs. The send error path requires a valid MapResponse
|
|
// which requires a mapper with state.
|
|
// So we test the nil-data path here.
|
|
err := handleNodeChange(mc, nil, change.Change{})
|
|
assert.NoError(t, err, "empty change produces nil data, no send needed")
|
|
}
|
|
|
|
func TestHandleNodeChange_NilDataNoSend(t *testing.T) {
|
|
mc := newMockNodeConnection(1)
|
|
|
|
// SelfUpdate targeted at different node produces nil data
|
|
ch := change.SelfUpdate(99)
|
|
err := handleNodeChange(mc, &mapper{}, ch)
|
|
|
|
require.NoError(t, err, "nil data should not cause error")
|
|
assert.Empty(t, mc.getSent(), "nil data should not trigger send")
|
|
}
|
|
|
|
// ============================================================================
|
|
// connectionEntry concurrent safety Tests
|
|
// ============================================================================
|
|
|
|
func TestConnectionEntry_ConcurrentSends(t *testing.T) {
|
|
ch := make(chan *tailcfg.MapResponse, 100)
|
|
entry := makeConnectionEntry("concurrent", ch)
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
successCount atomic.Int64
|
|
)
|
|
|
|
// 50 goroutines sending concurrently
|
|
|
|
for range 50 {
|
|
wg.Go(func() {
|
|
err := entry.send(testMapResponse())
|
|
if err == nil {
|
|
successCount.Add(1)
|
|
}
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
assert.Equal(t, int64(50), successCount.Load(),
|
|
"all sends to buffered channel should succeed")
|
|
|
|
// Drain and count
|
|
count := 0
|
|
|
|
for range len(ch) {
|
|
<-ch
|
|
|
|
count++
|
|
}
|
|
|
|
assert.Equal(t, 50, count, "all 50 messages should be on channel")
|
|
}
|
|
|
|
func TestConnectionEntry_ConcurrentSendAndClose(t *testing.T) {
|
|
ch := make(chan *tailcfg.MapResponse, 100)
|
|
entry := makeConnectionEntry("race", ch)
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
panicked atomic.Bool
|
|
)
|
|
|
|
// Goroutines sending rapidly
|
|
|
|
for range 20 {
|
|
wg.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked.Store(true)
|
|
}
|
|
}()
|
|
|
|
for range 10 {
|
|
_ = entry.send(testMapResponse())
|
|
}
|
|
})
|
|
}
|
|
|
|
// Close midway through
|
|
|
|
wg.Go(func() {
|
|
time.Sleep(1 * time.Millisecond) //nolint:forbidigo // concurrency test coordination
|
|
entry.closed.Store(true)
|
|
})
|
|
|
|
wg.Wait()
|
|
|
|
assert.False(t, panicked.Load(),
|
|
"concurrent send and close should not panic")
|
|
}
|
|
|
|
// ============================================================================
|
|
// multiChannelNodeConn concurrent Tests
|
|
// ============================================================================
|
|
|
|
func TestMultiChannelSend_ConcurrentAddAndSend(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// Start with one connection
|
|
ch1 := make(chan *tailcfg.MapResponse, 100)
|
|
mc.addConnection(makeConnectionEntry("initial", ch1))
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
panicked atomic.Bool
|
|
)
|
|
|
|
// Goroutine adding connections
|
|
|
|
wg.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked.Store(true)
|
|
}
|
|
}()
|
|
|
|
for i := range 10 {
|
|
ch := make(chan *tailcfg.MapResponse, 100)
|
|
mc.addConnection(makeConnectionEntry(fmt.Sprintf("added-%d", i), ch))
|
|
}
|
|
})
|
|
|
|
// Goroutine sending data
|
|
|
|
wg.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked.Store(true)
|
|
}
|
|
}()
|
|
|
|
for range 20 {
|
|
_ = mc.send(testMapResponse())
|
|
}
|
|
})
|
|
|
|
wg.Wait()
|
|
|
|
assert.False(t, panicked.Load(),
|
|
"concurrent add and send should not panic (mutex protects both)")
|
|
}
|
|
|
|
func TestMultiChannelSend_ConcurrentRemoveAndSend(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
channels := make([]chan *tailcfg.MapResponse, 10)
|
|
for i := range channels {
|
|
channels[i] = make(chan *tailcfg.MapResponse, 100)
|
|
mc.addConnection(makeConnectionEntry(fmt.Sprintf("conn-%d", i), channels[i]))
|
|
}
|
|
|
|
var (
|
|
wg sync.WaitGroup
|
|
panicked atomic.Bool
|
|
)
|
|
|
|
// Goroutine removing connections
|
|
|
|
wg.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked.Store(true)
|
|
}
|
|
}()
|
|
|
|
for _, ch := range channels {
|
|
mc.removeConnectionByChannel(ch)
|
|
}
|
|
})
|
|
|
|
// Goroutine sending data concurrently
|
|
|
|
wg.Go(func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panicked.Store(true)
|
|
}
|
|
}()
|
|
|
|
for range 20 {
|
|
_ = mc.send(testMapResponse())
|
|
}
|
|
})
|
|
|
|
wg.Wait()
|
|
|
|
assert.False(t, panicked.Load(),
|
|
"concurrent remove and send should not panic")
|
|
}
|
|
|
|
// ============================================================================
|
|
// Regression tests for H1 (timer leak) and H3 (lifecycle)
|
|
// ============================================================================
|
|
|
|
// TestConnectionEntry_SendFastPath_TimerStopped is a regression guard for H1.
|
|
// Before the fix, connectionEntry.send used time.After(50ms) which leaked a
|
|
// timer into the runtime heap on every call even when the channel send
|
|
// succeeded immediately. The fix switched to time.NewTimer + defer Stop().
|
|
//
|
|
// This test sends many messages on a buffered (non-blocking) channel and
|
|
// checks that the number of live goroutines stays bounded, which would
|
|
// grow without bound under the old time.After approach at high call rates.
|
|
func TestConnectionEntry_SendFastPath_TimerStopped(t *testing.T) {
|
|
const sends = 5000
|
|
|
|
ch := make(chan *tailcfg.MapResponse, sends)
|
|
|
|
entry := &connectionEntry{
|
|
id: "timer-leak-test",
|
|
c: ch,
|
|
version: 100,
|
|
created: time.Now(),
|
|
}
|
|
|
|
resp := testMapResponse()
|
|
|
|
for range sends {
|
|
err := entry.send(resp)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// Drain the channel so we aren't holding references.
|
|
for range sends {
|
|
<-ch
|
|
}
|
|
|
|
// Force a GC + timer cleanup pass.
|
|
runtime.GC()
|
|
|
|
// If timers were leaking we'd see a goroutine count much higher
|
|
// than baseline. With 5000 leaked timers the count would be
|
|
// noticeably elevated. We just check it's reasonable.
|
|
numGR := runtime.NumGoroutine()
|
|
assert.Less(t, numGR, 200,
|
|
"goroutine count after %d fast-path sends should be bounded; got %d (possible timer leak)", sends, numGR)
|
|
}
|
|
|
|
// TestBatcher_CloseWaitsForWorkers is a regression guard for H3.
|
|
// Before the fix, Close() would tear down node connections while workers
|
|
// were potentially still running, risking sends on closed channels.
|
|
// The fix added sync.WaitGroup tracking so Close() blocks until all
|
|
// worker goroutines exit.
|
|
func TestBatcher_CloseWaitsForWorkers(t *testing.T) {
|
|
b := NewBatcher(50*time.Millisecond, 4, nil)
|
|
|
|
goroutinesBefore := runtime.NumGoroutine()
|
|
|
|
b.Start()
|
|
|
|
// Give workers time to start.
|
|
time.Sleep(20 * time.Millisecond) //nolint:forbidigo // test timing
|
|
|
|
goroutinesDuring := runtime.NumGoroutine()
|
|
|
|
// We expect at least 5 new goroutines: 1 doWork + 4 workers.
|
|
assert.GreaterOrEqual(t, goroutinesDuring-goroutinesBefore, 5,
|
|
"expected doWork + 4 workers to be running")
|
|
|
|
// Close should block until all workers have exited.
|
|
b.Close()
|
|
|
|
// After Close returns, goroutines should have dropped back.
|
|
// Allow a small margin for runtime goroutines.
|
|
goroutinesAfter := runtime.NumGoroutine()
|
|
assert.InDelta(t, goroutinesBefore, goroutinesAfter, 3,
|
|
"goroutines should return to baseline after Close(); before=%d after=%d",
|
|
goroutinesBefore, goroutinesAfter)
|
|
}
|
|
|
|
// TestBatcher_CloseThenStartIsNoop verifies the lifecycle contract:
|
|
// once a Batcher has been started, calling Start() again is a no-op
|
|
// (the started flag prevents double-start).
|
|
func TestBatcher_CloseThenStartIsNoop(t *testing.T) {
|
|
b := NewBatcher(50*time.Millisecond, 2, nil)
|
|
|
|
b.Start()
|
|
b.Close()
|
|
|
|
goroutinesBefore := runtime.NumGoroutine()
|
|
|
|
// Second Start should be a no-op because started is already true.
|
|
b.Start()
|
|
|
|
// Allow a moment for any hypothetical goroutine to appear.
|
|
time.Sleep(10 * time.Millisecond) //nolint:forbidigo // test timing
|
|
|
|
goroutinesAfter := runtime.NumGoroutine()
|
|
|
|
assert.InDelta(t, goroutinesBefore, goroutinesAfter, 1,
|
|
"Start() after Close() should not spawn new goroutines; before=%d after=%d",
|
|
goroutinesBefore, goroutinesAfter)
|
|
}
|
|
|
|
// TestBatcher_CloseStopsTicker verifies that Close() stops the internal
|
|
// ticker, preventing resource leaks.
|
|
func TestBatcher_CloseStopsTicker(t *testing.T) {
|
|
b := NewBatcher(10*time.Millisecond, 1, nil)
|
|
|
|
b.Start()
|
|
b.Close()
|
|
|
|
// After Close, the ticker should be stopped. Reading from a stopped
|
|
// ticker's channel should not deliver any values.
|
|
select {
|
|
case <-b.tick.C:
|
|
t.Fatal("ticker fired after Close(); ticker.Stop() was not called")
|
|
case <-time.After(50 * time.Millisecond): //nolint:forbidigo // test timing
|
|
// Expected: no tick received.
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Regression tests for M1, M3, M7
|
|
// ============================================================================
|
|
|
|
// TestBatcher_CloseBeforeStart_DoesNotHang is a regression guard for M1.
|
|
// Before the fix, done was nil until Start() was called. queueWork and
|
|
// MapResponseFromChange select on done, so a nil channel would block
|
|
// forever when workCh was full. With done initialized in NewBatcher,
|
|
// Close() can be called safely before Start().
|
|
func TestBatcher_CloseBeforeStart_DoesNotHang(t *testing.T) {
|
|
b := NewBatcher(50*time.Millisecond, 2, nil)
|
|
|
|
// Close without Start must not panic or hang.
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
b.Close()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Success: Close returned promptly.
|
|
case <-time.After(2 * time.Second): //nolint:forbidigo // test timing
|
|
t.Fatal("Close() before Start() hung; done channel was likely nil")
|
|
}
|
|
}
|
|
|
|
// TestBatcher_QueueWorkAfterClose_DoesNotHang verifies that queueWork
|
|
// returns immediately via the done channel when the batcher is closed,
|
|
// even without Start() having been called.
|
|
func TestBatcher_QueueWorkAfterClose_DoesNotHang(t *testing.T) {
|
|
b := NewBatcher(50*time.Millisecond, 1, nil)
|
|
b.Close()
|
|
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
// queueWork selects on done; with done closed this must return.
|
|
b.queueWork(work{})
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Success
|
|
case <-time.After(2 * time.Second): //nolint:forbidigo // test timing
|
|
t.Fatal("queueWork hung after Close(); done channel select not working")
|
|
}
|
|
}
|
|
|
|
// TestIsConnected_FalseAfterAddNodeFailure is a regression guard for M3.
|
|
// Before the fix, AddNode error paths removed the connection but left
|
|
// b.connected with its previous value (nil = connected). IsConnected
|
|
// would return true for a node with zero active connections.
|
|
func TestIsConnected_FalseAfterAddNodeFailure(t *testing.T) {
|
|
b := NewBatcher(50*time.Millisecond, 2, nil)
|
|
b.Start()
|
|
|
|
defer b.Close()
|
|
|
|
id := types.NodeID(42)
|
|
|
|
// Simulate a previous session leaving the node marked as connected.
|
|
b.connected.Store(id, nil) // nil = connected
|
|
|
|
// Pre-create the node entry so AddNode reuses it, and set up a
|
|
// multiChannelNodeConn with no mapper so MapResponseFromChange will fail.
|
|
nc := newMultiChannelNodeConn(id, nil)
|
|
b.nodes.Store(id, nc)
|
|
|
|
ch := make(chan *tailcfg.MapResponse, 1)
|
|
|
|
err := b.AddNode(id, ch, 100, func() {})
|
|
require.Error(t, err, "AddNode should fail with nil mapper")
|
|
|
|
// After failure, the node should NOT be reported as connected.
|
|
assert.False(t, b.IsConnected(id),
|
|
"IsConnected should return false after AddNode failure with no remaining connections")
|
|
}
|
|
|
|
// TestRemoveConnectionAtIndex_NilsTrailingSlot is a regression guard for M7.
|
|
// Before the fix, removeConnectionAtIndexLocked used append(s[:i], s[i+1:]...)
|
|
// which left a stale pointer in the backing array's last slot. The fix
|
|
// uses copy + explicit nil of the trailing element.
|
|
func TestRemoveConnectionAtIndex_NilsTrailingSlot(t *testing.T) {
|
|
mc := newMultiChannelNodeConn(1, nil)
|
|
|
|
// Manually add three entries under the lock.
|
|
entries := make([]*connectionEntry, 3)
|
|
for i := range entries {
|
|
entries[i] = &connectionEntry{id: fmt.Sprintf("conn-%d", i), c: make(chan<- *tailcfg.MapResponse)}
|
|
}
|
|
|
|
mc.mutex.Lock()
|
|
mc.connections = append(mc.connections, entries...)
|
|
|
|
// Remove the middle entry (index 1).
|
|
removed := mc.removeConnectionAtIndexLocked(1, false)
|
|
require.Equal(t, entries[1], removed)
|
|
|
|
// After removal, len should be 2 and the backing array slot at
|
|
// index 2 (the old len-1) should be nil.
|
|
require.Len(t, mc.connections, 2)
|
|
assert.Equal(t, entries[0], mc.connections[0])
|
|
assert.Equal(t, entries[2], mc.connections[1])
|
|
|
|
// Check the backing array directly: the slot just past the new
|
|
// length must be nil to avoid retaining the pointer.
|
|
backing := mc.connections[:3]
|
|
assert.Nil(t, backing[2],
|
|
"trailing slot in backing array should be nil after removal")
|
|
|
|
mc.mutex.Unlock()
|
|
}
|