all: fix golangci-lint issues (#3064)

This commit is contained in:
Kristoffer Dalby
2026-02-06 21:45:32 +01:00
committed by GitHub
parent bfb6fd80df
commit ce580f8245
131 changed files with 3131 additions and 1560 deletions

View File

@@ -16,6 +16,14 @@ import (
"tailscale.com/tailcfg"
)
// Mapper errors.
var (
ErrInvalidNodeID = errors.New("invalid nodeID")
ErrMapperNil = errors.New("mapper is nil")
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
ErrNodeNotFoundMapper = errors.New("node not found")
)
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
@@ -81,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
}
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
}
// Handle self-only responses
@@ -136,7 +144,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return errors.New("nodeConnection is nil")
return ErrNodeConnectionNil
}
nodeID := nc.nodeID()

View File

@@ -2,6 +2,7 @@ package mapper
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sync"
@@ -18,7 +19,13 @@ import (
"tailscale.com/types/ptr"
)
var errConnectionClosed = errors.New("connection channel already closed")
// LockFreeBatcher errors.
var (
errConnectionClosed = errors.New("connection channel already closed")
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
ErrBatcherShuttingDown = errors.New("batcher shutting down")
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
@@ -81,6 +88,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
if err != nil {
nlog.Error().Err(err).Msg("initial map generation failed")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("generating initial map for node %d: %w", id, err)
}
@@ -90,11 +98,12 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
case c <- initialMap:
// Success
case <-time.After(5 * time.Second): //nolint:mnd
nlog.Error().Err(errors.New("timeout")).Msg("initial map send timeout") //nolint:err113
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
Msg("initial map send timed out because channel was blocked or receiver not ready")
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
Msg("initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("sending initial map to node %d: timeout", id)
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
}
// Update connection status
@@ -135,6 +144,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
nlog.Debug().Caller().
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection removed but keeping online, other connections remain")
return true // Node still has active connections
}
@@ -219,10 +229,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
@@ -235,7 +247,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
nc.updateSentPeers(result.mapResponse)
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
b.workErrors.Add(1)
wlog.Error().Err(result.err).
@@ -402,6 +414,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
}
}
}
return true
})
@@ -454,6 +467,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
if nodeConn.hasActiveConnections() {
ret.Store(id, true)
}
return true
})
@@ -469,6 +483,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret.Store(id, false)
}
}
return true
})
@@ -488,7 +503,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.done:
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
}
}
@@ -523,8 +538,9 @@ type multiChannelNodeConn struct {
// generateConnectionID generates a unique connection identifier.
func generateConnectionID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
_, _ = rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
@@ -557,7 +573,9 @@ func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
@@ -579,9 +597,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("successfully removed connection")
return true
}
}
return false
}
@@ -615,6 +635,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// This is not an error - the node will receive a full map when it reconnects
mc.log.Debug().Caller().
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error
}
@@ -623,7 +644,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Msg("send: broadcasting to all connections")
var lastErr error
successCount := 0
var failedConnections []int // Track failed connections for removal
// Send to all connections
@@ -632,8 +655,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Msg("send: attempting to send to connection")
if err := conn.send(data); err != nil {
err := conn.send(data)
if err != nil {
lastErr = err
failedConnections = append(failedConnections, i)
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
@@ -695,7 +720,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
case <-time.After(50 * time.Millisecond):
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
}
}
@@ -805,6 +830,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
Connected: connected,
ActiveConnections: activeConnCount,
}
return true
})
@@ -819,6 +845,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
ActiveConnections: 0,
}
}
return true
})

View File

