From 4aca9d6568aec48173a1f69a21cfca06f6c9790b Mon Sep 17 00:00:00 2001 From: DM Date: Sun, 8 Mar 2026 05:50:23 +0300 Subject: [PATCH] poll: stop stale map sessions through an explicit teardown hook When stale-send cleanup prunes a connection from the batcher, the old serveLongPoll session needs an explicit stop signal. Pass a stop hook into AddNode and trigger it when that connection is removed, so the session exits through its normal cancel path instead of relying on channel closure from the batcher side. --- hscontrol/mapper/batcher.go | 2 +- hscontrol/mapper/batcher_lockfree.go | 30 ++++++++++------ hscontrol/mapper/batcher_test.go | 53 ++++++++++++++-------------- hscontrol/poll.go | 28 +++++++-------- hscontrol/poll_test.go | 32 +++++++++-------- 5 files changed, 79 insertions(+), 66 deletions(-) diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 1f092a9c..5b2adddc 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -36,7 +36,7 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher type Batcher interface { Start() Close() - AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool IsConnected(id types.NodeID) bool ConnectedMap() *xsync.Map[types.NodeID, bool] diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 239a6504..4d35c274 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -54,7 +54,13 @@ type LockFreeBatcher struct { // AddNode registers a new node connection with the batcher and sends an initial map response. // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. -func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { +// The stop function tears down the owning session if this connection is later declared stale. +func (b *LockFreeBatcher) AddNode( + id types.NodeID, + c chan<- *tailcfg.MapResponse, + version tailcfg.CapabilityVersion, + stop func(), +) error { addNodeStart := time.Now() nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() @@ -68,6 +74,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse c: c, version: version, created: now, + stop: stop, } // Initialize last used timestamp newEntry.lastUsed.Store(now.Unix()) @@ -511,6 +518,7 @@ type connectionEntry struct { c chan<- *tailcfg.MapResponse version tailcfg.CapabilityVersion created time.Time + stop func() lastUsed atomic.Int64 // Unix timestamp of last successful send closed atomic.Bool // Indicates if this connection has been closed } @@ -556,27 +564,29 @@ func (mc *multiChannelNodeConn) close() { defer mc.mutex.Unlock() for _, conn := range mc.connections { - mc.closeConnection(conn) + mc.stopConnection(conn) } } -// closeConnection closes connection channel at most once, even if multiple cleanup -// paths race to tear the same session down. -func (mc *multiChannelNodeConn) closeConnection(conn *connectionEntry) { +// stopConnection marks a connection as closed and tears down the owning session +// at most once, even if multiple cleanup paths race to remove it. +func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) { if conn.closed.CompareAndSwap(false, true) { - close(conn.c) + if conn.stop != nil { + conn.stop() + } } } // removeConnectionAtIndexLocked removes the active connection at index. -// If closeChannel is true, it also closes that session's map-response channel. +// If stopConnection is true, it also stops that session. // Caller must hold mc.mutex. -func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, closeChannel bool) *connectionEntry { +func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry { conn := mc.connections[i] mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) - if closeChannel { - mc.closeConnection(conn) + if stopConnection { + mc.stopConnection(conn) } return conn diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 58a2158d..75fbe054 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -39,7 +39,7 @@ type testBatcherWrapper struct { state *state.State } -func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { +func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error { // Mark node as online in state before AddNode to match production behavior // This ensures the NodeStore has correct online status for change processing if t.state != nil { @@ -48,7 +48,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo } // First add the node to the real batcher - err := t.Batcher.AddNode(id, c, version) + err := t.Batcher.AddNode(id, c, version, stop) if err != nil { return err } @@ -543,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100), nil) // Wait for connection to be established assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -652,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { for i := range allNodes { node := &allNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullUpdate()) @@ -821,7 +821,7 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := &testData.Nodes[1] // Test AddNode with real node ID - _ = batcher.AddNode(tn.n.ID, tn.ch, 100) + _ = batcher.AddNode(tn.n.ID, tn.ch, 100, nil) if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") @@ -842,7 +842,7 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, 100*time.Millisecond) // Add the second node and verify update message - _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100) + _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100, nil) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. @@ -1043,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100), nil) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1149,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, 1) wg.Go(func() { - _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100), nil) }) // Add real work during connection chaos @@ -1163,7 +1163,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Go(func() { runtime.Gosched() // Yield to introduce timing variability - _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100), nil) }) // Remove second connection @@ -1254,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100), nil) batcher.AddWork(change.DERPMap()) // Consumer goroutine to validate data and detect channel issues @@ -1380,7 +1380,7 @@ func TestBatcherConcurrentClients(t *testing.T) { node := &stableNodes[i] ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100), nil) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1456,7 +1456,7 @@ func TestBatcherConcurrentClients(t *testing.T) { churningChannelsMutex.Unlock() - _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100), nil) // Consume updates to prevent blocking go func() { @@ -1774,7 +1774,7 @@ func XTestBatcherScalability(t *testing.T) { for i := range testNodes { node := &testNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) connectedNodesMutex.Lock() @@ -1891,6 +1891,7 @@ func XTestBatcherScalability(t *testing.T) { nodeID, channel, tailcfg.CapabilityVersion(100), + nil, ) connectedNodesMutex.Lock() @@ -2155,7 +2156,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time and wait for each to be connected for i := range allNodes { node := &allNodes[i] - _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Wait for node to be connected @@ -2307,7 +2308,7 @@ func TestBatcherRapidReconnection(t *testing.T) { for i := range allNodes { node := &allNodes[i] - err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add node %d: %v", i, err) } @@ -2337,7 +2338,7 @@ func TestBatcherRapidReconnection(t *testing.T) { node := &allNodes[i] newChannels[i] = make(chan *tailcfg.MapResponse, 10) - err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100), nil) if err != nil { t.Errorf("Failed to reconnect node %d: %v", i, err) } @@ -2444,13 +2445,13 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 1: Connect first node with initial connection t.Logf("Phase 1: Connecting node 1 with first connection...") - err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add node1: %v", err) } // Connect second node for comparison - err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)) + err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add node2: %v", err) } @@ -2466,7 +2467,7 @@ func TestBatcherMultiConnection(t *testing.T) { secondChannel := make(chan *tailcfg.MapResponse, 10) - err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add second connection for node1: %v", err) } @@ -2479,7 +2480,7 @@ func TestBatcherMultiConnection(t *testing.T) { thirdChannel := make(chan *tailcfg.MapResponse, 10) - err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100), nil) if err != nil { t.Fatalf("Failed to add third connection for node1: %v", err) } @@ -2718,9 +2719,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { defer node3.cleanup() // Connect all nodes to the batcher - require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))) - require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))) - require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100))) + require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil)) + require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100), nil)) + require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100), nil)) // Wait for all nodes to be connected assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -2813,7 +2814,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { nodeID := testData.Nodes[0].n.ID ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) - require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))) + require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100), nil)) assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode") @@ -2844,8 +2845,8 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) ch2 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) - require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100))) - require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100))) + require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100), nil)) + require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100), nil)) assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode") diff --git a/hscontrol/poll.go b/hscontrol/poll.go index ded86068..3179eb78 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand/v2" "net/http" + "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -14,7 +15,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util/zlog/zf" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/sasha-s/go-deadlock" "tailscale.com/tailcfg" "tailscale.com/util/zstdframe" ) @@ -33,11 +33,9 @@ type mapSession struct { ctx context.Context //nolint:containedctx capVer tailcfg.CapabilityVersion - cancelChMu deadlock.Mutex - - ch chan *tailcfg.MapResponse - cancelCh chan struct{} - cancelChOpen bool + ch chan *tailcfg.MapResponse + cancelCh chan struct{} + cancelChClosed atomic.Bool keepAlive time.Duration keepAliveTicker *time.Ticker @@ -64,9 +62,8 @@ func (h *Headscale) newMapSession( node: node, capVer: req.Version, - ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize), - cancelCh: make(chan struct{}), - cancelChOpen: true, + ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize), + cancelCh: make(chan struct{}), keepAlive: ka, keepAliveTicker: nil, @@ -92,6 +89,12 @@ func (m *mapSession) resetKeepAlive() { m.keepAliveTicker.Reset(m.keepAlive) } +func (m *mapSession) stopFromBatcher() { + if m.cancelChClosed.CompareAndSwap(false, true) { + close(m.cancelCh) + } +} + func (m *mapSession) beforeServeLongPoll() { if m.node.IsEphemeral() { m.h.ephemeralGC.Cancel(m.node.ID) @@ -146,10 +149,7 @@ func (m *mapSession) serveLongPoll() { // Clean up the session when the client disconnects defer func() { - m.cancelChMu.Lock() - m.cancelChOpen = false - close(m.cancelCh) - m.cancelChMu.Unlock() + m.stopFromBatcher() _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) @@ -224,7 +224,7 @@ func (m *mapSession) serveLongPoll() { // adding this before connecting it to the state ensure that // it does not miss any updates that might be sent in the split // time between the node connecting and the batcher being ready. - if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { //nolint:noinlineerr + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer, m.stopFromBatcher); err != nil { //nolint:noinlineerr m.log.Error().Caller().Err(err).Msg("failed to add node to batcher") return } diff --git a/hscontrol/poll_test.go b/hscontrol/poll_test.go index 7ea924fe..7247f0c6 100644 --- a/hscontrol/poll_test.go +++ b/hscontrol/poll_test.go @@ -86,16 +86,22 @@ func (w *delayedSuccessResponseWriter) WriteCount() int { return w.writeCount } -// Reproducer outline: +// TestGitHubIssue3129_TransientlyBlockedWriteDoesNotLeaveLiveStaleSession +// tests the scenario reported in +// https://github.com/juanfont/headscale/issues/3129. +// +// Scenario: // 1. Start a real long-poll session for one node. -// 2. Make the first map write block briefly, so the session stops draining m.ch. +// 2. Block the first map write long enough for the session to stop draining +// its buffered map-response channel. // 3. While that write is blocked, queue enough updates to fill the buffered -// session channel and make the next batcher send hit the stale-send timeout. -// 4. Let the blocked write recover. The stale session should still flush the -// update that was already buffered before its channel was pruned. -// 5. After that buffered update is drained, the stale session must exit instead -// of lingering as an orphaned serveLongPoll goroutine. -func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { +// channel and make the next batcher send hit the stale-send timeout. +// 4. That stale-send path removes the session from the batcher, so without an +// explicit teardown hook the old serveLongPoll goroutine would stay alive +// but stop receiving future updates. +// 5. Release the blocked write and verify the batcher-side stop signal makes +// that stale session exit instead of lingering as an orphaned goroutine. +func TestGitHubIssue3129_TransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { t.Parallel() app := createTestApp(t) @@ -141,7 +147,7 @@ func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { t.Cleanup(func() { dummyCh := make(chan *tailcfg.MapResponse, 1) - _ = app.mapBatcher.AddNode(node.ID, dummyCh, tailcfg.CapabilityVersion(100)) + _ = app.mapBatcher.AddNode(node.ID, dummyCh, tailcfg.CapabilityVersion(100), nil) cancel() select { case <-serveDone: @@ -163,8 +169,8 @@ func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { }() // One update fills the buffered session channel while the first write is blocked. - // The second update then hits the 50ms stale-send timeout and the batcher prunes - // and closes that stale channel. + // The second update then hits the 50ms stale-send timeout, so the batcher prunes + // the stale connection and triggers its stop hook. app.mapBatcher.AddWork(change.SelfUpdate(node.ID), change.SelfUpdate(node.ID)) select { @@ -173,10 +179,6 @@ func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { t.Fatal("expected the blocked write to eventually complete") } - assert.Eventually(t, func() bool { - return writer.WriteCount() >= 2 - }, 2*time.Second, 20*time.Millisecond, "session should flush the update that was already buffered before the stale send") - assert.Eventually(t, func() bool { select { case <-streamsClosed: