mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-25 10:08:41 +02:00
batcher: fix closed panic
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
committed by
Kristoffer Dalby
parent
c4600346f9
commit
616c0e895d
@@ -1,8 +1,8 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -16,6 +16,8 @@ import (
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var errConnectionClosed = errors.New("connection channel already closed")
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
@@ -26,9 +28,9 @@ type LockFreeBatcher struct {
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
workCh chan work
|
||||
workChOnce sync.Once // Ensures workCh is only closed once
|
||||
done chan struct{}
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
@@ -144,23 +146,20 @@ func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||
b.done = make(chan struct{})
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
b.cancel = nil
|
||||
// Signal shutdown to all goroutines
|
||||
if b.done != nil {
|
||||
close(b.done)
|
||||
}
|
||||
|
||||
// Only close workCh once
|
||||
select {
|
||||
case <-b.workCh:
|
||||
// Channel is already closed
|
||||
default:
|
||||
// Only close workCh once using sync.Once to prevent races
|
||||
b.workChOnce.Do(func() {
|
||||
close(b.workCh)
|
||||
}
|
||||
})
|
||||
|
||||
// Close the underlying channels supplying the data to the clients.
|
||||
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
|
||||
@@ -186,8 +185,8 @@ func (b *LockFreeBatcher) doWork() {
|
||||
case <-cleanupTicker.C:
|
||||
// Clean up nodes that have been offline for too long
|
||||
b.cleanupOfflineNodes()
|
||||
case <-b.ctx.Done():
|
||||
log.Info().Msg("batcher context done, stopping to feed workers")
|
||||
case <-b.done:
|
||||
log.Info().Msg("batcher done channel closed, stopping to feed workers")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -235,7 +234,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.ctx.Done():
|
||||
case <-b.done:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -258,8 +257,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
}
|
||||
case <-b.ctx.Done():
|
||||
log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker")
|
||||
case <-b.done:
|
||||
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -276,7 +275,7 @@ func (b *LockFreeBatcher) queueWork(w work) {
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.ctx.Done():
|
||||
case <-b.done:
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
@@ -443,7 +442,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.ctx.Done():
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
}
|
||||
}
|
||||
@@ -455,6 +454,7 @@ type connectionEntry struct {
|
||||
version tailcfg.CapabilityVersion
|
||||
created time.Time
|
||||
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
||||
closed atomic.Bool // Indicates if this connection has been closed
|
||||
}
|
||||
|
||||
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
||||
@@ -488,6 +488,9 @@ func (mc *multiChannelNodeConn) close() {
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for _, conn := range mc.connections {
|
||||
// Mark as closed before closing the channel to prevent
|
||||
// send on closed channel panics from concurrent workers
|
||||
conn.closed.Store(true)
|
||||
close(conn.c)
|
||||
}
|
||||
}
|
||||
@@ -620,6 +623,12 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the connection has been closed to prevent send on closed channel panic.
|
||||
// This can happen during shutdown when Close() is called while workers are still processing.
|
||||
if entry.closed.Load() {
|
||||
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
|
||||
}
|
||||
|
||||
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
||||
// This is critical for detecting Docker containers that are forcefully terminated
|
||||
// but still have channels that appear open.
|
||||
|
||||
Reference in New Issue
Block a user