diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 10c84b50..88647c60 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -135,7 +135,6 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo removed := nodeConn.removeConnectionByChannel(c) if !removed { nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid") - return false } // Check if node has any remaining active connections diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 6f3fbccb..58a2158d 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -2801,6 +2801,71 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { } } +func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + t.Run("marks disconnected when removed channel was last active connection", func(t *testing.T) { + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, NORMAL_BUFFER_SIZE) + defer cleanup() + + lfb, ok := unwrapBatcher(testData.Batcher).(*LockFreeBatcher) + require.True(t, ok, "expected LockFreeBatcher") + + nodeID := testData.Nodes[0].n.ID + ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) + require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode") + }, 5*time.Second, 50*time.Millisecond, "waiting for node to be connected") + + nodeConn, exists := lfb.nodes.Load(nodeID) + require.True(t, exists, "node connection should exist") + require.True(t, nodeConn.removeConnectionByChannel(ch), "manual channel removal should succeed") + + removed := lfb.RemoveNode(nodeID, ch) + assert.False(t, removed, "RemoveNode should report no remaining active connections") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.False(c, lfb.IsConnected(nodeID), "node should be disconnected after last connection is gone") + }, 5*time.Second, 50*time.Millisecond, "waiting for node to be disconnected") + + close(ch) + }) + + t.Run("keeps connected when another connection is still active", func(t *testing.T) { + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, NORMAL_BUFFER_SIZE) + defer cleanup() + + lfb, ok := unwrapBatcher(testData.Batcher).(*LockFreeBatcher) + require.True(t, ok, "expected LockFreeBatcher") + + nodeID := testData.Nodes[0].n.ID + ch1 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) + ch2 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) + + require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100))) + require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100))) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode") + }, 5*time.Second, 50*time.Millisecond, "waiting for node to be connected") + + nodeConn, exists := lfb.nodes.Load(nodeID) + require.True(t, exists, "node connection should exist") + require.True(t, nodeConn.removeConnectionByChannel(ch1), "manual channel removal should succeed") + + removed := lfb.RemoveNode(nodeID, ch1) + assert.True(t, removed, "RemoveNode should report node still has active connections") + assert.True(t, lfb.IsConnected(nodeID), "node should still be connected while another connection exists") + assert.Equal(t, 1, nodeConn.getActiveConnectionCount(), "exactly one active connection should remain") + + close(ch1) + }) + }) + } +} + // unwrapBatcher extracts the underlying batcher from wrapper types. func unwrapBatcher(b Batcher) Batcher { if wrapper, ok := b.(*testBatcherWrapper); ok {