poll: stop stale map sessions through an explicit teardown hook

When stale-send cleanup prunes a connection from the batcher, the old serveLongPoll session needs an explicit stop signal. Pass a stop hook into AddNode and trigger it when that connection is removed, so the session exits through its normal cancel path instead of relying on channel closure from the batcher side.
This commit is contained in:
DM
2026-03-08 05:50:23 +03:00
committed by Kristoffer Dalby
parent 3daf45e88a
commit 4aca9d6568
5 changed files with 79 additions and 66 deletions

View File

@@ -36,7 +36,7 @@ type batcherFunc func(cfg *types.Config, state *state.State) Batcher
type Batcher interface {
Start()
Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]

View File

@@ -54,7 +54,13 @@ type LockFreeBatcher struct {
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
// The stop function tears down the owning session if this connection is later declared stale.
func (b *LockFreeBatcher) AddNode(
id types.NodeID,
c chan<- *tailcfg.MapResponse,
version tailcfg.CapabilityVersion,
stop func(),
) error {
addNodeStart := time.Now()
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
@@ -68,6 +74,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
c: c,
version: version,
created: now,
stop: stop,
}
// Initialize last used timestamp
newEntry.lastUsed.Store(now.Unix())
@@ -511,6 +518,7 @@ type connectionEntry struct {
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
created time.Time
stop func()
lastUsed atomic.Int64 // Unix timestamp of last successful send
closed atomic.Bool // Indicates if this connection has been closed
}
@@ -556,27 +564,29 @@ func (mc *multiChannelNodeConn) close() {
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
mc.closeConnection(conn)
mc.stopConnection(conn)
}
}
// closeConnection closes connection channel at most once, even if multiple cleanup
// paths race to tear the same session down.
func (mc *multiChannelNodeConn) closeConnection(conn *connectionEntry) {
// stopConnection marks a connection as closed and tears down the owning session
// at most once, even if multiple cleanup paths race to remove it.
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
if conn.closed.CompareAndSwap(false, true) {
close(conn.c)
if conn.stop != nil {
conn.stop()
}
}
}
// removeConnectionAtIndexLocked removes the active connection at index.
// If closeChannel is true, it also closes that session's map-response channel.
// If stopConnection is true, it also stops that session.
// Caller must hold mc.mutex.
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, closeChannel bool) *connectionEntry {
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
conn := mc.connections[i]
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
if closeChannel {
mc.closeConnection(conn)
if stopConnection {
mc.stopConnection(conn)
}
return conn

View File

@@ -39,7 +39,7 @@ type testBatcherWrapper struct {
state *state.State
}
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error {
// Mark node as online in state before AddNode to match production behavior
// This ensures the NodeStore has correct online status for change processing
if t.state != nil {
@@ -48,7 +48,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo
}
// First add the node to the real batcher
err := t.Batcher.AddNode(id, c, version)
err := t.Batcher.AddNode(id, c, version, stop)
if err != nil {
return err
}
@@ -543,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start()
// Connect the node to the batcher
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100), nil)
// Wait for connection to be established
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@@ -652,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
for i := range allNodes {
node := &allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullUpdate())
@@ -821,7 +821,7 @@ func TestBatcherBasicOperations(t *testing.T) {
tn2 := &testData.Nodes[1]
// Test AddNode with real node ID
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
_ = batcher.AddNode(tn.n.ID, tn.ch, 100, nil)
if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode")
@@ -842,7 +842,7 @@ func TestBatcherBasicOperations(t *testing.T) {
drainChannelTimeout(tn.ch, 100*time.Millisecond)
// Add the second node and verify update message
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100, nil)
assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected.
@@ -1043,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10)
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100), nil)
// Track update content for validation
var receivedUpdates []*tailcfg.MapResponse
@@ -1149,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
ch1 := make(chan *tailcfg.MapResponse, 1)
wg.Go(func() {
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100), nil)
})
// Add real work during connection chaos
@@ -1163,7 +1163,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Go(func() {
runtime.Gosched() // Yield to introduce timing variability
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100), nil)
})
// Remove second connection
@@ -1254,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100), nil)
batcher.AddWork(change.DERPMap())
// Consumer goroutine to validate data and detect channel issues
@@ -1380,7 +1380,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
node := &stableNodes[i]
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100), nil)
// Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@@ -1456,7 +1456,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
churningChannelsMutex.Unlock()
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100), nil)
// Consume updates to prevent blocking
go func() {
@@ -1774,7 +1774,7 @@ func XTestBatcherScalability(t *testing.T) {
for i := range testNodes {
node := &testNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
connectedNodesMutex.Lock()
@@ -1891,6 +1891,7 @@ func XTestBatcherScalability(t *testing.T) {
nodeID,
channel,
tailcfg.CapabilityVersion(100),
nil,
)
connectedNodesMutex.Lock()
@@ -2155,7 +2156,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Connect nodes one at a time and wait for each to be connected
for i := range allNodes {
node := &allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Wait for node to be connected
@@ -2307,7 +2308,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
for i := range allNodes {
node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil {
t.Fatalf("Failed to add node %d: %v", i, err)
}
@@ -2337,7 +2338,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
node := &allNodes[i]
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100), nil)
if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err)
}
@@ -2444,13 +2445,13 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil {
t.Fatalf("Failed to add node1: %v", err)
}
// Connect second node for comparison
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))
err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil {
t.Fatalf("Failed to add node2: %v", err)
}
@@ -2466,7 +2467,7 @@ func TestBatcherMultiConnection(t *testing.T) {
secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100), nil)
if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err)
}
@@ -2479,7 +2480,7 @@ func TestBatcherMultiConnection(t *testing.T) {
thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100), nil)
if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err)
}
@@ -2718,9 +2719,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) {
defer node3.cleanup()
// Connect all nodes to the batcher
require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)))
require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)))
require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100)))
require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil))
require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100), nil))
require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100), nil))
// Wait for all nodes to be connected
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@@ -2813,7 +2814,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) {
nodeID := testData.Nodes[0].n.ID
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)))
require.NoError(t, lfb.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100), nil))
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode")
@@ -2844,8 +2845,8 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) {
ch1 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
ch2 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100)))
require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100)))
require.NoError(t, lfb.AddNode(nodeID, ch1, tailcfg.CapabilityVersion(100), nil))
require.NoError(t, lfb.AddNode(nodeID, ch2, tailcfg.CapabilityVersion(100), nil))
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.True(c, lfb.IsConnected(nodeID), "node should be connected after AddNode")