diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 88647c60..239a6504 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -556,13 +556,32 @@ func (mc *multiChannelNodeConn) close() { defer mc.mutex.Unlock() for _, conn := range mc.connections { - // Mark as closed before closing the channel to prevent - // send on closed channel panics from concurrent workers - conn.closed.Store(true) + mc.closeConnection(conn) + } +} + +// closeConnection closes connection channel at most once, even if multiple cleanup +// paths race to tear the same session down. +func (mc *multiChannelNodeConn) closeConnection(conn *connectionEntry) { + if conn.closed.CompareAndSwap(false, true) { close(conn.c) } } +// removeConnectionAtIndexLocked removes the active connection at index. +// If closeChannel is true, it also closes that session's map-response channel. +// Caller must hold mc.mutex. +func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, closeChannel bool) *connectionEntry { + conn := mc.connections[i] + mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) + + if closeChannel { + mc.closeConnection(conn) + } + + return conn +} + // addConnection adds a new connection. func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { mutexWaitStart := time.Now() @@ -590,8 +609,7 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR for i, entry := range mc.connections { if entry.c == c { - // Remove this connection - mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) + mc.removeConnectionAtIndexLocked(i, false) mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)). Int("remaining_connections", len(mc.connections)). Msg("successfully removed connection") @@ -673,10 +691,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // Remove failed connections (in reverse order to maintain indices) for i := len(failedConnections) - 1; i >= 0; i-- { idx := failedConnections[i] + entry := mc.removeConnectionAtIndexLocked(idx, true) mc.log.Debug().Caller(). - Str(zf.ConnID, mc.connections[idx].id). - Msg("send: removing failed connection") - mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...) + Str(zf.ConnID, entry.id). + Msg("send: removed failed connection") } mc.updateCount.Add(1) diff --git a/hscontrol/poll_test.go b/hscontrol/poll_test.go new file mode 100644 index 00000000..7ea924fe --- /dev/null +++ b/hscontrol/poll_test.go @@ -0,0 +1,188 @@ +package hscontrol + +import ( + "context" + "net/http" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +type delayedSuccessResponseWriter struct { + header http.Header + + firstWriteDelay time.Duration + + firstWriteStarted chan struct{} + firstWriteStartedOnce sync.Once + + firstWriteFinished chan struct{} + firstWriteFinishedOnce sync.Once + + mu sync.Mutex + writeCount int +} + +func newDelayedSuccessResponseWriter(firstWriteDelay time.Duration) *delayedSuccessResponseWriter { + return &delayedSuccessResponseWriter{ + header: make(http.Header), + firstWriteDelay: firstWriteDelay, + firstWriteStarted: make(chan struct{}), + firstWriteFinished: make(chan struct{}), + } +} + +func (w *delayedSuccessResponseWriter) Header() http.Header { + return w.header +} + +func (w *delayedSuccessResponseWriter) WriteHeader(int) {} + +func (w *delayedSuccessResponseWriter) Write(data []byte) (int, error) { + w.mu.Lock() + w.writeCount++ + writeCount := w.writeCount + w.mu.Unlock() + + if writeCount == 1 { + // Only the first write is delayed. This simulates a transiently wedged map response: + // long enough to make the batcher time out future sends, + // but short enough that the old session can still recover if we leave it alive + w.firstWriteStartedOnce.Do(func() { + close(w.firstWriteStarted) + }) + + time.Sleep(w.firstWriteDelay) + + w.firstWriteFinishedOnce.Do(func() { + close(w.firstWriteFinished) + }) + } + + return len(data), nil +} + +func (w *delayedSuccessResponseWriter) Flush() {} + +func (w *delayedSuccessResponseWriter) FirstWriteStarted() <-chan struct{} { + return w.firstWriteStarted +} + +func (w *delayedSuccessResponseWriter) FirstWriteFinished() <-chan struct{} { + return w.firstWriteFinished +} + +func (w *delayedSuccessResponseWriter) WriteCount() int { + w.mu.Lock() + defer w.mu.Unlock() + + return w.writeCount +} + +// Reproducer outline: +// 1. Start a real long-poll session for one node. +// 2. Make the first map write block briefly, so the session stops draining m.ch. +// 3. While that write is blocked, queue enough updates to fill the buffered +// session channel and make the next batcher send hit the stale-send timeout. +// 4. Let the blocked write recover. The stale session should still flush the +// update that was already buffered before its channel was pruned. +// 5. After that buffered update is drained, the stale session must exit instead +// of lingering as an orphaned serveLongPoll goroutine. +func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + user := app.state.CreateUserForTest("poll-stale-session-user") + createdNode := app.state.CreateRegisteredNodeForTest(user, "poll-stale-session-node") + require.NoError(t, app.state.UpdatePolicyManagerUsersForTest()) + + app.cfg.Tuning.BatchChangeDelay = 20 * time.Millisecond + app.cfg.Tuning.NodeMapSessionBufferedChanSize = 1 + + app.mapBatcher.Close() + require.NoError(t, app.state.Close()) + + reloadedState, err := state.NewState(app.cfg) + require.NoError(t, err) + app.state = reloadedState + + app.mapBatcher = mapper.NewBatcherAndMapper(app.cfg, app.state) + app.mapBatcher.Start() + + t.Cleanup(func() { + app.mapBatcher.Close() + require.NoError(t, app.state.Close()) + }) + + nodeView, ok := app.state.GetNodeByID(createdNode.ID) + require.True(t, ok, "expected node to be present in NodeStore after reload") + require.True(t, nodeView.Valid(), "expected valid node view after reload") + node := nodeView.AsStruct() + + ctx, cancel := context.WithCancel(context.Background()) + writer := newDelayedSuccessResponseWriter(250 * time.Millisecond) + session := app.newMapSession(ctx, tailcfg.MapRequest{ + Stream: true, + Version: tailcfg.CapabilityVersion(100), + }, writer, node) + + serveDone := make(chan struct{}) + go func() { + session.serveLongPoll() + close(serveDone) + }() + + t.Cleanup(func() { + dummyCh := make(chan *tailcfg.MapResponse, 1) + _ = app.mapBatcher.AddNode(node.ID, dummyCh, tailcfg.CapabilityVersion(100)) + cancel() + select { + case <-serveDone: + case <-time.After(2 * time.Second): + } + _ = app.mapBatcher.RemoveNode(node.ID, dummyCh) + }) + + select { + case <-writer.FirstWriteStarted(): + case <-time.After(2 * time.Second): + t.Fatal("expected initial map write to start") + } + + streamsClosed := make(chan struct{}) + go func() { + app.clientStreamsOpen.Wait() + close(streamsClosed) + }() + + // One update fills the buffered session channel while the first write is blocked. + // The second update then hits the 50ms stale-send timeout and the batcher prunes + // and closes that stale channel. + app.mapBatcher.AddWork(change.SelfUpdate(node.ID), change.SelfUpdate(node.ID)) + + select { + case <-writer.FirstWriteFinished(): + case <-time.After(2 * time.Second): + t.Fatal("expected the blocked write to eventually complete") + } + + assert.Eventually(t, func() bool { + return writer.WriteCount() >= 2 + }, 2*time.Second, 20*time.Millisecond, "session should flush the update that was already buffered before the stale send") + + assert.Eventually(t, func() bool { + select { + case <-streamsClosed: + return true + default: + return false + } + }, time.Second, 20*time.Millisecond, "after stale-send cleanup, the stale session should exit") +}