@@ -35,6 +35,7 @@ type batcherTestCase struct {
// that would normally be sent by poll.go in production.
type testBatcherWrapper struct {
Batcher
state *state.State
}
@@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
}
// Finally remove from the real batcher
removed := t.Batcher.RemoveNode(id, c)
if !removed {
return false
}
return true
return t.Batcher.RemoveNode(id, c)
}
// wrapBatcherForTest wraps a batcher with test-specific behavior.
@@ -129,8 +125,6 @@ const (
SMALL_BUFFER_SIZE = 3
TINY_BUFFER_SIZE = 1 // For maximum contention
LARGE_BUFFER_SIZE = 200
reservedResponseHeaderSize = 4
)
// TestData contains all test entities created for a test scenario.
@@ -241,8 +235,8 @@ func setupBatcherWithTestData(
}
derpMap, err := derp.GetDERPMap(cfg.DERP)
assert.NoError(t, err)
assert.NotNil(t, derpMap)
require.NoError(t, err)
require.NotNil(t, derpMap)
state.SetDERPMap(derpMap)
@@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
}
// getStats returns a copy of the statistics for a node.
//
//nolint:unused
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
ut.mu.RLock()
defer ut.mu.RUnlock()
@@ -386,16 +382,14 @@ type UpdateInfo struct {
}
// parseUpdateAndAnalyze parses an update and returns detailed information.
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
info := UpdateInfo{
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
return UpdateInfo{
PeerCount: len(resp.Peers),
PatchCount: len(resp.PeersChangedPatch),
IsFull: len(resp.Peers) > 0,
IsPatch: len(resp.PeersChangedPatch) > 0,
IsDERP: resp.DERPMap != nil,
}
return info, nil
}
// start begins consuming updates from the node's channel and tracking stats.
@@ -417,7 +411,8 @@ func (n *node) start() {
atomic.AddInt64(&n.updateCount, 1)
// Parse update and track detailed stats
if info, err := parseUpdateAndAnalyze(data); err == nil {
info := parseUpdateAndAnalyze(data)
{
// Track update types
if info.IsFull {
atomic.AddInt64(&n.fullCount, 1)
@@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start()
// Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
// Wait for connection to be established
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
for i := range allNodes {
node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullUpdate())
@@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
connectedCount := 0
for i := range allNodes {
node := &allNodes[i]
@@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
t.Logf("✅ All nodes achieved full connectivity!")
totalTime := time.Since(startTime)
// Disconnect all nodes
@@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) {
defer cleanup()
batcher := testData.Batcher
tn := testData.Nodes[0]
tn2 := testData.Nodes[1]
tn := &testData.Nodes[0]
tn2 := &testData.Nodes[1]
// Test AddNode with real node ID
batcher.AddNode(tn.n.ID, tn.ch, 100)
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode")
@@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) {
}
// Drain any initial messages from first node
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
drainChannelTimeout(tn.ch, 100*time.Millisecond)
// Add the second node and verify update message
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected.
@@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) {
}
}
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
count := 0
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) {
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case data := <-ch:
count++
// Optional: add debug output if needed
_ = data
case <-ch:
// Drain message
case <-timer.C:
return
}
@@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10)
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
// Track update content for validation
var receivedUpdates []*tailcfg.MapResponse
@@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// even when real node updates are being processed, ensuring no race conditions
// occur during channel replacement with actual workload.
func XTestBatcherChannelClosingRace(t *testing.T) {
t.Helper()
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
// Create test environment with real database and nodes
@@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
defer cleanup()
batcher := testData.Batcher
testNode := testData.Nodes[0]
testNode := &testData.Nodes[0]
var (
channelIssues int
@@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
ch1 := make(chan *tailcfg.MapResponse, 1)
wg.Go(func() {
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
})
// Add real work during connection chaos
@@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Go(func() {
runtime.Gosched() // Yield to introduce timing variability
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
})
// Remove second connection
@@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
defer cleanup()
batcher := testData.Batcher
testNode := testData.Nodes[0]
testNode := &testData.Nodes[0]
var (
panics int
@@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPMap())
// Consumer goroutine to validate data and detect channel issues
@@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
for range i % 3 {
runtime.Gosched() // Introduce timing variability
}
batcher.RemoveNode(testNode.n.ID, ch)
// Yield to allow workers to process and close channels
@@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// real node data. The test validates that stable clients continue to function
// normally and receive proper updates despite the connection churn from other clients,
// ensuring system stability under concurrent load.
//
//nolint:gocyclo // complex concurrent test scenario
func TestBatcherConcurrentClients(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrent client test in short mode")
@@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
for _, node := range stableNodes {
for i := range stableNodes {
node := &stableNodes[i]
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
// Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully
return
}
if valid, reason := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety
for i := range numCycles {
for _, node := range churningNodes {
for j := range churningNodes {
node := &churningNodes[j]
wg.Add(2)
// Connect churning node
@@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking
go func() {
@@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully
return
}
if valid, _ := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for range i % 5 {
runtime.Gosched() // Introduce timing variability
}
churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID]
@@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
if i%7 == 0 && len(allNodes) > 0 {
// Node-specific changes using real nodes
node := allNodes[i%len(allNodes)]
node := &allNodes[i%len(allNodes)]
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
@@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
for _, node := range stableNodes {
for i := range stableNodes {
node := &stableNodes[i]
if stats, exists := allStats[node.n.ID]; exists {
stableUpdateCount += stats.TotalUpdates
t.Logf("Stable node %d: %d updates",
@@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
}
}
for _, node := range churningNodes {
for i := range churningNodes {
node := &churningNodes[i]
if stats, exists := allStats[node.n.ID]; exists {
churningUpdateCount += stats.TotalUpdates
}
@@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
}
// Verify all stable clients are still functional
for _, node := range stableNodes {
for i := range stableNodes {
node := &stableNodes[i]
if !batcher.IsConnected(node.n.ID) {
t.Errorf("Stable node %d lost connection during racing", node.n.ID)
}
@@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
// It validates that the system remains stable with no deadlocks, panics, or
// missed updates under sustained high load. The test uses real node data to
// generate authentic update scenarios and tracks comprehensive statistics.
//
//nolint:gocyclo,thelper // complex scalability test scenario
func XTestBatcherScalability(t *testing.T) {
if testing.Short() {
t.Skip("Skipping scalability test in short mode")
@@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
description string
}
var testCases []testCase
testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes))
// Generate all combinations of the test matrix
for _, nodeCount := range nodes {
@@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) {
for i := range testNodes {
node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true
@@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) {
}
// Connection/disconnection cycles for subset of nodes
for i, node := range chaosNodes {
for i := range chaosNodes {
node := &chaosNodes[i]
// Only add work if this is connection chaos or mixed
if tc.chaosType == "connection" || tc.chaosType == "mixed" {
wg.Add(2)
@@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) {
channel,
tailcfg.CapabilityVersion(100),
)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = true
@@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
t.Logf("Created %d nodes in database", len(allNodes))
// Connect nodes one at a time and wait for each to be connected
for i, node := range allNodes {
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
for i := range allNodes {
node := &allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Wait for node to be connected
@@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
// Check how many peers each node should see
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
peers := testData.State.ListPeers(node.n.ID)
t.Logf("Node %d should see %d peers from state", i, peers.Len())
}
@@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...")
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node %d: %v", i, err)
@@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
removed := batcher.RemoveNode(node.n.ID, node.ch)
t.Logf("Node %d RemoveNode result: %t", i, removed)
}
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err)
@@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
debugInfo := debugBatcher.Debug()
disconnectedCount := 0
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
if info, exists := debugInfo[node.n.ID]; exists {
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
@@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
if infoMap, ok := info.(map[string]any); ok {
if connected, ok := infoMap["connected"].(bool); ok && !connected {
disconnectedCount++
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
}
}
} else {
disconnectedCount++
t.Logf("Node %d missing from debug info entirely", i)
}
@@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
case update := <-newChannels[i]:
if update != nil {
receivedCount++
t.Logf("Node %d received update successfully", i)
}
case <-timeout:
@@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
}
}
//nolint:gocyclo // complex multi-connection test scenario
func TestBatcherMultiConnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
@@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) {
defer cleanup()
batcher := testData.Batcher
node1 := testData.Nodes[0]
node2 := testData.Nodes[1]
node1 := &testData.Nodes[0]
node2 := &testData.Nodes[1]
t.Logf("=== MULTI-CONNECTION TEST ===")
// Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node1: %v", err)
@@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err)
@@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err)
@@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
@@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) {
if info, exists := debugInfo[node1.n.ID]; exists {
t.Logf("Node1 debug info: %+v", info)
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 3 {
@@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
}
}
if connected, ok := infoMap["connected"].(bool); ok && !connected {
t.Errorf("Node1 should show as connected with 3 active connections")
}

View File

@@ -1,7 +1,6 @@
package mapper
import (
"errors"
"net/netip"
"sort"
"time"
@@ -36,6 +35,7 @@ const (
// NewMapResponseBuilder creates a new builder with basic fields set.
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
resp: &tailcfg.MapResponse{
KeepAlive: false,
@@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFoundMapper)
return b
}
@@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
}
return b
}
@@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFoundMapper)
return b
}
@@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFoundMapper)
return b
}
@@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFoundMapper)
return b
}
@@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFoundMapper)
return b
}
@@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
return nil, errors.New("node not found")
return nil, ErrNodeNotFoundMapper
}
// Get unreduced matchers for peer relationship determination.
@@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
// WithPeersRemoved adds removed peer IDs.
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID
tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs))
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
// Build finalizes the response and returns marshaled bytes.
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)
}
if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
}

