Files
headscale/hscontrol/mapper/batcher_lockfree.go
Kristoffer Dalby 3ebe4d99c1 mapper/batcher: reduce lock contention with two-phase send
Rewrite multiChannelNodeConn.send() to use a two-phase approach:
1. RLock: snapshot connections slice (cheap pointer copy)
2. Unlock: send to all connections (50ms timeouts happen here)
3. Lock: remove failed connections by pointer identity

Previously, send() held the write lock for the entire duration of
sending to all connections. With N stale connections each timing out
at 50ms, this blocked addConnection/removeConnection for N*50ms.
The two-phase approach holds the lock only for O(N) pointer
operations, not for N*50ms I/O waits.
2026-03-14 02:52:28 -07:00

989 lines
29 KiB
Go

package mapper
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
// LockFreeBatcher errors.
var (
errConnectionClosed = errors.New("connection channel already closed")
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
ErrBatcherShuttingDown = errors.New("batcher shutting down")
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
tick *time.Ticker
mapper *mapper
workers int
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
workChOnce sync.Once // Ensures workCh is only closed once
done chan struct{}
doneOnce sync.Once // Ensures done is only closed once
started atomic.Bool // Ensures Start() is only called once
// Metrics
totalNodes atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
}
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// The stop function tears down the owning session if this connection is later declared stale.
func (b *LockFreeBatcher) AddNode(
id types.NodeID,
c chan<- *tailcfg.MapResponse,
version tailcfg.CapabilityVersion,
stop func(),
) error {
addNodeStart := time.Now()
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
// Generate connection ID
connID := generateConnectionID()
// Create new connection entry
now := time.Now()
newEntry := &connectionEntry{
id: connID,
c: c,
version: version,
created: now,
stop: stop,
}
// Initialize last used timestamp
newEntry.lastUsed.Store(now.Unix())
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
if !loaded {
b.totalNodes.Add(1)
}
// Add connection to the list (lock-free)
nodeConn.addConnection(newEntry)
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil {
nlog.Error().Err(err).Msg("initial map generation failed")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("generating initial map for node %d: %w", id, err)
}
// Use a blocking send with timeout for initial map since the channel should be ready
// and we want to avoid the race condition where the receiver isn't ready yet
select {
case c <- initialMap:
// Success
case <-time.After(5 * time.Second): //nolint:mnd
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
Msg("initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
}
// Update connection status
b.connected.Store(id, nil) // nil = connected
// Node will automatically receive updates through the normal flow
// The initial full map already contains all current state
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection established in batcher")
return nil
}
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches one of the current connections, closes that specific connection,
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
// Reports if the node still has active connections after removal.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
nodeConn, exists := b.nodes.Load(id)
if !exists || nodeConn == nil {
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
return false
}
// Remove specific connection
removed := nodeConn.removeConnectionByChannel(c)
if !removed {
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
}
// Check if node has any remaining active connections
if nodeConn.hasActiveConnections() {
nlog.Debug().Caller().
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection removed but keeping online, other connections remain")
return true // Node still has active connections
}
// No active connections - keep the node entry alive for rapid reconnections
// The node will get a fresh full map when it reconnects
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
b.connected.Store(id, new(time.Now()))
return false
}
// AddWork queues a change to be processed by the batcher.
func (b *LockFreeBatcher) AddWork(r ...change.Change) {
b.addWork(r...)
}
func (b *LockFreeBatcher) Start() {
if !b.started.CompareAndSwap(false, true) {
return
}
b.done = make(chan struct{})
go b.doWork()
}
func (b *LockFreeBatcher) Close() {
// Signal shutdown to all goroutines, only once
b.doneOnce.Do(func() {
if b.done != nil {
close(b.done)
}
})
// Only close workCh once using sync.Once to prevent races
b.workChOnce.Do(func() {
close(b.workCh)
})
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
if conn == nil {
return true
}
conn.close()
return true
})
}
func (b *LockFreeBatcher) doWork() {
for i := range b.workers {
go b.worker(i + 1)
}
// Create a cleanup ticker for removing truly disconnected nodes
cleanupTicker := time.NewTicker(5 * time.Minute)
defer cleanupTicker.Stop()
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.done:
log.Info().Msg("batcher done channel closed, stopping to feed workers")
return
}
}
}
func (b *LockFreeBatcher) worker(workerID int) {
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
for {
select {
case w, ok := <-b.workCh:
if !ok {
wlog.Debug().Msg("worker channel closing, shutting down")
return
}
b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request
// where there is a blocking function waiting for the map that
// is being generated.
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
wlog.Error().Err(result.err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Str(zf.Reason, w.c.Reason).
Msg("failed to generate map response for synchronous work")
} else if result.mapResponse != nil {
// Update peer tracking for synchronous responses too
nc.updateSentPeers(result.mapResponse)
}
} else {
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
b.workErrors.Add(1)
wlog.Error().Err(result.err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
// Send result
select {
case w.resultCh <- result:
case <-b.done:
return
}
continue
}
// If resultCh is nil, this is an asynchronous work request
// that should be processed and sent to the node instead of
// returned to the caller.
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
// Apply change to node - this will handle offline nodes gracefully
// and queue work for when they reconnect
err := nc.change(w.c)
if err != nil {
b.workErrors.Add(1)
wlog.Error().Err(err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Str(zf.Reason, w.c.Reason).
Msg("failed to apply change")
}
}
case <-b.done:
wlog.Debug().Msg("batcher shutting down, exiting worker")
return
}
}
}
func (b *LockFreeBatcher) addWork(r ...change.Change) {
b.addToBatch(r...)
}
// queueWork safely queues work.
func (b *LockFreeBatcher) queueWork(w work) {
b.workQueuedCount.Add(1)
select {
case b.workCh <- w:
// Successfully queued
case <-b.done:
// Batcher is shutting down
return
}
}
// addToBatch adds changes to the pending batch.
func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
// Clean up any nodes being permanently removed from the system.
//
// This handles the case where a node is deleted from state but the batcher
// still has it registered. By cleaning up here, we prevent "node not found"
// errors when workers try to generate map responses for deleted nodes.
//
// Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually
// deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy
// changes that affect peer visibility do NOT use this field - they set
// RequiresRuntimePeerComputation=true and compute removed peers at runtime,
// putting them in tailcfg.MapResponse.PeersRemoved (a different struct).
// Therefore, this cleanup only removes nodes that are truly being deleted,
// not nodes that are still connected but have lost visibility of certain peers.
//
// See: https://github.com/juanfont/headscale/issues/2924
for _, ch := range changes {
for _, removedID := range ch.PeersRemoved {
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
b.totalNodes.Add(-1)
log.Debug().
Uint64(zf.NodeID, removedID.Uint64()).
Msg("removed deleted node from batcher")
}
b.connected.Delete(removedID)
}
}
// Short circuit if any of the changes is a full update, which
// means we can skip sending individual changes.
if change.HasFull(changes) {
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
nc.pendingMu.Lock()
nc.pending = []change.Change{change.FullUpdate()}
nc.pendingMu.Unlock()
return true
})
return
}
broadcast, targeted := change.SplitTargetedAndBroadcast(changes)
// Handle targeted changes - send only to the specific node
for _, ch := range targeted {
if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil {
nc.appendPending(ch)
}
}
// Handle broadcast changes - send to all nodes, filtering as needed
if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 {
nc.appendPending(filtered...)
}
return true
})
}
}
// processBatchedChanges processes all pending batched changes.
func (b *LockFreeBatcher) processBatchedChanges() {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
pending := nc.drainPending()
if len(pending) == 0 {
return true
}
// Send all batched changes for this node
for _, ch := range pending {
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
}
return true
})
}
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node
// reconnects between the hasActiveConnections() check and the Delete() call.
// TODO(kradalby): reevaluate if we want to keep this.
func (b *LockFreeBatcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute
now := time.Now()
var nodesToCleanup []types.NodeID
// Find nodes that have been offline for too long
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
nodesToCleanup = append(nodesToCleanup, nodeID)
}
return true
})
// Clean up the identified nodes using Compute() for atomic check-and-delete.
// This prevents a TOCTOU race where a node reconnects (adding an active
// connection) between the hasActiveConnections() check and the Delete() call.
cleaned := 0
for _, nodeID := range nodesToCleanup {
deleted := false
b.nodes.Compute(
nodeID,
func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) {
if !loaded || conn == nil || conn.hasActiveConnections() {
return conn, xsync.CancelOp
}
deleted = true
return conn, xsync.DeleteOp
},
)
if deleted {
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
Dur("offline_duration", cleanupThreshold).
Msg("cleaning up node that has been offline for too long")
b.connected.Delete(nodeID)
b.totalNodes.Add(-1)
cleaned++
}
}
if cleaned > 0 {
log.Info().Int(zf.CleanedNodes, cleaned).
Msg("completed cleanup of long-offline nodes")
}
}
// IsConnected is lock-free read that checks if a node has any active connections.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
// First check if we have active connections for this node
if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil {
if nodeConn.hasActiveConnections() {
return true
}
}
// Check disconnected timestamp with grace period
val, ok := b.connected.Load(id)
if !ok {
return false
}
// nil means connected
if val == nil {
return true
}
return false
}
// ConnectedMap returns a lock-free map of all connected nodes.
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]()
// First, add all nodes with active connections
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
if nodeConn == nil {
return true
}
if nodeConn.hasActiveConnections() {
ret.Store(id, true)
}
return true
})
// Then add all entries from the connected map
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// Only add if not already added as connected above
if _, exists := ret.Load(id); !exists {
if val == nil {
// nil means connected
ret.Store(id, true)
} else {
// timestamp means disconnected
ret.Store(id, false)
}
}
return true
})
return ret
}
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.done:
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
}
}
// connectionEntry represents a single connection to a node.
type connectionEntry struct {
id string // unique connection ID
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
created time.Time
stop func()
lastUsed atomic.Int64 // Unix timestamp of last successful send
closed atomic.Bool // Indicates if this connection has been closed
}
// multiChannelNodeConn manages multiple concurrent connections for a single node.
type multiChannelNodeConn struct {
id types.NodeID
mapper *mapper
log zerolog.Logger
mutex sync.RWMutex
connections []*connectionEntry
// pendingMu protects pending changes independently of the connection mutex.
// This avoids contention between addToBatch (which appends changes) and
// send() (which sends data to connections).
pendingMu sync.Mutex
pending []change.Change
closeOnce sync.Once
updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node.
// This enables computing diffs for policy changes instead of sending
// full peer lists (which clients interpret as "no change" when empty).
// Using xsync.Map for lock-free concurrent access.
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
}
// generateConnectionID generates a unique connection identifier.
func generateConnectionID() string {
bytes := make([]byte, 8)
_, _ = rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id,
mapper: mapper,
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
}
}
func (mc *multiChannelNodeConn) close() {
mc.closeOnce.Do(func() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
mc.stopConnection(conn)
}
})
}
// stopConnection marks a connection as closed and tears down the owning session
// at most once, even if multiple cleanup paths race to remove it.
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
if conn.closed.CompareAndSwap(false, true) {
if conn.stop != nil {
conn.stop()
}
}
}
// removeConnectionAtIndexLocked removes the active connection at index.
// If stopConnection is true, it also stops that session.
// Caller must hold mc.mutex.
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
conn := mc.connections[i]
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
if stopConnection {
mc.stopConnection(conn)
}
return conn
}
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
Int("total_connections", len(mc.connections)).
Dur("mutex_wait_time", mutexWaitDur).
Msg("successfully added connection after mutex wait")
}
// removeConnectionByChannel removes a connection by matching channel pointer.
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for i, entry := range mc.connections {
if entry.c == c {
mc.removeConnectionAtIndexLocked(i, false)
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("successfully removed connection")
return true
}
}
return false
}
// hasActiveConnections checks if the node has any active connections.
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections) > 0
}
// getActiveConnectionCount returns the number of active connections.
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections)
}
// appendPending appends changes to this node's pending change list.
// Thread-safe via pendingMu; does not contend with the connection mutex.
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
mc.pendingMu.Lock()
mc.pending = append(mc.pending, changes...)
mc.pendingMu.Unlock()
}
// drainPending atomically removes and returns all pending changes.
// Returns nil if there are no pending changes.
func (mc *multiChannelNodeConn) drainPending() []change.Change {
mc.pendingMu.Lock()
p := mc.pending
mc.pending = nil
mc.pendingMu.Unlock()
return p
}
// send broadcasts data to all active connections for the node.
// send broadcasts data to all connections using a two-phase approach to avoid
// holding the write lock during potentially slow sends. Each stale connection
// can block for up to 50ms (see connectionEntry.send), so N stale connections
// under a single write lock would block for N*50ms. The two-phase approach:
//
// 1. RLock: snapshot the connections slice (cheap pointer copy)
// 2. Unlock: send to all connections without any lock held (timeouts happen here)
// 3. Lock: remove only the failed connections by pointer identity
//
// New connections added during step 2 are safe: they receive a full initial
// map via AddNode, so missing this particular update causes no data loss.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Phase 1: snapshot connections under read lock.
mc.mutex.RLock()
if len(mc.connections) == 0 {
mc.mutex.RUnlock()
mc.log.Debug().Caller().
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil
}
// Copy the slice header (shares underlying array, but that's fine since
// we only read; writes go through the write lock in phase 3).
snapshot := make([]*connectionEntry, len(mc.connections))
copy(snapshot, mc.connections)
mc.mutex.RUnlock()
mc.log.Debug().Caller().
Int("total_connections", len(snapshot)).
Msg("send: broadcasting to all connections")
// Phase 2: send to all connections without holding any lock.
// Stale connection timeouts (50ms each) happen here without blocking
// other goroutines that need the mutex.
var (
lastErr error
successCount int
failed []*connectionEntry
)
for _, conn := range snapshot {
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).
Msg("send: attempting to send to connection")
err := conn.send(data)
if err != nil {
lastErr = err
failed = append(failed, conn)
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).
Msg("send: connection send failed")
} else {
successCount++
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).
Msg("send: successfully sent to connection")
}
}
// Phase 3: write-lock only to remove failed connections.
if len(failed) > 0 {
mc.mutex.Lock()
// Remove by pointer identity: only remove entries that still exist
// in the current connections slice and match a failed pointer.
// New connections added between phase 1 and 3 are not affected.
failedSet := make(map[*connectionEntry]struct{}, len(failed))
for _, f := range failed {
failedSet[f] = struct{}{}
}
clean := mc.connections[:0]
for _, conn := range mc.connections {
if _, isFailed := failedSet[conn]; !isFailed {
clean = append(clean, conn)
} else {
mc.log.Debug().Caller().
Str(zf.ConnID, conn.id).
Msg("send: removing failed connection")
// Tear down the owning session so the old serveLongPoll
// goroutine exits instead of lingering as a stale session.
mc.stopConnection(conn)
}
}
mc.connections = clean
mc.mutex.Unlock()
}
mc.updateCount.Add(1)
mc.log.Debug().
Int("successful_sends", successCount).
Int("failed_connections", len(failed)).
Msg("send: completed broadcast")
// Success if at least one send succeeded
if successCount > 0 {
return nil
}
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
}
// send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Check if the connection has been closed to prevent send on closed channel panic.
// This can happen during shutdown when Close() is called while workers are still processing.
if entry.closed.Load() {
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
}
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.
select {
case entry.c <- data:
// Update last used timestamp on successful send
entry.lastUsed.Store(time.Now().Unix())
return nil
case <-time.After(50 * time.Millisecond):
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
}
}
// nodeID returns the node ID.
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
return mc.id
}
// version returns the capability version from the first active connection.
// All connections for a node should have the same version in practice.
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if len(mc.connections) == 0 {
return 0
}
return mc.connections[0].version
}
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
// This must be called after successfully sending a response to keep track of
// what the client knows about, enabling accurate diffs for future updates.
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
if resp == nil {
return
}
// Full peer list replaces tracked state entirely
if resp.Peers != nil {
mc.lastSentPeers.Clear()
for _, peer := range resp.Peers {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
}
// Incremental additions
for _, peer := range resp.PeersChanged {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
// Incremental removals
for _, id := range resp.PeersRemoved {
mc.lastSentPeers.Delete(id)
}
}
// computePeerDiff compares the current peer list against what was last sent
// and returns the peers that were removed (in lastSentPeers but not in current).
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
for _, id := range currentPeers {
currentSet[id] = struct{}{}
}
var removed []tailcfg.NodeID
// Find removed: in lastSentPeers but not in current
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
if _, exists := currentSet[id]; !exists {
removed = append(removed, id)
}
return true
})
return removed
}
// change applies a change to all active connections for the node.
func (mc *multiChannelNodeConn) change(r change.Change) error {
return handleNodeChange(mc, mc.mapper, r)
}
// DebugNodeInfo contains debug information about a node's connections.
type DebugNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
}
// Debug returns a pre-baked map of node debug information for the debug interface.
func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
result := make(map[types.NodeID]DebugNodeInfo)
// Get all nodes with their connection status using immediate connection logic
// (no grace period) for debug purposes
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
if nodeConn == nil {
return true
}
nodeConn.mutex.RLock()
activeConnCount := len(nodeConn.connections)
nodeConn.mutex.RUnlock()
// Use immediate connection status: if active connections exist, node is connected
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
connected := false
if activeConnCount > 0 {
connected = true
} else {
// Check connected map for immediate status
if val, ok := b.connected.Load(id); ok && val == nil {
connected = true
}
}
result[id] = DebugNodeInfo{
Connected: connected,
ActiveConnections: activeConnCount,
}
return true
})
// Add all entries from the connected map to capture both connected and disconnected nodes
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// Only add if not already processed above
if _, exists := result[id]; !exists {
// Use immediate connection status for debug (no grace period)
connected := (val == nil) // nil means connected, timestamp means disconnected
result[id] = DebugNodeInfo{
Connected: connected,
ActiveConnections: 0,
}
}
return true
})
return result
}
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
return b.mapper.debugMapResponses()
}
// WorkErrors returns the count of work errors encountered.
// This is primarily useful for testing and debugging.
func (b *LockFreeBatcher) WorkErrors() int64 {
return b.workErrors.Load()
}