From 87b8507ac9d3510ee15074478e235ab62b1ba872 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 14 Mar 2026 14:06:52 +0000 Subject: [PATCH] mapper/batcher: replace connected map with per-node disconnectedAt The Batcher's connected field (*xsync.Map[types.NodeID, *time.Time]) encoded three states via pointer semantics: - nil value: node is connected - non-nil time: node disconnected at that timestamp - key missing: node was never seen This was error-prone (nil meaning 'connected' inverts Go idioms), redundant with b.nodes + hasActiveConnections(), and required keeping two parallel maps in sync. It also contained a bug in RemoveNode where new(time.Now()) was used instead of &now, producing a zero time. Replace the separate connected map with a disconnectedAt field on multiChannelNodeConn (atomic.Pointer[time.Time]), tracked directly on the object that already manages the node's connections. Changes: - Add disconnectedAt field and helpers (markConnected, markDisconnected, isConnected, offlineDuration) to multiChannelNodeConn - Remove the connected field from Batcher - Simplify IsConnected from two map lookups to one - Simplify ConnectedMap and Debug from two-map iteration to one - Rewrite cleanupOfflineNodes to scan b.nodes directly - Remove the markDisconnectedIfNoConns helper - Update all tests and benchmarks Fixes #3141 --- hscontrol/mapper/batcher.go | 144 +++++-------------- hscontrol/mapper/batcher_bench_test.go | 21 +-- hscontrol/mapper/batcher_concurrency_test.go | 113 +++++++-------- hscontrol/mapper/batcher_scale_bench_test.go | 9 +- hscontrol/mapper/batcher_unit_test.go | 11 +- hscontrol/mapper/node_conn.go | 41 ++++++ 6 files changed, 148 insertions(+), 191 deletions(-) diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 57fe2ed5..8caa901d 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -43,10 +43,9 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher { tick: time.NewTicker(batchTime), // The size of this channel is arbitrary chosen, the sizing should be revisited. - workCh: make(chan work, workers*200), - done: make(chan struct{}), - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), + workCh: make(chan work, workers*200), + done: make(chan struct{}), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), } } @@ -200,8 +199,7 @@ type Batcher struct { mapper *mapper workers int - nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] - connected *xsync.Map[types.NodeID, *time.Time] + nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] // Work queue channel workCh chan work @@ -264,7 +262,10 @@ func (b *Batcher) AddNode( if err != nil { nlog.Error().Err(err).Msg("initial map generation failed") nodeConn.removeConnectionByChannel(c) - b.markDisconnectedIfNoConns(id, nodeConn) + + if !nodeConn.hasActiveConnections() { + nodeConn.markDisconnected() + } return fmt.Errorf("generating initial map for node %d: %w", id, err) } @@ -279,13 +280,16 @@ func (b *Batcher) AddNode( nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd Msg("initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) - b.markDisconnectedIfNoConns(id, nodeConn) + + if !nodeConn.hasActiveConnections() { + nodeConn.markDisconnected() + } return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) } - // Update connection status - b.connected.Store(id, nil) // nil = connected + // Mark the node as connected now that the initial map was sent. + nodeConn.markConnected() // Node will automatically receive updates through the normal flow // The initial full map already contains all current state @@ -328,7 +332,7 @@ func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) boo // No active connections - keep the node entry alive for rapid reconnections // The node will get a fresh full map when it reconnects nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection") - b.connected.Store(id, new(time.Now())) + nodeConn.markDisconnected() return false } @@ -530,8 +534,6 @@ func (b *Batcher) addToBatch(changes ...change.Change) { Uint64(zf.NodeID, removedID.Uint64()). Msg("removed deleted node from batcher") } - - b.connected.Delete(removedID) } } @@ -604,15 +606,13 @@ func (b *Batcher) processBatchedChanges() { // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. // Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node // reconnects between the hasActiveConnections() check and the Delete() call. -// TODO(kradalby): reevaluate if we want to keep this. func (b *Batcher) cleanupOfflineNodes() { - now := time.Now() - var nodesToCleanup []types.NodeID - // Find nodes that have been offline for too long - b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { - if disconnectTime != nil && now.Sub(*disconnectTime) > offlineNodeCleanupThreshold { + // Find nodes that have been offline for too long by scanning b.nodes + // and checking each node's disconnectedAt timestamp. + b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { + if nc != nil && !nc.hasActiveConnections() && nc.offlineDuration() > offlineNodeCleanupThreshold { nodesToCleanup = append(nodesToCleanup, nodeID) } @@ -635,8 +635,7 @@ func (b *Batcher) cleanupOfflineNodes() { // Perform all bookkeeping inside the Compute callback so // that a concurrent AddNode (which calls LoadOrStore on // b.nodes) cannot slip in between the delete and the - // connected/counter updates. - b.connected.Delete(nodeID) + // counter update. b.totalNodes.Add(-1) cleaned++ @@ -656,57 +655,26 @@ func (b *Batcher) cleanupOfflineNodes() { } } -// IsConnected is lock-free read that checks if a node has any active connections. +// IsConnected is a lock-free read that checks if a node is connected. +// A node is considered connected if it has active connections or has +// not been marked as disconnected. func (b *Batcher) IsConnected(id types.NodeID) bool { - // First check if we have active connections for this node - if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil { - if nodeConn.hasActiveConnections() { - return true - } - } - - // Check disconnected timestamp with grace period - val, ok := b.connected.Load(id) - if !ok { + nodeConn, exists := b.nodes.Load(id) + if !exists || nodeConn == nil { return false } - // nil means connected - if val == nil { - return true - } - - return false + return nodeConn.isConnected() } -// ConnectedMap returns a lock-free map of all connected nodes. +// ConnectedMap returns a lock-free map of all known nodes and their +// connection status (true = connected, false = disconnected). func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret := xsync.NewMap[types.NodeID, bool]() - // First, add all nodes with active connections - b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { - if nodeConn == nil { - return true - } - - if nodeConn.hasActiveConnections() { - ret.Store(id, true) - } - - return true - }) - - // Then add all entries from the connected map - b.connected.Range(func(id types.NodeID, val *time.Time) bool { - // Only add if not already added as connected above - if _, exists := ret.Load(id); !exists { - if val == nil { - // nil means connected - ret.Store(id, true) - } else { - // timestamp means disconnected - ret.Store(id, false) - } + b.nodes.Range(func(id types.NodeID, nc *multiChannelNodeConn) bool { + if nc != nil { + ret.Store(id, nc.isConnected()) } return true @@ -715,17 +683,6 @@ func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { return ret } -// markDisconnectedIfNoConns stores a disconnect timestamp in b.connected -// when the node has no remaining active connections. This prevents -// IsConnected from returning a stale true after all connections have been -// removed on an error path (e.g. AddNode failure). -func (b *Batcher) markDisconnectedIfNoConns(id types.NodeID, nc *multiChannelNodeConn) { - if !nc.hasActiveConnections() { - now := time.Now() - b.connected.Store(id, &now) - } -} - // MapResponseFromChange queues work to generate a map response and waits for the result. // This allows synchronous map generation using the same worker pool. func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { @@ -753,45 +710,14 @@ type DebugNodeInfo struct { func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo { result := make(map[types.NodeID]DebugNodeInfo) - // Get all nodes with their connection status using immediate connection logic - // (no grace period) for debug purposes - b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { - if nodeConn == nil { + b.nodes.Range(func(id types.NodeID, nc *multiChannelNodeConn) bool { + if nc == nil { return true } - activeConnCount := nodeConn.getActiveConnectionCount() - - // Use immediate connection status: if active connections exist, node is connected - // If not, check the connected map for nil (connected) vs timestamp (disconnected) - connected := false - if activeConnCount > 0 { - connected = true - } else { - // Check connected map for immediate status - if val, ok := b.connected.Load(id); ok && val == nil { - connected = true - } - } - result[id] = DebugNodeInfo{ - Connected: connected, - ActiveConnections: activeConnCount, - } - - return true - }) - - // Add all entries from the connected map to capture both connected and disconnected nodes - b.connected.Range(func(id types.NodeID, val *time.Time) bool { - // Only add if not already processed above - if _, exists := result[id]; !exists { - // Use immediate connection status for debug (no grace period) - connected := (val == nil) // nil means connected, timestamp means disconnected - result[id] = DebugNodeInfo{ - Connected: connected, - ActiveConnections: 0, - } + Connected: nc.isConnected(), + ActiveConnections: nc.getActiveConnectionCount(), } return true diff --git a/hscontrol/mapper/batcher_bench_test.go b/hscontrol/mapper/batcher_bench_test.go index e7c8f06a..78f7e693 100644 --- a/hscontrol/mapper/batcher_bench_test.go +++ b/hscontrol/mapper/batcher_bench_test.go @@ -150,12 +150,11 @@ func BenchmarkUpdateSentPeers(b *testing.B) { // helper, it doesn't register cleanup and suppresses logging. func benchBatcher(nodeCount, bufferSize int) (*Batcher, map[types.NodeID]chan *tailcfg.MapResponse) { b := &Batcher{ - tick: time.NewTicker(1 * time.Hour), // never fires during bench - workers: 4, - workCh: make(chan work, 4*200), - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), - done: make(chan struct{}), + tick: time.NewTicker(1 * time.Hour), // never fires during bench + workers: 4, + workCh: make(chan work, 4*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + done: make(chan struct{}), } channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount) @@ -172,7 +171,6 @@ func benchBatcher(nodeCount, bufferSize int) (*Batcher, map[types.NodeID]chan *t entry.lastUsed.Store(time.Now().Unix()) mc.addConnection(entry) b.nodes.Store(id, mc) - b.connected.Store(id, nil) channels[id] = ch } @@ -471,7 +469,7 @@ func BenchmarkConnectedMap(b *testing.B) { for _, nodeCount := range []int{10, 100, 1000} { b.Run(fmt.Sprintf("%dnodes", nodeCount), func(b *testing.B) { - batcher, _ := benchBatcher(nodeCount, 1) + batcher, channels := benchBatcher(nodeCount, 1) defer func() { close(batcher.done) @@ -481,8 +479,11 @@ func BenchmarkConnectedMap(b *testing.B) { // Disconnect 10% of nodes for a realistic mix for i := 1; i <= nodeCount; i++ { if i%10 == 0 { - now := time.Now() - batcher.connected.Store(types.NodeID(i), &now) //nolint:gosec // benchmark + id := types.NodeID(i) //nolint:gosec // benchmark + if mc, ok := batcher.nodes.Load(id); ok { + mc.removeConnectionByChannel(channels[id]) + mc.markDisconnected() + } } } diff --git a/hscontrol/mapper/batcher_concurrency_test.go b/hscontrol/mapper/batcher_concurrency_test.go index 61da0d38..ee628522 100644 --- a/hscontrol/mapper/batcher_concurrency_test.go +++ b/hscontrol/mapper/batcher_concurrency_test.go @@ -47,12 +47,11 @@ func setupLightweightBatcher(t *testing.T, nodeCount, bufferSize int) *lightweig t.Helper() b := &Batcher{ - tick: time.NewTicker(10 * time.Millisecond), - workers: 4, - workCh: make(chan work, 4*200), - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), - done: make(chan struct{}), + tick: time.NewTicker(10 * time.Millisecond), + workers: 4, + workCh: make(chan work, 4*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + done: make(chan struct{}), } channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount) @@ -69,7 +68,6 @@ func setupLightweightBatcher(t *testing.T, nodeCount, bufferSize int) *lightweig entry.lastUsed.Store(time.Now().Unix()) mc.addConnection(entry) b.nodes.Store(id, mc) - b.connected.Store(id, nil) // nil = connected channels[id] = ch } @@ -299,13 +297,10 @@ func TestAddToBatch_NodeRemovalCleanup(t *testing.T) { PeersRemoved: []types.NodeID{removedNode}, }) - // Node should be removed from all maps + // Node should be removed from the nodes map _, exists = lb.b.nodes.Load(removedNode) assert.False(t, exists, "node 3 should be removed from nodes map") - _, exists = lb.b.connected.Load(removedNode) - assert.False(t, exists, "node 3 should be removed from connected map") - pending := getPendingForNode(lb.b, removedNode) assert.Empty(t, pending, "node 3 should have no pending changes") @@ -546,13 +541,13 @@ func TestCleanupOfflineNodes_RemovesOld(t *testing.T) { lb := setupLightweightBatcher(t, 5, 10) defer lb.cleanup() - // Make node 3 appear offline for 20 minutes - oldTime := time.Now().Add(-20 * time.Minute) - lb.b.connected.Store(types.NodeID(3), &oldTime) - // Remove its active connections so it appears truly offline + // Remove node 3's active connections and mark it disconnected 20 minutes ago if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { ch := lb.channels[types.NodeID(3)] mc.removeConnectionByChannel(ch) + + oldTime := time.Now().Add(-20 * time.Minute) + mc.disconnectedAt.Store(&oldTime) } lb.b.cleanupOfflineNodes() @@ -571,13 +566,13 @@ func TestCleanupOfflineNodes_KeepsRecent(t *testing.T) { lb := setupLightweightBatcher(t, 5, 10) defer lb.cleanup() - // Make node 3 appear offline for only 5 minutes (under threshold) - recentTime := time.Now().Add(-5 * time.Minute) - lb.b.connected.Store(types.NodeID(3), &recentTime) - + // Remove node 3's connections and mark it disconnected 5 minutes ago (under threshold) if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { ch := lb.channels[types.NodeID(3)] mc.removeConnectionByChannel(ch) + + recentTime := time.Now().Add(-5 * time.Minute) + mc.disconnectedAt.Store(&recentTime) } lb.b.cleanupOfflineNodes() @@ -593,8 +588,10 @@ func TestCleanupOfflineNodes_KeepsActive(t *testing.T) { defer lb.cleanup() // Set old disconnect time but keep the connection active - oldTime := time.Now().Add(-20 * time.Minute) - lb.b.connected.Store(types.NodeID(3), &oldTime) + if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { + oldTime := time.Now().Add(-20 * time.Minute) + mc.disconnectedAt.Store(&oldTime) + } // Don't remove connection - node still has active connections lb.b.cleanupOfflineNodes() @@ -717,14 +714,12 @@ func TestBatcher_IsConnectedReflectsState(t *testing.T) { // Non-existent node should not be connected assert.False(t, lb.b.IsConnected(types.NodeID(999))) - // Disconnect node 3 (remove connection + set disconnect time) + // Disconnect node 3 (remove connection + mark disconnected) if mc, ok := lb.b.nodes.Load(types.NodeID(3)); ok { mc.removeConnectionByChannel(lb.channels[types.NodeID(3)]) + mc.markDisconnected() } - now := time.Now() - lb.b.connected.Store(types.NodeID(3), &now) - assert.False(t, lb.b.IsConnected(types.NodeID(3)), "node 3 should not be connected after disconnection") @@ -742,11 +737,9 @@ func TestBatcher_ConnectedMapConsistency(t *testing.T) { // Disconnect node 2 if mc, ok := lb.b.nodes.Load(types.NodeID(2)); ok { mc.removeConnectionByChannel(lb.channels[types.NodeID(2)]) + mc.markDisconnected() } - now := time.Now() - lb.b.connected.Store(types.NodeID(2), &now) - cm := lb.b.ConnectedMap() // Connected nodes @@ -789,13 +782,13 @@ func TestBug3_CleanupOfflineNodes_TOCTOU(t *testing.T) { targetNode := types.NodeID(3) - // Make node 3 appear offline for >15 minutes (past cleanup threshold) - oldTime := time.Now().Add(-20 * time.Minute) - lb.b.connected.Store(targetNode, &oldTime) - // Remove its active connections so it appears truly offline + // Remove node 3's active connections and mark it disconnected >15 minutes ago if mc, ok := lb.b.nodes.Load(targetNode); ok { ch := lb.channels[targetNode] mc.removeConnectionByChannel(ch) + + oldTime := time.Now().Add(-20 * time.Minute) + mc.disconnectedAt.Store(&oldTime) } // Verify node 3 has no active connections before we start. @@ -819,7 +812,7 @@ func TestBug3_CleanupOfflineNodes_TOCTOU(t *testing.T) { } entry.lastUsed.Store(time.Now().Unix()) mc.addConnection(entry) - lb.b.connected.Store(targetNode, nil) // nil = connected + mc.markConnected() lb.channels[targetNode] = newCh // Now run cleanup. Node 3 is in the candidates list (old disconnect @@ -840,7 +833,7 @@ func TestBug3_CleanupOfflineNodes_TOCTOU(t *testing.T) { mc.removeConnectionByChannel(newCh) oldTime2 := time.Now().Add(-20 * time.Minute) - lb.b.connected.Store(targetNode, &oldTime2) + mc.disconnectedAt.Store(&oldTime2) var wg sync.WaitGroup @@ -861,7 +854,7 @@ func TestBug3_CleanupOfflineNodes_TOCTOU(t *testing.T) { } reconnEntry.lastUsed.Store(time.Now().Unix()) mc.addConnection(reconnEntry) - lb.b.connected.Store(targetNode, nil) + mc.markConnected() } }) @@ -1028,13 +1021,13 @@ func TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally(t *testing.T targetNode := types.NodeID(3) - // Make node 3 appear offline for >15 minutes - oldTime := time.Now().Add(-20 * time.Minute) - lb.b.connected.Store(targetNode, &oldTime) - + // Remove node 3's connections and mark it disconnected >15 minutes ago if mc, ok := lb.b.nodes.Load(targetNode); ok { ch := lb.channels[targetNode] mc.removeConnectionByChannel(ch) + + oldTime := time.Now().Add(-20 * time.Minute) + mc.disconnectedAt.Store(&oldTime) } // Add pending changes for node 3 before cleanup @@ -1049,13 +1042,10 @@ func TestBug7_CleanupOfflineNodes_PendingChangesCleanedStructurally(t *testing.T // Run cleanup lb.b.cleanupOfflineNodes() - // Node 3 should be removed from nodes and connected + // Node 3 should be removed from the nodes map _, existsInNodes := lb.b.nodes.Load(targetNode) assert.False(t, existsInNodes, "node 3 should be removed from nodes map") - _, existsInConnected := lb.b.connected.Load(targetNode) - assert.False(t, existsInConnected, "node 3 should be removed from connected map") - // Pending changes are structurally gone because the node was deleted. // getPendingForNode returns nil for non-existent nodes. pendingAfter := getPendingForNode(lb.b, targetNode) @@ -1282,12 +1272,11 @@ func TestScale1000_MultiChannelBroadcast(t *testing.T) { // Create nodes with varying connection counts b := &Batcher{ - tick: time.NewTicker(10 * time.Millisecond), - workers: 4, - workCh: make(chan work, 4*200), - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), - done: make(chan struct{}), + tick: time.NewTicker(10 * time.Millisecond), + workers: 4, + workCh: make(chan work, 4*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + done: make(chan struct{}), } defer func() { @@ -1551,16 +1540,17 @@ func TestScale1000_IsConnectedConsistency(t *testing.T) { } }) - // Goroutine modifying connection state + // Goroutine modifying connection state via disconnectedAt on the node conn wg.Go(func() { for i := range 100 { id := types.NodeID(1 + (i % 1000)) //nolint:gosec // test - if i%2 == 0 { - now := time.Now() - lb.b.connected.Store(id, &now) // disconnect - } else { - lb.b.connected.Store(id, nil) // reconnect + if mc, ok := lb.b.nodes.Load(id); ok { + if i%2 == 0 { + mc.markDisconnected() // disconnect + } else { + mc.markConnected() // reconnect + } } } }) @@ -1617,7 +1607,6 @@ func TestScale1000_BroadcastDuringNodeChurn(t *testing.T) { if cycle%2 == 0 { // "Remove" node lb.b.nodes.Delete(id) - lb.b.connected.Delete(id) } else { // "Add" node back mc := newMultiChannelNodeConn(id, nil) @@ -1631,7 +1620,6 @@ func TestScale1000_BroadcastDuringNodeChurn(t *testing.T) { entry.lastUsed.Store(time.Now().Unix()) mc.addConnection(entry) lb.b.nodes.Store(id, mc) - lb.b.connected.Store(id, nil) } }() } @@ -1684,12 +1672,11 @@ func TestScale1000_WorkChannelSaturation(t *testing.T) { // Create batcher with SMALL work channel to force saturation b := &Batcher{ - tick: time.NewTicker(10 * time.Millisecond), - workers: 2, - workCh: make(chan work, 10), // Very small - will saturate - nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), - connected: xsync.NewMap[types.NodeID, *time.Time](), - done: make(chan struct{}), + tick: time.NewTicker(10 * time.Millisecond), + workers: 2, + workCh: make(chan work, 10), // Very small - will saturate + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + done: make(chan struct{}), } defer func() { diff --git a/hscontrol/mapper/batcher_scale_bench_test.go b/hscontrol/mapper/batcher_scale_bench_test.go index e8e6b075..8192d967 100644 --- a/hscontrol/mapper/batcher_scale_bench_test.go +++ b/hscontrol/mapper/batcher_scale_bench_test.go @@ -305,7 +305,7 @@ func BenchmarkScale_ConnectedMap(b *testing.B) { for _, n := range scaleCountsHeavy { b.Run(strconv.Itoa(n), func(b *testing.B) { - batcher, _ := benchBatcher(n, 1) + batcher, channels := benchBatcher(n, 1) defer func() { close(batcher.done) @@ -315,8 +315,11 @@ func BenchmarkScale_ConnectedMap(b *testing.B) { // 10% disconnected for realism for i := 1; i <= n; i++ { if i%10 == 0 { - now := time.Now() - batcher.connected.Store(types.NodeID(i), &now) //nolint:gosec + id := types.NodeID(i) //nolint:gosec + if mc, ok := batcher.nodes.Load(id); ok { + mc.removeConnectionByChannel(channels[id]) + mc.markDisconnected() + } } } diff --git a/hscontrol/mapper/batcher_unit_test.go b/hscontrol/mapper/batcher_unit_test.go index 8e971021..54172522 100644 --- a/hscontrol/mapper/batcher_unit_test.go +++ b/hscontrol/mapper/batcher_unit_test.go @@ -1121,9 +1121,9 @@ func TestBatcher_QueueWorkAfterClose_DoesNotHang(t *testing.T) { } // TestIsConnected_FalseAfterAddNodeFailure is a regression guard for M3. -// Before the fix, AddNode error paths removed the connection but left -// b.connected with its previous value (nil = connected). IsConnected -// would return true for a node with zero active connections. +// Before the fix, AddNode error paths removed the connection but did not +// mark the node as disconnected. IsConnected would return true for a +// node with zero active connections. func TestIsConnected_FalseAfterAddNodeFailure(t *testing.T) { b := NewBatcher(50*time.Millisecond, 2, nil) b.Start() @@ -1132,12 +1132,11 @@ func TestIsConnected_FalseAfterAddNodeFailure(t *testing.T) { id := types.NodeID(42) - // Simulate a previous session leaving the node marked as connected. - b.connected.Store(id, nil) // nil = connected - // Pre-create the node entry so AddNode reuses it, and set up a // multiChannelNodeConn with no mapper so MapResponseFromChange will fail. + // markConnected() simulates a previous session leaving it connected. nc := newMultiChannelNodeConn(id, nil) + nc.markConnected() b.nodes.Store(id, nc) ch := make(chan *tailcfg.MapResponse, 1) diff --git a/hscontrol/mapper/node_conn.go b/hscontrol/mapper/node_conn.go index 1ee9d3e6..87d61544 100644 --- a/hscontrol/mapper/node_conn.go +++ b/hscontrol/mapper/node_conn.go @@ -52,6 +52,12 @@ type multiChannelNodeConn struct { closeOnce sync.Once updateCount atomic.Int64 + // disconnectedAt records when the last connection was removed. + // nil means the node is considered connected (or newly created); + // non-nil means the node disconnected at the stored timestamp. + // Used by cleanupOfflineNodes to evict stale entries. + disconnectedAt atomic.Pointer[time.Time] + // lastSentPeers tracks which peers were last sent to this node. // This enables computing diffs for policy changes instead of sending // full peer lists (which clients interpret as "no change" when empty). @@ -162,6 +168,41 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int { return len(mc.connections) } +// markConnected clears the disconnect timestamp, indicating the node +// has an active connection. +func (mc *multiChannelNodeConn) markConnected() { + mc.disconnectedAt.Store(nil) +} + +// markDisconnected records the current time as the moment the node +// lost its last connection. Used by cleanupOfflineNodes to determine +// how long the node has been offline. +func (mc *multiChannelNodeConn) markDisconnected() { + now := time.Now() + mc.disconnectedAt.Store(&now) +} + +// isConnected returns true if the node has active connections or has +// not been marked as disconnected. +func (mc *multiChannelNodeConn) isConnected() bool { + if mc.hasActiveConnections() { + return true + } + + return mc.disconnectedAt.Load() == nil +} + +// offlineDuration returns how long the node has been disconnected. +// Returns 0 if the node is connected or has never been marked as disconnected. +func (mc *multiChannelNodeConn) offlineDuration() time.Duration { + t := mc.disconnectedAt.Load() + if t == nil { + return 0 + } + + return time.Since(*t) +} + // appendPending appends changes to this node's pending change list. // Thread-safe via pendingMu; does not contend with the connection mutex. func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {