diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index c4e26810..57fe2ed5 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -171,8 +171,12 @@ type workResult struct { } // work represents a unit of work to be processed by workers. +// All pending changes for a node are bundled into a single work item +// so that one worker processes them sequentially. This prevents +// out-of-order MapResponse delivery and races on lastSentPeers +// that occur when multiple workers process changes for the same node. type work struct { - c change.Change + changes []change.Change nodeID types.NodeID resultCh chan<- workResult // optional channel for synchronous operations } @@ -417,29 +421,33 @@ func (b *Batcher) worker(workerID int) { b.workProcessed.Add(1) - // If the resultCh is set, it means that this is a work request - // where there is a blocking function waiting for the map that - // is being generated. - // This is used for synchronous map generation. + // Synchronous path: a caller is blocking on resultCh + // waiting for a generated MapResponse (used by AddNode + // for the initial map). Always contains a single change. if w.resultCh != nil { var result workResult if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { + // Hold workMu so concurrent async work for this + // node waits until the initial map is sent. + nc.workMu.Lock() + var err error - result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + result.mapResponse, err = generateMapResponse(nc, b.mapper, w.changes[0]) result.err = err if result.err != nil { b.workErrors.Add(1) wlog.Error().Err(result.err). Uint64(zf.NodeID, w.nodeID.Uint64()). - Str(zf.Reason, w.c.Reason). + Str(zf.Reason, w.changes[0].Reason). Msg("failed to generate map response for synchronous work") } else if result.mapResponse != nil { - // Update peer tracking for synchronous responses too nc.updateSentPeers(result.mapResponse) } + + nc.workMu.Unlock() } else { result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID) @@ -449,7 +457,6 @@ func (b *Batcher) worker(workerID int) { Msg("node not found for synchronous work") } - // Send result select { case w.resultCh <- result: case <-b.done: @@ -459,20 +466,24 @@ func (b *Batcher) worker(workerID int) { continue } - // If resultCh is nil, this is an asynchronous work request - // that should be processed and sent to the node instead of - // returned to the caller. + // Async path: process all bundled changes sequentially. + // workMu ensures that if another worker picks up the next + // tick's bundle for the same node, it waits until we + // finish — preventing out-of-order delivery and races + // on lastSentPeers (Clear+Store vs Range). if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { - // Apply change to node - this will handle offline nodes gracefully - // and queue work for when they reconnect - err := nc.change(w.c) - if err != nil { - b.workErrors.Add(1) - wlog.Error().Err(err). - Uint64(zf.NodeID, w.nodeID.Uint64()). - Str(zf.Reason, w.c.Reason). - Msg("failed to apply change") + nc.workMu.Lock() + for _, ch := range w.changes { + err := nc.change(ch) + if err != nil { + b.workErrors.Add(1) + wlog.Error().Err(err). + Uint64(zf.NodeID, w.nodeID.Uint64()). + Str(zf.Reason, ch.Reason). + Msg("failed to apply change") + } } + nc.workMu.Unlock() } case <-b.done: wlog.Debug().Msg("batcher shutting down, exiting worker") @@ -581,10 +592,10 @@ func (b *Batcher) processBatchedChanges() { return true } - // Send all batched changes for this node - for _, ch := range pending { - b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) - } + // Queue a single work item containing all pending changes. + // One item per node ensures a single worker processes them + // sequentially, preventing out-of-order delivery. + b.queueWork(work{changes: pending, nodeID: nodeID, resultCh: nil}) return true }) @@ -721,7 +732,7 @@ func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tai resultCh := make(chan workResult, 1) // Queue the work with a result channel using the safe queueing method - b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh}) + b.queueWork(work{changes: []change.Change{ch}, nodeID: id, resultCh: resultCh}) // Wait for the result select { diff --git a/hscontrol/mapper/batcher_concurrency_test.go b/hscontrol/mapper/batcher_concurrency_test.go index d6cb0f83..61da0d38 100644 --- a/hscontrol/mapper/batcher_concurrency_test.go +++ b/hscontrol/mapper/batcher_concurrency_test.go @@ -390,20 +390,26 @@ func TestProcessBatchedChanges_ConcurrentAdd_NoDataLoss(t *testing.T) { // One final process to flush any remaining lb.b.processBatchedChanges() - // Count how many work items were actually queued - queuedWork := len(lb.b.workCh) + // Count total changes across all bundled work items in the channel. + // Each work item may contain multiple changes since processBatchedChanges + // bundles all pending changes per node into a single work item. + queuedChanges := 0 + + workItems := len(lb.b.workCh) + for range workItems { + w := <-lb.b.workCh + queuedChanges += len(w.changes) + } // Also count any still-pending remaining := len(getPendingForNode(lb.b, types.NodeID(1))) - total := queuedWork + remaining + total := queuedChanges + remaining added := int(addedCount.Load()) - t.Logf("added=%d, queued_work=%d, still_pending=%d, total_accounted=%d, lost=%d", - added, queuedWork, remaining, total, added-total) + t.Logf("added=%d, queued_changes=%d (in %d work items), still_pending=%d, total_accounted=%d, lost=%d", + added, queuedChanges, workItems, remaining, total, added-total) // Every added change must either be in the work queue or still pending. - // The Range→Delete race in processBatchedChanges causes inconsistency: - // changes can be lost (total < added) or duplicated (total > added). assert.Equal(t, added, total, "processBatchedChanges has %d inconsistent changes (%d added vs %d accounted) "+ "under concurrent access", @@ -422,6 +428,114 @@ func TestProcessBatchedChanges_EmptyPending(t *testing.T) { "no work should be queued when there are no pending changes") } +// TestProcessBatchedChanges_BundlesChangesPerNode verifies that multiple +// pending changes for the same node are bundled into a single work item. +// This prevents out-of-order delivery when different workers pick up +// separate changes for the same node. +func TestProcessBatchedChanges_BundlesChangesPerNode(t *testing.T) { + lb := setupLightweightBatcher(t, 3, 10) + defer lb.cleanup() + + // Add multiple pending changes for node 1 + if nc, ok := lb.b.nodes.Load(types.NodeID(1)); ok { + nc.appendPending(change.DERPMap()) + nc.appendPending(change.DNSConfig()) + nc.appendPending(change.PolicyOnly()) + } + // Single change for node 2 + if nc, ok := lb.b.nodes.Load(types.NodeID(2)); ok { + nc.appendPending(change.DERPMap()) + } + + lb.b.processBatchedChanges() + + // Should produce exactly 2 work items: one per node with pending changes. + // Node 3 had no pending changes, so no work item for it. + assert.Len(t, lb.b.workCh, 2, + "should produce one work item per node, not per change") + + // Drain and verify the bundled changes are intact + totalChanges := 0 + + for range 2 { + w := <-lb.b.workCh + + totalChanges += len(w.changes) + if w.nodeID == types.NodeID(1) { + assert.Len(t, w.changes, 3, + "node 1's work item should contain all 3 changes") + } else { + assert.Len(t, w.changes, 1, + "node 2's work item should contain 1 change") + } + } + + assert.Equal(t, 4, totalChanges, "total changes across all work items") +} + +// TestWorkMu_PreventsInterTickRace verifies that workMu serializes change +// processing across consecutive batch ticks. Without workMu, two workers +// could process bundles from tick N and tick N+1 concurrently for the same +// node, causing out-of-order delivery and races on lastSentPeers. +func TestWorkMu_PreventsInterTickRace(t *testing.T) { + zerolog.SetGlobalLevel(zerolog.Disabled) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + mc := newMultiChannelNodeConn(1, nil) + ch := make(chan *tailcfg.MapResponse, 100) + entry := &connectionEntry{ + id: "test", + c: ch, + version: tailcfg.CapabilityVersion(100), + created: time.Now(), + } + entry.lastUsed.Store(time.Now().Unix()) + mc.addConnection(entry) + + // Track the order in which work completes + var ( + order []int + mu sync.Mutex + ) + + record := func(id int) { + mu.Lock() + + order = append(order, id) + mu.Unlock() + } + + var wg sync.WaitGroup + + // Simulate two workers grabbing consecutive tick bundles. + // Worker 1 holds workMu and sleeps, worker 2 must wait. + wg.Go(func() { + mc.workMu.Lock() + // Simulate processing time for tick N's bundle + time.Sleep(50 * time.Millisecond) //nolint:forbidigo + record(1) + mc.workMu.Unlock() + }) + + // Small delay so worker 1 grabs the lock first + time.Sleep(5 * time.Millisecond) //nolint:forbidigo + + wg.Go(func() { + mc.workMu.Lock() + record(2) + mc.workMu.Unlock() + }) + + wg.Wait() + + mu.Lock() + defer mu.Unlock() + + require.Len(t, order, 2) + assert.Equal(t, 1, order[0], "worker 1 (tick N) should complete first") + assert.Equal(t, 2, order[1], "worker 2 (tick N+1) should complete second") +} + // ============================================================================ // cleanupOfflineNodes Tests // ============================================================================ @@ -562,8 +676,8 @@ func TestBatcher_QueueWorkDuringShutdown(t *testing.T) { go func() { lb.b.queueWork(work{ - c: change.DERPMap(), - nodeID: types.NodeID(1), + changes: []change.Change{change.DERPMap()}, + nodeID: types.NodeID(1), }) close(done) }() @@ -790,8 +904,8 @@ func TestBug5_WorkerPanicKillsWorkerPermanently(t *testing.T) { // Without the nil guard, this would panic: nc.change(w.c) on nil nc. for range 10 { lb.b.queueWork(work{ - c: change.DERPMap(), - nodeID: nilNodeID, + changes: []change.Change{change.DERPMap()}, + nodeID: nilNodeID, }) } @@ -801,7 +915,7 @@ func TestBug5_WorkerPanicKillsWorkerPermanently(t *testing.T) { for range 5 { resultCh := make(chan workResult, 1) lb.b.queueWork(work{ - c: change.DERPMap(), + changes: []change.Change{change.DERPMap()}, nodeID: nilNodeID, resultCh: resultCh, }) @@ -823,8 +937,8 @@ func TestBug5_WorkerPanicKillsWorkerPermanently(t *testing.T) { beforeValid := lb.b.workProcessed.Load() for range 5 { lb.b.queueWork(work{ - c: change.DERPMap(), - nodeID: types.NodeID(1), + changes: []change.Change{change.DERPMap()}, + nodeID: types.NodeID(1), }) } diff --git a/hscontrol/mapper/node_conn.go b/hscontrol/mapper/node_conn.go index f67ef727..1ee9d3e6 100644 --- a/hscontrol/mapper/node_conn.go +++ b/hscontrol/mapper/node_conn.go @@ -42,6 +42,13 @@ type multiChannelNodeConn struct { pendingMu sync.Mutex pending []change.Change + // workMu serializes change processing for this node across batch ticks. + // Without this, two workers could process consecutive ticks' bundles + // concurrently, causing out-of-order MapResponse delivery and races + // on lastSentPeers (Clear+Store in updateSentPeers vs Range in + // computePeerDiff). + workMu sync.Mutex + closeOnce sync.Once updateCount atomic.Int64