diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 1f092a9c..5b2adddc 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -36,7 +36,7 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher type Batcher interface { Start() Close() - AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 239a6504..4d35c274 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -54,7 +54,13 @@ type LockFreeBatcher struct { // AddNode registers a new node connection with the batcher and sends an initial map response. // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. -func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { +// The stop function tears down the owning session if this connection is later declared stale. +func (b *LockFreeBatcher) AddNode( + id types.NodeID, + c chan<- *tailcfg.MapResponse, + version tailcfg.CapabilityVersion, + stop func(), +) error { addNodeStart := time.Now() nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() @@ -68,6 +74,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse c: c, version: version, created: now, + stop: stop, } // Initialize last used timestamp newEntry.lastUsed.Store(now.Unix()) @@ -511,6 +518,7 @@ type connectionEntry struct { c chan<- *tailcfg.MapResponse version tailcfg.CapabilityVersion created time.Time + stop func() lastUsed atomic.Int64 // Unix timestamp of last successful send closed atomic.Bool // Indicates if this connection has been closed } @@ -556,27 +564,29 @@ func (mc *multiChannelNodeConn) close() { defer mc.mutex.Unlock() for _, conn := range mc.connections { - mc.closeConnection(conn) + mc.stopConnection(conn) } } -// closeConnection closes connection channel at most once, even if multiple cleanup -// paths race to tear the same session down. -func (mc *multiChannelNodeConn) closeConnection(conn *connectionEntry) { +// stopConnection marks a connection as closed and tears down the owning session +// at most once, even if multiple cleanup paths race to remove it. +func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) { if conn.closed.CompareAndSwap(false, true) { - close(conn.c) + if conn.stop != nil { + conn.stop() + } } } // removeConnectionAtIndexLocked removes the active connection at index. -// If closeChannel is true, it also closes that session's map-response channel. +// If stopConnection is true, it also stops that session. // Caller must hold mc.mutex. -func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, closeChannel bool) *connectionEntry { +func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry { conn := mc.connections[i] mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) - if closeChannel { - mc.closeConnection(conn) + if stopConnection { + mc.stopConnection(conn) } return conn diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 58a2158d..75fbe054 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -39,7 +39,7 @@ type testBatcherWrapper struct { state *state.State } -func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { +func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error { // Mark node as online in state before AddNode to match production behavior // This ensures the NodeStore has correct online status for change processing if t.state != nil { @@ -48,7 +48,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo } // First add the node to the real batcher - err := t.Batcher.AddNode(id, c, version) + err := t.Batcher.AddNode(id, c, version, stop) if err != nil { return err } @@ -543,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100), nil) // Wait for connection to be established assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -652,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { for i := range allNodes { node := &allNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullUpdate()) @@ -821,7 +821,7 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := &testData.Nodes[1] // Test AddNode with real node ID - _ = batcher.AddNode(tn.n.ID, tn.ch, 100) + _ = batcher.AddNode(tn.n.ID, tn.ch, 100, nil) if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") @@ -842,7 +842,7 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, 100*time.Millisecond) // Add the second node and verify update message - _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100) + _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100, nil) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. @@ -1043,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100), nil) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1149,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, 1) wg.Go(func() { - _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100), nil) }) // Add real work during connection chaos @@ -1163,7 +1163,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Go(func() { runtime.Gosched() // Yield to introduce timing variability - _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100), nil) }) // Remove second connection @@ -1254,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100), nil) batcher.AddWork(change.DERPMap()) // Consumer goroutine to validate data and detect channel issues @@ -1380,7 +1380,7 @@ func TestBatcherConcurrentClients(t *testing.T) { node := &stableNodes[i] ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100), nil) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1456,7 +1456,7 @@ func TestBatcherConcurrentClients(t *testing.T) { churningChannelsMutex.Unlock() - _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100), nil) // Consume updates to prevent blocking go func() { @@ -1774,7 +1774,7 @@ func XTestBatcherScalability(t *testing.T) { for i := range testNodes { node := &testNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) connectedNodesMutex.Lock() @@ -1891,6 +1891,7 @@ func XTestBatcherScalability(t *testing.T) { nodeID, channel, tailcfg.CapabilityVersion(100), + nil, ) connectedNodesMutex.Lock() @@ -2155,7 +2156,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time and wait for each to be connected for i := range allNodes { node := &allNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Wait for node to be connected @@ -2307,7 +2308,7 @@ func TestBatcherRapidReconnection(t *testing.T) { for i := range allNodes { node := &allNodes[i] - err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add node %d: %v", i, err) } @@ -2337,7 +2338,7 @@ func TestBatcherRapidReconnection(t *testing.T) { node := &allNodes[i] newChannels[i] = make(chan *tailcfg.MapResponse, 10) - err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100), nil) if err != nil { t.Errorf("Failed to reconnect node %d: %v", i, err) } @@ -2444,13 +2445,13 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 1: Connect first node with initial connection t.Logf("Phase 1: Connecting node 1 with first connection...") - err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add node1: %v", err) } // Connect second node for comparison - err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)) + err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add node2: %v", err) } @@ -2466,7 +2467,7 @@ func TestBatcherMultiConnection(t *testing.T) { secondChannel := make(chan *tailcfg.MapResponse, 10) - err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add second connection for node1: %v", err) } @@ -2479,7 +2480,7 @@ func TestBatcherMultiConnection(t *testing.T) { thirdChannel := make(chan *tailcfg.MapResponse, 10) - err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add third connection for node1: %v", err) } @@ -2718,9 +2719,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { defer node3.cleanup() // Connect all nodes to the batcher - require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))) - require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))) - require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100))) + require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil)) + require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100), nil)) + require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100), nil)) // Wait for all nodes to be connected assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -2813,7 +2814,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { nodeID := testData.Nodes[0].n.ID ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) - require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))) + require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100), nil)) assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode") @@ -2844,8 +2845,8 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) ch2 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) - require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100))) - require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100))) + require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100), nil)) + require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100), nil)) assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode") diff --git a/hscontrol/poll.go b/hscontrol/poll.go index ded86068..3179eb78 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand/v2" "net/http" + "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -14,7 +15,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util/zlog/zf" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/sasha-s/go-deadlock" "tailscale.com/tailcfg" "tailscale.com/util/zstdframe" ) @@ -33,11 +33,9 @@ type mapSession struct { ctx context.Context //nolint:containedctx capVer tailcfg.CapabilityVersion - cancelChMu deadlock.Mutex - - ch chan *tailcfg.MapResponse - cancelCh chan struct{} - cancelChOpen bool + ch chan *tailcfg.MapResponse + cancelCh chan struct{} + cancelChClosed atomic.Bool keepAlive time.Duration keepAliveTicker *time.Ticker @@ -64,9 +62,8 @@ func (h *Headscale) newMapSession( node: node, capVer: req.Version, - ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize), - cancelCh: make(chan struct{}), - cancelChOpen: true, + ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize), + cancelCh: make(chan struct{}), keepAlive: ka, keepAliveTicker: nil, @@ -92,6 +89,12 @@ func (m *mapSession) resetKeepAlive() { m.keepAliveTicker.Reset(m.keepAlive) } +func (m *mapSession) stopFromBatcher() { + if m.cancelChClosed.CompareAndSwap(false, true) { + close(m.cancelCh) + } +} + func (m *mapSession) beforeServeLongPoll() { if m.node.IsEphemeral() { m.h.ephemeralGC.Cancel(m.node.ID) @@ -146,10 +149,7 @@ func (m *mapSession) serveLongPoll() { // Clean up the session when the client disconnects defer func() { - m.cancelChMu.Lock() - m.cancelChOpen = false - close(m.cancelCh) - m.cancelChMu.Unlock() + m.stopFromBatcher() _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) @@ -224,7 +224,7 @@ func (m *mapSession) serveLongPoll() { // adding this before connecting it to the state ensure that // it does not miss any updates that might be sent in the split // time between the node connecting and the batcher being ready. - if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { //nolint:noinlineerr + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer, m.stopFromBatcher); err != nil { //nolint:noinlineerr m.log.Error().Caller().Err(err).Msg("failed to add node to batcher") return } diff --git a/hscontrol/poll_test.go b/hscontrol/poll_test.go index 7ea924fe..7247f0c6 100644 --- a/hscontrol/poll_test.go +++ b/hscontrol/poll_test.go @@ -86,16 +86,22 @@ func (w *delayedSuccessResponseWriter) WriteCount() int { return w.writeCount } -// Reproducer outline: +// TestGitHubIssue3129_TransientlyBlockedWriteDoesNotLeaveLiveStaleSession +// tests the scenario reported in +// https://github.com/juanfont/headscale/issues/3129. +// +// Scenario: // 1. Start a real long-poll session for one node. -// 2. Make the first map write block briefly, so the session stops draining m.ch. +// 2. Block the first map write long enough for the session to stop draining +// its buffered map-response channel. // 3. While that write is blocked, queue enough updates to fill the buffered -// session channel and make the next batcher send hit the stale-send timeout. -// 4. Let the blocked write recover. The stale session should still flush the -// update that was already buffered before its channel was pruned. -// 5. After that buffered update is drained, the stale session must exit instead -// of lingering as an orphaned serveLongPoll goroutine. -func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { +// channel and make the next batcher send hit the stale-send timeout. +// 4. That stale-send path removes the session from the batcher, so without an +// explicit teardown hook the old serveLongPoll goroutine would stay alive +// but stop receiving future updates. +// 5. Release the blocked write and verify the batcher-side stop signal makes +// that stale session exit instead of lingering as an orphaned goroutine. +func TestGitHubIssue3129_TransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { t.Parallel() app := createTestApp(t) @@ -141,7 +147,7 @@ func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { t.Cleanup(func() { dummyCh := make(chan *tailcfg.MapResponse, 1) - _ = app.mapBatcher.AddNode(node.ID, dummyCh, tailcfg.CapabilityVersion(100)) + _ = app.mapBatcher.AddNode(node.ID, dummyCh, tailcfg.CapabilityVersion(100), nil) cancel() select { case <-serveDone: @@ -163,8 +169,8 @@ func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { }() // One update fills the buffered session channel while the first write is blocked. - // The second update then hits the 50ms stale-send timeout and the batcher prunes - // and closes that stale channel. + // The second update then hits the 50ms stale-send timeout, so the batcher prunes + // the stale connection and triggers its stop hook. app.mapBatcher.AddWork(change.SelfUpdate(node.ID), change.SelfUpdate(node.ID)) select { @@ -173,10 +179,6 @@ func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { t.Fatal("expected the blocked write to eventually complete") } - assert.Eventually(t, func() bool { - return writer.WriteCount() >= 2 - }, 2*time.Second, 20*time.Millisecond, "session should flush the update that was already buffered before the stale send") - assert.Eventually(t, func() bool { select { case <-streamsClosed: