batcher: fix closed panic

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-12-12 14:28:09 +01:00
committed by Kristoffer Dalby
parent c4600346f9
commit 616c0e895d
2 changed files with 58 additions and 34 deletions

View File

@@ -1,8 +1,8 @@
package mapper
import (
"context"
"crypto/rand"
"errors"
"fmt"
"sync"
"sync/atomic"
@@ -16,6 +16,8 @@ import (
"tailscale.com/types/ptr"
)
var errConnectionClosed = errors.New("connection channel already closed")
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
tick *time.Ticker
@@ -26,9 +28,9 @@ type LockFreeBatcher struct {
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
ctx context.Context
cancel context.CancelFunc
workCh chan work
workChOnce sync.Once // Ensures workCh is only closed once
done chan struct{}
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
@@ -144,23 +146,20 @@ func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
}
func (b *LockFreeBatcher) Start() {
b.ctx, b.cancel = context.WithCancel(context.Background())
b.done = make(chan struct{})
go b.doWork()
}
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
b.cancel = nil
// Signal shutdown to all goroutines
if b.done != nil {
close(b.done)
}
// Only close workCh once
select {
case <-b.workCh:
// Channel is already closed
default:
// Only close workCh once using sync.Once to prevent races
b.workChOnce.Do(func() {
close(b.workCh)
}
})
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
@@ -186,8 +185,8 @@ func (b *LockFreeBatcher) doWork() {
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.ctx.Done():
log.Info().Msg("batcher context done, stopping to feed workers")
case <-b.done:
log.Info().Msg("batcher done channel closed, stopping to feed workers")
return
}
}
@@ -235,7 +234,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
// Send result
select {
case w.resultCh <- result:
case <-b.ctx.Done():
case <-b.done:
return
}
@@ -258,8 +257,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
Msg("failed to apply change")
}
}
case <-b.ctx.Done():
log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker")
case <-b.done:
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker")
return
}
}
@@ -276,7 +275,7 @@ func (b *LockFreeBatcher) queueWork(w work) {
select {
case b.workCh <- w:
// Successfully queued
case <-b.ctx.Done():
case <-b.done:
// Batcher is shutting down
return
}
@@ -443,7 +442,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.ctx.Done():
case <-b.done:
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
}
}
@@ -455,6 +454,7 @@ type connectionEntry struct {
version tailcfg.CapabilityVersion
created time.Time
lastUsed atomic.Int64 // Unix timestamp of last successful send
closed atomic.Bool // Indicates if this connection has been closed
}
// multiChannelNodeConn manages multiple concurrent connections for a single node.
@@ -488,6 +488,9 @@ func (mc *multiChannelNodeConn) close() {
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
// Mark as closed before closing the channel to prevent
// send on closed channel panics from concurrent workers
conn.closed.Store(true)
close(conn.c)
}
}
@@ -620,6 +623,12 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
return nil
}
// Check if the connection has been closed to prevent send on closed channel panic.
// This can happen during shutdown when Close() is called while workers are still processing.
if entry.closed.Load() {
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
}
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.

View File

@@ -147,12 +147,12 @@ type node struct {
n *types.Node
ch chan *tailcfg.MapResponse
// Update tracking
// Update tracking (all accessed atomically for thread safety)
updateCount int64
patchCount int64
fullCount int64
maxPeersCount int
lastPeerCount int
maxPeersCount atomic.Int64
lastPeerCount atomic.Int64
stop chan struct{}
stopped chan struct{}
}
@@ -422,18 +422,32 @@ func (n *node) start() {
// Track update types
if info.IsFull {
atomic.AddInt64(&n.fullCount, 1)
n.lastPeerCount = info.PeerCount
// Update max peers seen
if info.PeerCount > n.maxPeersCount {
n.maxPeersCount = info.PeerCount
n.lastPeerCount.Store(int64(info.PeerCount))
// Update max peers seen using compare-and-swap for thread safety
for {
current := n.maxPeersCount.Load()
if int64(info.PeerCount) <= current {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
break
}
}
}
if info.IsPatch {
atomic.AddInt64(&n.patchCount, 1)
// For patches, we track how many patch items
if info.PatchCount > n.maxPeersCount {
n.maxPeersCount = info.PatchCount
// For patches, we track how many patch items using compare-and-swap
for {
current := n.maxPeersCount.Load()
if int64(info.PatchCount) <= current {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
break
}
}
}
}
@@ -465,8 +479,8 @@ func (n *node) cleanup() NodeStats {
TotalUpdates: atomic.LoadInt64(&n.updateCount),
PatchUpdates: atomic.LoadInt64(&n.patchCount),
FullUpdates: atomic.LoadInt64(&n.fullCount),
MaxPeersSeen: n.maxPeersCount,
LastPeerCount: n.lastPeerCount,
MaxPeersSeen: int(n.maxPeersCount.Load()),
LastPeerCount: int(n.lastPeerCount.Load()),
}
}
@@ -665,7 +679,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
connectedCount := 0
for i := range allNodes {
node := &allNodes[i]
currentMaxPeers := node.maxPeersCount
currentMaxPeers := int(node.maxPeersCount.Load())
if currentMaxPeers >= expectedPeers {
connectedCount++
}