Files
headscale/hscontrol/mapper/batcher.go
Kristoffer Dalby 3276bda0c0 mapper/batcher: replace time.After with NewTimer to avoid timer leak
connectionEntry.send() is on the hot path: called once per connection
per broadcast tick. time.After allocates a timer that sits in the
runtime timer heap until it fires (50 ms), even when the channel send
succeeds immediately. At 1000 connected nodes, every tick leaks 1000
timers into the heap, creating continuous GC pressure.

Replace with time.NewTimer + defer timer.Stop() so the timer is
removed from the heap as soon as the fast-path send completes.
2026-03-14 02:52:28 -07:00

1155 lines
34 KiB
Go

package mapper
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
// Mapper errors.
var (
ErrInvalidNodeID = errors.New("invalid nodeID")
ErrMapperNil = errors.New("mapper is nil")
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
ErrNodeNotFoundMapper = errors.New("node not found")
)
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type",
}, []string{"response_type"})
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher {
return &Batcher{
mapper: mapper,
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
}
}
// NewBatcherAndMapper creates a new Batcher with its mapper.
func NewBatcherAndMapper(cfg *types.Config, state *state.State) *Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
// nodeConnection interface for different connection implementations.
type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
// updateSentPeers updates the tracking of which peers have been sent to this node.
updateSentPeers(resp *tailcfg.MapResponse)
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
nodeID := nc.nodeID()
version := nc.version()
if r.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send
}
if nodeID == 0 {
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
}
// Handle self-only responses
if r.IsSelfOnly() && r.TargetNode != nodeID {
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
}
// Check if this is a self-update (the changed node is the receiving node).
// When true, ensure the response includes the node's self info so it sees
// its own attribute changes (e.g., tags changed via admin API).
isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID
var (
mapResp *tailcfg.MapResponse
err error
)
// Track metric using categorized type, not free-form reason
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
// Check if this requires runtime peer visibility computation (e.g., policy changes)
if r.RequiresRuntimePeerComputation {
currentPeers := mapper.state.ListPeers(nodeID)
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
for _, peer := range currentPeers.All() {
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
}
removedPeers := nc.computePeerDiff(currentPeerIDs)
// Include self node when this is a self-update (e.g., node's own tags changed)
// so the node sees its updated self info along with new packet filters.
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate)
} else if isSelfUpdate {
// Non-policy self-update: just send the self node info
mapResp, err = mapper.selfMapResponse(nodeID, version)
} else {
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return ErrNodeConnectionNil
}
nodeID := nc.nodeID()
log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started")
data, err := generateMapResponse(nc, mapper, r)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some response types
return nil
}
// Send the map response
err = nc.send(data)
if err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
// Update peer tracking after successful send
nc.updateSentPeers(data)
return nil
}
// workResult represents the result of processing a change.
type workResult struct {
mapResponse *tailcfg.MapResponse
err error
}
// work represents a unit of work to be processed by workers.
type work struct {
c change.Change
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}
// Batcher errors.
var (
errConnectionClosed = errors.New("connection channel already closed")
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
ErrBatcherShuttingDown = errors.New("batcher shutting down")
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
)
// Batcher batches and distributes map responses to connected nodes.
// It uses concurrent maps, per-node mutexes, and a worker pool.
type Batcher struct {
tick *time.Ticker
mapper *mapper
workers int
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
done chan struct{}
doneOnce sync.Once // Ensures done is only closed once
started atomic.Bool // Ensures Start() is only called once
// Metrics
totalNodes atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
}
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// The stop function tears down the owning session if this connection is later declared stale.
func (b *Batcher) AddNode(
id types.NodeID,
c chan<- *tailcfg.MapResponse,
version tailcfg.CapabilityVersion,
stop func(),
) error {
addNodeStart := time.Now()
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
// Generate connection ID
connID := generateConnectionID()
// Create new connection entry
now := time.Now()
newEntry := &connectionEntry{
id: connID,
c: c,
version: version,
created: now,
stop: stop,
}
// Initialize last used timestamp
newEntry.lastUsed.Store(now.Unix())
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
if !loaded {
b.totalNodes.Add(1)
}
// Add connection to the list (lock-free)
nodeConn.addConnection(newEntry)
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil {
nlog.Error().Err(err).Msg("initial map generation failed")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("generating initial map for node %d: %w", id, err)
}
// Use a blocking send with timeout for initial map since the channel should be ready
// and we want to avoid the race condition where the receiver isn't ready yet
select {
case c <- initialMap:
// Success
case <-time.After(5 * time.Second): //nolint:mnd
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
Msg("initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
}
// Update connection status
b.connected.Store(id, nil) // nil = connected
// Node will automatically receive updates through the normal flow
// The initial full map already contains all current state
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection established in batcher")
return nil
}
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches one of the current connections, closes that specific connection,
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
// Reports if the node still has active connections after removal.
func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
nodeConn, exists := b.nodes.Load(id)
if !exists || nodeConn == nil {
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
return false
}
// Remove specific connection
removed := nodeConn.removeConnectionByChannel(c)
if !removed {
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
}
// Check if node has any remaining active connections
if nodeConn.hasActiveConnections() {
nlog.Debug().Caller().
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection removed but keeping online, other connections remain")
return true // Node still has active connections
}
// No active connections - keep the node entry alive for rapid reconnections
// The node will get a fresh full map when it reconnects
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
b.connected.Store(id, new(time.Now()))
return false
}
// AddWork queues a change to be processed by the batcher.
func (b *Batcher) AddWork(r ...change.Change) {
b.addWork(r...)
}
func (b *Batcher) Start() {
if !b.started.CompareAndSwap(false, true) {
return
}
b.done = make(chan struct{})
go b.doWork()
}
func (b *Batcher) Close() {
// Signal shutdown to all goroutines, only once.
// Workers and queueWork both select on done, so closing it
// is sufficient for graceful shutdown. We intentionally do NOT
// close workCh here because processBatchedChanges or
// MapResponseFromChange may still be sending on it concurrently.
b.doneOnce.Do(func() {
if b.done != nil {
close(b.done)
}
})
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
if conn == nil {
return true
}
conn.close()
return true
})
}
func (b *Batcher) doWork() {
for i := range b.workers {
go b.worker(i + 1)
}
// Create a cleanup ticker for removing truly disconnected nodes
cleanupTicker := time.NewTicker(5 * time.Minute)
defer cleanupTicker.Stop()
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.done:
log.Info().Msg("batcher done channel closed, stopping to feed workers")
return
}
}
}
func (b *Batcher) worker(workerID int) {
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
for {
select {
case w, ok := <-b.workCh:
if !ok {
wlog.Debug().Msg("worker channel closing, shutting down")
return
}
b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request
// where there is a blocking function waiting for the map that
// is being generated.
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
wlog.Error().Err(result.err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Str(zf.Reason, w.c.Reason).
Msg("failed to generate map response for synchronous work")
} else if result.mapResponse != nil {
// Update peer tracking for synchronous responses too
nc.updateSentPeers(result.mapResponse)
}
} else {
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
b.workErrors.Add(1)
wlog.Error().Err(result.err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
// Send result
select {
case w.resultCh <- result:
case <-b.done:
return
}
continue
}
// If resultCh is nil, this is an asynchronous work request
// that should be processed and sent to the node instead of
// returned to the caller.
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
// Apply change to node - this will handle offline nodes gracefully
// and queue work for when they reconnect
err := nc.change(w.c)
if err != nil {
b.workErrors.Add(1)
wlog.Error().Err(err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Str(zf.Reason, w.c.Reason).
Msg("failed to apply change")
}
}
case <-b.done:
wlog.Debug().Msg("batcher shutting down, exiting worker")
return
}
}
}
func (b *Batcher) addWork(r ...change.Change) {
b.addToBatch(r...)
}
// queueWork safely queues work.
func (b *Batcher) queueWork(w work) {
b.workQueuedCount.Add(1)
select {
case b.workCh <- w:
// Successfully queued
case <-b.done:
// Batcher is shutting down
return
}
}
// addToBatch adds changes to the pending batch.
func (b *Batcher) addToBatch(changes ...change.Change) {
// Clean up any nodes being permanently removed from the system.
//
// This handles the case where a node is deleted from state but the batcher
// still has it registered. By cleaning up here, we prevent "node not found"
// errors when workers try to generate map responses for deleted nodes.
//
// Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually
// deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy
// changes that affect peer visibility do NOT use this field - they set
// RequiresRuntimePeerComputation=true and compute removed peers at runtime,
// putting them in tailcfg.MapResponse.PeersRemoved (a different struct).
// Therefore, this cleanup only removes nodes that are truly being deleted,
// not nodes that are still connected but have lost visibility of certain peers.
//
// See: https://github.com/juanfont/headscale/issues/2924
for _, ch := range changes {
for _, removedID := range ch.PeersRemoved {
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
b.totalNodes.Add(-1)
log.Debug().
Uint64(zf.NodeID, removedID.Uint64()).
Msg("removed deleted node from batcher")
}
b.connected.Delete(removedID)
}
}
// Short circuit if any of the changes is a full update, which
// means we can skip sending individual changes.
if change.HasFull(changes) {
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
nc.pendingMu.Lock()
nc.pending = []change.Change{change.FullUpdate()}
nc.pendingMu.Unlock()
return true
})
return
}
broadcast, targeted := change.SplitTargetedAndBroadcast(changes)
// Handle targeted changes - send only to the specific node
for _, ch := range targeted {
if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil {
nc.appendPending(ch)
}
}
// Handle broadcast changes - send to all nodes, filtering as needed
if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 {
nc.appendPending(filtered...)
}
return true
})
}
}
// processBatchedChanges processes all pending batched changes.
func (b *Batcher) processBatchedChanges() {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
pending := nc.drainPending()
if len(pending) == 0 {
return true
}
// Send all batched changes for this node
for _, ch := range pending {
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
}
return true
})
}
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node
// reconnects between the hasActiveConnections() check and the Delete() call.
// TODO(kradalby): reevaluate if we want to keep this.
func (b *Batcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute
now := time.Now()
var nodesToCleanup []types.NodeID
// Find nodes that have been offline for too long
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
nodesToCleanup = append(nodesToCleanup, nodeID)
}
return true
})
// Clean up the identified nodes using Compute() for atomic check-and-delete.
// This prevents a TOCTOU race where a node reconnects (adding an active
// connection) between the hasActiveConnections() check and the Delete() call.
cleaned := 0
for _, nodeID := range nodesToCleanup {
deleted := false
b.nodes.Compute(
nodeID,
func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) {
if !loaded || conn == nil || conn.hasActiveConnections() {
return conn, xsync.CancelOp
}
deleted = true
return conn, xsync.DeleteOp
},
)
if deleted {
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
Dur("offline_duration", cleanupThreshold).
Msg("cleaning up node that has been offline for too long")
b.connected.Delete(nodeID)
b.totalNodes.Add(-1)
cleaned++
}
}
if cleaned > 0 {
log.Info().Int(zf.CleanedNodes, cleaned).
Msg("completed cleanup of long-offline nodes")
}
}
// IsConnected is lock-free read that checks if a node has any active connections.
func (b *Batcher) IsConnected(id types.NodeID) bool {
// First check if we have active connections for this node
if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil {
if nodeConn.hasActiveConnections() {
return true
}
}
// Check disconnected timestamp with grace period
val, ok := b.connected.Load(id)
if !ok {
return false
}
// nil means connected
if val == nil {
return true
}
return false
}
// ConnectedMap returns a lock-free map of all connected nodes.
func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]()
// First, add all nodes with active connections
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
if nodeConn == nil {
return true
}
if nodeConn.hasActiveConnections() {
ret.Store(id, true)
}
return true
})
// Then add all entries from the connected map
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// Only add if not already added as connected above
if _, exists := ret.Load(id); !exists {
if val == nil {
// nil means connected
ret.Store(id, true)
} else {
// timestamp means disconnected
ret.Store(id, false)
}
}
return true
})
return ret
}
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.done:
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
}
}
// connectionEntry represents a single connection to a node.
type connectionEntry struct {
id string // unique connection ID
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
created time.Time
stop func()
lastUsed atomic.Int64 // Unix timestamp of last successful send
closed atomic.Bool // Indicates if this connection has been closed
}
// multiChannelNodeConn manages multiple concurrent connections for a single node.
type multiChannelNodeConn struct {
id types.NodeID
mapper *mapper
log zerolog.Logger
mutex sync.RWMutex
connections []*connectionEntry
// pendingMu protects pending changes independently of the connection mutex.
// This avoids contention between addToBatch (which appends changes) and
// send() (which sends data to connections).
pendingMu sync.Mutex
pending []change.Change
closeOnce sync.Once
updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node.
// This enables computing diffs for policy changes instead of sending
// full peer lists (which clients interpret as "no change" when empty).
// Using xsync.Map for lock-free concurrent access.
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
}
// generateConnectionID generates a unique connection identifier.
func generateConnectionID() string {
bytes := make([]byte, 8)
_, _ = rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id,
mapper: mapper,
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
}
}
func (mc *multiChannelNodeConn) close() {
mc.closeOnce.Do(func() {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
mc.stopConnection(conn)
}
})
}
// stopConnection marks a connection as closed and tears down the owning session
// at most once, even if multiple cleanup paths race to remove it.
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
if conn.closed.CompareAndSwap(false, true) {
if conn.stop != nil {
conn.stop()
}
}
}
// removeConnectionAtIndexLocked removes the active connection at index.
// If stopConnection is true, it also stops that session.
// Caller must hold mc.mutex.
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
conn := mc.connections[i]
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
if stopConnection {
mc.stopConnection(conn)
}
return conn
}
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
Int("total_connections", len(mc.connections)).
Dur("mutex_wait_time", mutexWaitDur).
Msg("successfully added connection after mutex wait")
}
// removeConnectionByChannel removes a connection by matching channel pointer.
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
mc.mutex.Lock()
defer mc.mutex.Unlock()
for i, entry := range mc.connections {
if entry.c == c {
mc.removeConnectionAtIndexLocked(i, false)
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("successfully removed connection")
return true
}
}
return false
}
// hasActiveConnections checks if the node has any active connections.
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections) > 0
}
// getActiveConnectionCount returns the number of active connections.
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
return len(mc.connections)
}
// appendPending appends changes to this node's pending change list.
// Thread-safe via pendingMu; does not contend with the connection mutex.
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
mc.pendingMu.Lock()
mc.pending = append(mc.pending, changes...)
mc.pendingMu.Unlock()
}
// drainPending atomically removes and returns all pending changes.
// Returns nil if there are no pending changes.
func (mc *multiChannelNodeConn) drainPending() []change.Change {
mc.pendingMu.Lock()
p := mc.pending
mc.pending = nil
mc.pendingMu.Unlock()
return p
}
// send broadcasts data to all active connections for the node.
//
// To avoid holding the write lock during potentially slow sends (each stale
// connection can block for up to 50ms), the method snapshots connections under
// a read lock, sends without any lock held, then write-locks only to remove
// failures. New connections added between the snapshot and cleanup are safe:
// they receive a full initial map via AddNode, so missing this update causes
// no data loss.
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Snapshot connections under read lock.
mc.mutex.RLock()
if len(mc.connections) == 0 {
mc.mutex.RUnlock()
mc.log.Debug().Caller().
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil
}
// Copy the slice so we can release the read lock before sending.
snapshot := make([]*connectionEntry, len(mc.connections))
copy(snapshot, mc.connections)
mc.mutex.RUnlock()
mc.log.Debug().Caller().
Int("total_connections", len(snapshot)).
Msg("send: broadcasting to all connections")
// Send to all connections without holding any lock.
// Stale connection timeouts (50ms each) happen here without blocking
// other goroutines that need the mutex.
var (
lastErr error
successCount int
failed []*connectionEntry
)
for _, conn := range snapshot {
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).
Msg("send: attempting to send to connection")
err := conn.send(data)
if err != nil {
lastErr = err
failed = append(failed, conn)
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).
Msg("send: connection send failed")
} else {
successCount++
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).
Msg("send: successfully sent to connection")
}
}
// Write-lock only to remove failed connections.
if len(failed) > 0 {
mc.mutex.Lock()
// Remove by pointer identity: only remove entries that still exist
// in the current connections slice and match a failed pointer.
// New connections added since the snapshot are not affected.
failedSet := make(map[*connectionEntry]struct{}, len(failed))
for _, f := range failed {
failedSet[f] = struct{}{}
}
clean := mc.connections[:0]
for _, conn := range mc.connections {
if _, isFailed := failedSet[conn]; !isFailed {
clean = append(clean, conn)
} else {
mc.log.Debug().Caller().
Str(zf.ConnID, conn.id).
Msg("send: removing failed connection")
// Tear down the owning session so the old serveLongPoll
// goroutine exits instead of lingering as a stale session.
mc.stopConnection(conn)
}
}
mc.connections = clean
mc.mutex.Unlock()
}
mc.updateCount.Add(1)
mc.log.Debug().
Int("successful_sends", successCount).
Int("failed_connections", len(failed)).
Msg("send: completed broadcast")
// Success if at least one send succeeded
if successCount > 0 {
return nil
}
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
}
// send sends data to a single connection entry with timeout-based stale connection detection.
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
if data == nil {
return nil
}
// Check if the connection has been closed to prevent send on closed channel panic.
// This can happen during shutdown when Close() is called while workers are still processing.
if entry.closed.Load() {
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
}
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.
//
// We use time.NewTimer + Stop instead of time.After to avoid leaking timers.
// time.After creates a timer that lives in the runtime's timer heap until it fires,
// even when the send succeeds immediately. On the hot path (1000+ nodes per tick),
// this leaks thousands of timers per second.
timer := time.NewTimer(50 * time.Millisecond) //nolint:mnd
defer timer.Stop()
select {
case entry.c <- data:
// Update last used timestamp on successful send
entry.lastUsed.Store(time.Now().Unix())
return nil
case <-timer.C:
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
}
}
// nodeID returns the node ID.
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
return mc.id
}
// version returns the capability version from the first active connection.
// All connections for a node should have the same version in practice.
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
mc.mutex.RLock()
defer mc.mutex.RUnlock()
if len(mc.connections) == 0 {
return 0
}
return mc.connections[0].version
}
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
// This must be called after successfully sending a response to keep track of
// what the client knows about, enabling accurate diffs for future updates.
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
if resp == nil {
return
}
// Full peer list replaces tracked state entirely
if resp.Peers != nil {
mc.lastSentPeers.Clear()
for _, peer := range resp.Peers {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
}
// Incremental additions
for _, peer := range resp.PeersChanged {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
// Incremental removals
for _, id := range resp.PeersRemoved {
mc.lastSentPeers.Delete(id)
}
}
// computePeerDiff compares the current peer list against what was last sent
// and returns the peers that were removed (in lastSentPeers but not in current).
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
for _, id := range currentPeers {
currentSet[id] = struct{}{}
}
var removed []tailcfg.NodeID
// Find removed: in lastSentPeers but not in current
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
if _, exists := currentSet[id]; !exists {
removed = append(removed, id)
}
return true
})
return removed
}
// change applies a change to all active connections for the node.
func (mc *multiChannelNodeConn) change(r change.Change) error {
return handleNodeChange(mc, mc.mapper, r)
}
// DebugNodeInfo contains debug information about a node's connections.
type DebugNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
}
// Debug returns a pre-baked map of node debug information for the debug interface.
func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo {
result := make(map[types.NodeID]DebugNodeInfo)
// Get all nodes with their connection status using immediate connection logic
// (no grace period) for debug purposes
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
if nodeConn == nil {
return true
}
nodeConn.mutex.RLock()
activeConnCount := len(nodeConn.connections)
nodeConn.mutex.RUnlock()
// Use immediate connection status: if active connections exist, node is connected
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
connected := false
if activeConnCount > 0 {
connected = true
} else {
// Check connected map for immediate status
if val, ok := b.connected.Load(id); ok && val == nil {
connected = true
}
}
result[id] = DebugNodeInfo{
Connected: connected,
ActiveConnections: activeConnCount,
}
return true
})
// Add all entries from the connected map to capture both connected and disconnected nodes
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// Only add if not already processed above
if _, exists := result[id]; !exists {
// Use immediate connection status for debug (no grace period)
connected := (val == nil) // nil means connected, timestamp means disconnected
result[id] = DebugNodeInfo{
Connected: connected,
ActiveConnections: 0,
}
}
return true
})
return result
}
func (b *Batcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
return b.mapper.debugMapResponses()
}
// WorkErrors returns the count of work errors encountered.
// This is primarily useful for testing and debugging.
func (b *Batcher) WorkErrors() int64 {
return b.workErrors.Load()
}