mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-11 11:50:30 +01:00
batcher: fix closed panic
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
committed by
Kristoffer Dalby
parent
c4600346f9
commit
616c0e895d
@@ -1,8 +1,8 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var errConnectionClosed = errors.New("connection channel already closed")
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
@@ -26,9 +28,9 @@ type LockFreeBatcher struct {
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
workCh chan work
|
||||
workChOnce sync.Once // Ensures workCh is only closed once
|
||||
done chan struct{}
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
@@ -144,23 +146,20 @@ func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||
b.done = make(chan struct{})
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel = nil
|
||||
// Signal shutdown to all goroutines
|
||||
if b.done != nil {
|
||||
close(b.done)
|
||||
}
|
||||
|
||||
// Only close workCh once
|
||||
select {
|
||||
case <-b.workCh:
|
||||
// Channel is already closed
|
||||
default:
|
||||
// Only close workCh once using sync.Once to prevent races
|
||||
b.workChOnce.Do(func() {
|
||||
close(b.workCh)
|
||||
}
|
||||
})
|
||||
|
||||
// Close the underlying channels supplying the data to the clients.
|
||||
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
|
||||
@@ -186,8 +185,8 @@ func (b *LockFreeBatcher) doWork() {
|
||||
case <-cleanupTicker.C:
|
||||
// Clean up nodes that have been offline for too long
|
||||
b.cleanupOfflineNodes()
|
||||
case <-b.ctx.Done():
|
||||
log.Info().Msg("batcher context done, stopping to feed workers")
|
||||
case <-b.done:
|
||||
log.Info().Msg("batcher done channel closed, stopping to feed workers")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -235,7 +234,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.ctx.Done():
|
||||
case <-b.done:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -258,8 +257,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
}
|
||||
case <-b.ctx.Done():
|
||||
log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker")
|
||||
case <-b.done:
|
||||
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -276,7 +275,7 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.ctx.Done():
|
||||
case <-b.done:
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
@@ -443,7 +442,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.ctx.Done():
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
}
|
||||
}
|
||||
@@ -455,6 +454,7 @@ type connectionEntry struct {
|
||||
version tailcfg.CapabilityVersion
|
||||
created time.Time
|
||||
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
||||
closed atomic.Bool // Indicates if this connection has been closed
|
||||
}
|
||||
|
||||
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
||||
@@ -488,6 +488,9 @@ func (mc *multiChannelNodeConn) close() {
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for _, conn := range mc.connections {
|
||||
// Mark as closed before closing the channel to prevent
|
||||
// send on closed channel panics from concurrent workers
|
||||
conn.closed.Store(true)
|
||||
close(conn.c)
|
||||
}
|
||||
}
|
||||
@@ -620,6 +623,12 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the connection has been closed to prevent send on closed channel panic.
|
||||
// This can happen during shutdown when Close() is called while workers are still processing.
|
||||
if entry.closed.Load() {
|
||||
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
|
||||
}
|
||||
|
||||
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
||||
// This is critical for detecting Docker containers that are forcefully terminated
|
||||
// but still have channels that appear open.
|
||||
|
||||
@@ -147,12 +147,12 @@ type node struct {
|
||||
n *types.Node
|
||||
ch chan *tailcfg.MapResponse
|
||||
|
||||
// Update tracking
|
||||
// Update tracking (all accessed atomically for thread safety)
|
||||
updateCount int64
|
||||
patchCount int64
|
||||
fullCount int64
|
||||
maxPeersCount int
|
||||
lastPeerCount int
|
||||
maxPeersCount atomic.Int64
|
||||
lastPeerCount atomic.Int64
|
||||
stop chan struct{}
|
||||
stopped chan struct{}
|
||||
}
|
||||
@@ -422,18 +422,32 @@ func (n *node) start() {
|
||||
// Track update types
|
||||
if info.IsFull {
|
||||
atomic.AddInt64(&n.fullCount, 1)
|
||||
n.lastPeerCount = info.PeerCount
|
||||
// Update max peers seen
|
||||
if info.PeerCount > n.maxPeersCount {
|
||||
n.maxPeersCount = info.PeerCount
|
||||
n.lastPeerCount.Store(int64(info.PeerCount))
|
||||
// Update max peers seen using compare-and-swap for thread safety
|
||||
for {
|
||||
current := n.maxPeersCount.Load()
|
||||
if int64(info.PeerCount) <= current {
|
||||
break
|
||||
}
|
||||
|
||||
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsPatch {
|
||||
atomic.AddInt64(&n.patchCount, 1)
|
||||
// For patches, we track how many patch items
|
||||
if info.PatchCount > n.maxPeersCount {
|
||||
n.maxPeersCount = info.PatchCount
|
||||
// For patches, we track how many patch items using compare-and-swap
|
||||
for {
|
||||
current := n.maxPeersCount.Load()
|
||||
if int64(info.PatchCount) <= current {
|
||||
break
|
||||
}
|
||||
|
||||
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -465,8 +479,8 @@ func (n *node) cleanup() NodeStats {
|
||||
TotalUpdates: atomic.LoadInt64(&n.updateCount),
|
||||
PatchUpdates: atomic.LoadInt64(&n.patchCount),
|
||||
FullUpdates: atomic.LoadInt64(&n.fullCount),
|
||||
MaxPeersSeen: n.maxPeersCount,
|
||||
LastPeerCount: n.lastPeerCount,
|
||||
MaxPeersSeen: int(n.maxPeersCount.Load()),
|
||||
LastPeerCount: int(n.lastPeerCount.Load()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -665,7 +679,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
connectedCount := 0
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
currentMaxPeers := node.maxPeersCount
|
||||
|
||||
currentMaxPeers := int(node.maxPeersCount.Load())
|
||||
if currentMaxPeers >= expectedPeers {
|
||||
connectedCount++
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user