mapper/batcher: restructure internals for correctness

Move per-node pending changes from a shared xsync.Map on the batcher
into multiChannelNodeConn, protected by a dedicated mutex. The new
appendPending/drainPending methods provide atomic append and drain
operations, eliminating data races in addToBatch and
processBatchedChanges.

Add sync.Once to multiChannelNodeConn.close() to make it idempotent,
preventing panics from concurrent close calls on the same channel.

Add started atomic.Bool to guard Start() against being called
multiple times, preventing orphaned goroutines.

Add comprehensive concurrency tests validating these changes.
This commit is contained in:
Kristoffer Dalby
2026-03-13 13:31:39 +00:00
parent 21e02e5d1f
commit 57070680a5
7 changed files with 1836 additions and 69 deletions

View File

@@ -52,10 +52,9 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
tick: time.NewTicker(batchTime), tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited. // The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200), workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](), connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
} }
} }

View File

@@ -150,13 +150,12 @@ func BenchmarkUpdateSentPeers(b *testing.B) {
// helper, it doesn't register cleanup and suppresses logging. // helper, it doesn't register cleanup and suppresses logging.
func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) { func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) {
b := &LockFreeBatcher{ b := &LockFreeBatcher{
tick: time.NewTicker(1 * time.Hour), // never fires during bench tick: time.NewTicker(1 * time.Hour), // never fires during bench
workers: 4, workers: 4,
workCh: make(chan work, 4*200), workCh: make(chan work, 4*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](), connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), done: make(chan struct{}),
done: make(chan struct{}),
} }
channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount) channels := make(map[types.NodeID]chan *tailcfg.MapResponse, nodeCount)
@@ -204,8 +203,8 @@ func BenchmarkAddToBatch_Broadcast(b *testing.B) {
for range b.N { for range b.N {
batcher.addToBatch(ch) batcher.addToBatch(ch)
// Clear pending to avoid unbounded growth // Clear pending to avoid unbounded growth
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool { batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
batcher.pendingChanges.Delete(id) nc.drainPending()
return true return true
}) })
} }
@@ -242,8 +241,8 @@ func BenchmarkAddToBatch_Targeted(b *testing.B) {
batcher.addToBatch(ch) batcher.addToBatch(ch)
// Clear pending periodically to avoid growth // Clear pending periodically to avoid growth
if i%100 == 99 { if i%100 == 99 {
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool { batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
batcher.pendingChanges.Delete(id) nc.drainPending()
return true return true
}) })
} }
@@ -298,7 +297,9 @@ func BenchmarkProcessBatchedChanges(b *testing.B) {
b.StopTimer() b.StopTimer()
// Seed pending changes // Seed pending changes
for i := 1; i <= nodeCount; i++ { for i := 1; i <= nodeCount; i++ {
batcher.pendingChanges.Store(types.NodeID(i), []change.Change{change.DERPMap()}) //nolint:gosec // benchmark if nc, ok := batcher.nodes.Load(types.NodeID(i)); ok { //nolint:gosec // benchmark
nc.appendPending(change.DERPMap())
}
} }
b.StartTimer() b.StartTimer()
@@ -411,8 +412,8 @@ func BenchmarkConcurrentAddToBatch(b *testing.B) {
case <-batcher.done: case <-batcher.done:
return return
default: default:
batcher.pendingChanges.Range(func(id types.NodeID, _ []change.Change) bool { batcher.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
batcher.pendingChanges.Delete(id) nc.drainPending()
return true return true
}) })
time.Sleep(time.Millisecond) //nolint:forbidigo // benchmark drain loop time.Sleep(time.Millisecond) //nolint:forbidigo // benchmark drain loop
@@ -646,7 +647,7 @@ func BenchmarkAddNode(b *testing.B) {
// Connect all nodes (measuring AddNode cost) // Connect all nodes (measuring AddNode cost)
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
} }
b.StopTimer() b.StopTimer()
@@ -707,7 +708,7 @@ func BenchmarkFullPipeline(b *testing.B) {
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil { if err != nil {
b.Fatalf("failed to add node %d: %v", i, err) b.Fatalf("failed to add node %d: %v", i, err)
} }
@@ -762,7 +763,7 @@ func BenchmarkMapResponseFromChange(b *testing.B) {
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil { if err != nil {
b.Fatalf("failed to add node %d: %v", i, err) b.Fatalf("failed to add node %d: %v", i, err)
} }

File diff suppressed because it is too large Load Diff

View File

@@ -41,8 +41,7 @@ type LockFreeBatcher struct {
done chan struct{} done chan struct{}
doneOnce sync.Once // Ensures done is only closed once doneOnce sync.Once // Ensures done is only closed once
// Batching state started atomic.Bool // Ensures Start() is only called once
pendingChanges *xsync.Map[types.NodeID, []change.Change]
// Metrics // Metrics
totalNodes atomic.Int64 totalNodes atomic.Int64
@@ -167,7 +166,12 @@ func (b *LockFreeBatcher) AddWork(r ...change.Change) {
} }
func (b *LockFreeBatcher) Start() { func (b *LockFreeBatcher) Start() {
if !b.started.CompareAndSwap(false, true) {
return
}
b.done = make(chan struct{}) b.done = make(chan struct{})
go b.doWork() go b.doWork()
} }
@@ -336,15 +340,16 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
} }
b.connected.Delete(removedID) b.connected.Delete(removedID)
b.pendingChanges.Delete(removedID)
} }
} }
// Short circuit if any of the changes is a full update, which // Short circuit if any of the changes is a full update, which
// means we can skip sending individual changes. // means we can skip sending individual changes.
if change.HasFull(changes) { if change.HasFull(changes) {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()}) nc.pendingMu.Lock()
nc.pending = []change.Change{change.FullUpdate()}
nc.pendingMu.Unlock()
return true return true
}) })
@@ -356,20 +361,18 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
// Handle targeted changes - send only to the specific node // Handle targeted changes - send only to the specific node
for _, ch := range targeted { for _, ch := range targeted {
pending, _ := b.pendingChanges.LoadOrStore(ch.TargetNode, []change.Change{}) if nc, ok := b.nodes.Load(ch.TargetNode); ok {
pending = append(pending, ch) nc.appendPending(ch)
b.pendingChanges.Store(ch.TargetNode, pending) }
} }
// Handle broadcast changes - send to all nodes, filtering as needed // Handle broadcast changes - send to all nodes, filtering as needed
if len(broadcast) > 0 { if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
filtered := change.FilterForNode(nodeID, broadcast) filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 { if len(filtered) > 0 {
pending, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{}) nc.appendPending(filtered...)
pending = append(pending, filtered...)
b.pendingChanges.Store(nodeID, pending)
} }
return true return true
@@ -379,12 +382,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
// processBatchedChanges processes all pending batched changes. // processBatchedChanges processes all pending batched changes.
func (b *LockFreeBatcher) processBatchedChanges() { func (b *LockFreeBatcher) processBatchedChanges() {
if b.pendingChanges == nil { b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
return pending := nc.drainPending()
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, pending []change.Change) bool {
if len(pending) == 0 { if len(pending) == 0 {
return true return true
} }
@@ -394,9 +393,6 @@ func (b *LockFreeBatcher) processBatchedChanges() {
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
} }
// Clear the pending changes for this node
b.pendingChanges.Delete(nodeID)
return true return true
}) })
} }
@@ -532,6 +528,13 @@ type multiChannelNodeConn struct {
mutex sync.RWMutex mutex sync.RWMutex
connections []*connectionEntry connections []*connectionEntry
// pendingMu protects pending changes independently of the connection mutex.
// This avoids contention between addToBatch (which appends changes) and
// send() (which sends data to connections).
pendingMu sync.Mutex
pending []change.Change
closeOnce sync.Once
updateCount atomic.Int64 updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node. // lastSentPeers tracks which peers were last sent to this node.
@@ -560,12 +563,14 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
} }
func (mc *multiChannelNodeConn) close() { func (mc *multiChannelNodeConn) close() {
mc.mutex.Lock() mc.closeOnce.Do(func() {
defer mc.mutex.Unlock() mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections { for _, conn := range mc.connections {
mc.stopConnection(conn) mc.stopConnection(conn)
} }
})
} }
// stopConnection marks a connection as closed and tears down the owning session // stopConnection marks a connection as closed and tears down the owning session
@@ -647,6 +652,25 @@ func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
return len(mc.connections) return len(mc.connections)
} }
// appendPending appends changes to this node's pending change list.
// Thread-safe via pendingMu; does not contend with the connection mutex.
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
mc.pendingMu.Lock()
mc.pending = append(mc.pending, changes...)
mc.pendingMu.Unlock()
}
// drainPending atomically removes and returns all pending changes.
// Returns nil if there are no pending changes.
func (mc *multiChannelNodeConn) drainPending() []change.Change {
mc.pendingMu.Lock()
p := mc.pending
mc.pending = nil
mc.pendingMu.Unlock()
return p
}
// send broadcasts data to all active connections for the node. // send broadcasts data to all active connections for the node.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil { if data == nil {

View File

@@ -2302,8 +2302,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("=== RAPID RECONNECTION TEST ===") t.Logf("=== RAPID RECONNECTION TEST ===")
t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes)) t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes))
// Phase 1: Connect all nodes initially // Connect all nodes initially.
t.Logf("Phase 1: Connecting all nodes...") t.Logf("Connecting all nodes...")
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
@@ -2321,8 +2321,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
} }
}, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle") }, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle")
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) // Rapid disconnect ALL nodes (simulating nodes going down).
t.Logf("Phase 2: Rapid disconnect all nodes...") t.Logf("Rapid disconnect all nodes...")
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
@@ -2330,8 +2330,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("Node %d RemoveNode result: %t", i, removed) t.Logf("Node %d RemoveNode result: %t", i, removed)
} }
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) // Rapid reconnect with NEW channels (simulating nodes coming back up).
t.Logf("Phase 3: Rapid reconnect with new channels...") t.Logf("Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i := range allNodes { for i := range allNodes {
@@ -2351,8 +2351,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
} }
}, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle") }, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle")
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR // Check debug status after reconnection.
t.Logf("Phase 4: Checking debug status...") t.Logf("Checking debug status...")
if debugBatcher, ok := batcher.(interface { if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any Debug() map[types.NodeID]any
@@ -2396,8 +2396,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("Batcher does not implement Debug() method") t.Logf("Batcher does not implement Debug() method")
} }
// Phase 5: Test if "disconnected" nodes can actually receive updates // Test if "disconnected" nodes can actually receive updates.
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...") t.Logf("Testing if nodes can receive updates despite debug status...")
// Send a change that should reach all nodes // Send a change that should reach all nodes
batcher.AddWork(change.DERPMap()) batcher.AddWork(change.DERPMap())
@@ -2442,8 +2442,8 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Logf("=== MULTI-CONNECTION TEST ===") t.Logf("=== MULTI-CONNECTION TEST ===")
// Phase 1: Connect first node with initial connection // Connect first node with initial connection.
t.Logf("Phase 1: Connecting node 1 with first connection...") t.Logf("Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil) err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100), nil)
if err != nil { if err != nil {
@@ -2462,8 +2462,8 @@ func TestBatcherMultiConnection(t *testing.T) {
assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected") assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected")
}, time.Second, 10*time.Millisecond, "waiting for initial connections") }, time.Second, 10*time.Millisecond, "waiting for initial connections")
// Phase 2: Add second connection for node1 (multi-connection scenario) // Add second connection for node1 (multi-connection scenario).
t.Logf("Phase 2: Adding second connection for node 1...") t.Logf("Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10) secondChannel := make(chan *tailcfg.MapResponse, 10)
@@ -2475,8 +2475,8 @@ func TestBatcherMultiConnection(t *testing.T) {
// Yield to allow connection to be processed // Yield to allow connection to be processed
runtime.Gosched() runtime.Gosched()
// Phase 3: Add third connection for node1 // Add third connection for node1.
t.Logf("Phase 3: Adding third connection for node 1...") t.Logf("Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10) thirdChannel := make(chan *tailcfg.MapResponse, 10)
@@ -2488,8 +2488,8 @@ func TestBatcherMultiConnection(t *testing.T) {
// Yield to allow connection to be processed // Yield to allow connection to be processed
runtime.Gosched() runtime.Gosched()
// Phase 4: Verify debug status shows correct connection count // Verify debug status shows correct connection count.
t.Logf("Phase 4: Verifying debug status shows multiple connections...") t.Logf("Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface { if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any Debug() map[types.NodeID]any
@@ -2525,8 +2525,8 @@ func TestBatcherMultiConnection(t *testing.T) {
} }
} }
// Phase 5: Send update and verify ALL connections receive it // Send update and verify ALL connections receive it.
t.Logf("Phase 5: Testing update distribution to all connections...") t.Logf("Testing update distribution to all connections...")
// Clear any existing updates from all channels // Clear any existing updates from all channels
clearChannel := func(ch chan *tailcfg.MapResponse) { clearChannel := func(ch chan *tailcfg.MapResponse) {
@@ -2591,8 +2591,8 @@ func TestBatcherMultiConnection(t *testing.T) {
connection1Received, connection2Received, connection3Received) connection1Received, connection2Received, connection3Received)
} }
// Phase 6: Test connection removal and verify remaining connections still work // Test connection removal and verify remaining connections still work.
t.Logf("Phase 6: Testing connection removal...") t.Logf("Testing connection removal...")
// Remove the second connection // Remove the second connection
removed := batcher.RemoveNode(node1.n.ID, secondChannel) removed := batcher.RemoveNode(node1.n.ID, secondChannel)

6
hscontrol/util/norace.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build !race
package util
// RaceEnabled is true when the race detector is active.
const RaceEnabled = false

6
hscontrol/util/race.go Normal file
View File

@@ -0,0 +1,6 @@
//go:build race
package util
// RaceEnabled is true when the race detector is active.
const RaceEnabled = true