View File

@@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
// Build should return a multierr
data, err := result.Build()
assert.Nil(t, data)
assert.Error(t, err)
require.Nil(t, data)
require.Error(t, err)
// The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors")

View File

@@ -24,7 +24,6 @@ import (
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
mapperIDLength = 8
debugMapResponsePerm = 0o755
)
@@ -50,6 +49,7 @@ type mapper struct {
created time.Time
}
//nolint:unused
type patch struct {
timestamp time.Time
change *tailcfg.PeerChange
@@ -60,7 +60,6 @@ func newMapper(
state *state.State,
) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &mapper{
state: state,
cfg: cfg,
@@ -76,6 +75,7 @@ func generateUserProfiles(
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.UserView)
ids := make([]uint, 0, len(userMap))
user := node.Owner()
if !user.Valid() {
log.Error().
@@ -84,14 +84,17 @@ func generateUserProfiles(
return nil
}
userID := user.Model().ID
userMap[userID] = &user
ids = append(ids, userID)
for _, peer := range peers.All() {
peerUser := peer.Owner()
if !peerUser.Valid() {
continue
}
peerUserID := peerUser.Model().ID
userMap[peerUserID] = &peerUser
ids = append(ids, peerUserID)
@@ -99,7 +102,9 @@ func generateUserProfiles(
slices.Sort(ids)
ids = slices.Compact(ids)
var profiles []tailcfg.UserProfile
for _, id := range ids {
if userMap[id] != nil {
profiles = append(profiles, userMap[id].TailscaleUserProfile())
@@ -149,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
}
// fullMapResponse returns a MapResponse for the given node.
//
//nolint:unused
func (m *mapper) fullMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
@@ -316,6 +323,7 @@ func writeDebugMapResponse(
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
@@ -329,6 +337,7 @@ func writeDebugMapResponse(
)
log.Trace().Msgf("writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
@@ -337,7 +346,7 @@ func writeDebugMapResponse(
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
if debugDumpMapResponsePath == "" {
return nil, nil
return nil, nil //nolint:nilnil // intentional: no data when debug path not set
}
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
@@ -350,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
}
result := make(map[types.NodeID][]tailcfg.MapResponse)
for _, node := range nodes {
if !node.IsDir() {
continue
@@ -385,6 +395,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
}
var resp tailcfg.MapResponse
err = json.Unmarshal(body, &resp)
if err != nil {
log.Error().Err(err).Msgf("unmarshalling file %s", file.Name())

View File

@@ -3,14 +3,10 @@ package mapper
import (
"fmt"
"net/netip"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
@@ -81,90 +77,3 @@ func TestDNSConfigMapResponse(t *testing.T) {
})
}
}
// mockState is a mock implementation that provides the required methods.
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
primary *routes.PrimaryRoutes
nodes types.Nodes
peers types.Nodes
}
func (m *mockState) DERPMap() *tailcfg.DERPMap {
return m.derpMap
}
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
if m.polMan == nil {
return tailcfg.FilterAllowAll, nil
}
return m.polMan.Filter()
}
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
if m.polMan == nil {
return nil, nil
}
return m.polMan.SSHPolicy(node)
}
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
if m.polMan == nil {
return false
}
return m.polMan.NodeCanHaveTag(node, tag)
}
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
if m.primary == nil {
return nil
}
return m.primary.PrimaryRoutes(nodeID)
}
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
if len(peerIDs) > 0 {
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
// Return all peers except the node itself
var filtered types.Nodes
for _, peer := range m.peers {
if peer.ID != nodeID {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) > 0 {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)
}
}
return filtered, nil
}
return m.nodes, nil
}
func Test_fullMapResponse(t *testing.T) {
t.Skip("Test needs to be refactored for new state-based architecture")
// TODO: Refactor this test to work with the new state-based mapper
// The test architecture needs to be updated to work with the state interface
// instead of the old direct dependency injection pattern
}

View File

@@ -19,6 +19,7 @@ import (
func TestTailNode(t *testing.T) {
mustNK := func(str string) key.NodePublic {
var k key.NodePublic
_ = k.UnmarshalText([]byte(str))
return k
@@ -26,6 +27,7 @@ func TestTailNode(t *testing.T) {
mustDK := func(str string) key.DiscoPublic {
var k key.DiscoPublic
_ = k.UnmarshalText([]byte(str))
return k
@@ -33,6 +35,7 @@ func TestTailNode(t *testing.T) {
mustMK := func(str string) key.MachinePublic {
var k key.MachinePublic
_ = k.UnmarshalText([]byte(str))
return k
@@ -255,7 +258,7 @@ func TestNodeExpiry(t *testing.T) {
},
{
name: "localtime",
exp: tp(time.Time{}.Local()),
exp: tp(time.Time{}.Local()), //nolint:gosmopolitan
wantTimeZero: true,
},
}
@@ -284,7 +287,9 @@ func TestNodeExpiry(t *testing.T) {
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)
}
var deseri tailcfg.Node
err = json.Unmarshal(seri, &deseri)
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)