Files
headscale/hscontrol/mapper/batcher.go
Kristoffer Dalby 87b8507ac9 mapper/batcher: replace connected map with per-node disconnectedAt
The Batcher's connected field (*xsync.Map[types.NodeID, *time.Time])
encoded three states via pointer semantics:

  - nil value:    node is connected
  - non-nil time: node disconnected at that timestamp
  - key missing:  node was never seen

This was error-prone (nil meaning 'connected' inverts Go idioms),
redundant with b.nodes + hasActiveConnections(), and required keeping
two parallel maps in sync. It also contained a bug in RemoveNode where
new(time.Now()) was used instead of &now, producing a zero time.

Replace the separate connected map with a disconnectedAt field on
multiChannelNodeConn (atomic.Pointer[time.Time]), tracked directly
on the object that already manages the node's connections.

Changes:
  - Add disconnectedAt field and helpers (markConnected, markDisconnected,
    isConnected, offlineDuration) to multiChannelNodeConn
  - Remove the connected field from Batcher
  - Simplify IsConnected from two map lookups to one
  - Simplify ConnectedMap and Debug from two-map iteration to one
  - Rewrite cleanupOfflineNodes to scan b.nodes directly
  - Remove the markDisconnectedIfNoConns helper
  - Update all tests and benchmarks

Fixes #3141
2026-03-16 02:22:56 -07:00

738 lines
22 KiB
Go

package mapper
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
// Mapper errors.
var (
ErrInvalidNodeID = errors.New("invalid nodeID")
ErrMapperNil = errors.New("mapper is nil")
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
ErrNodeNotFoundMapper = errors.New("node not found")
)
// offlineNodeCleanupThreshold is how long a node must be disconnected
// before cleanupOfflineNodes removes its in-memory state.
const offlineNodeCleanupThreshold = 15 * time.Minute
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type",
}, []string{"response_type"})
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher {
return &Batcher{
mapper: mapper,
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
done: make(chan struct{}),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
}
}
// NewBatcherAndMapper creates a new Batcher with its mapper.
func NewBatcherAndMapper(cfg *types.Config, state *state.State) *Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
// nodeConnection interface for different connection implementations.
type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
// updateSentPeers updates the tracking of which peers have been sent to this node.
updateSentPeers(resp *tailcfg.MapResponse)
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
nodeID := nc.nodeID()
version := nc.version()
if r.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send
}
if nodeID == 0 {
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
}
// Handle self-only responses
if r.IsSelfOnly() && r.TargetNode != nodeID {
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
}
// Check if this is a self-update (the changed node is the receiving node).
// When true, ensure the response includes the node's self info so it sees
// its own attribute changes (e.g., tags changed via admin API).
isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID
var (
mapResp *tailcfg.MapResponse
err error
)
// Track metric using categorized type, not free-form reason
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
// Check if this requires runtime peer visibility computation (e.g., policy changes)
if r.RequiresRuntimePeerComputation {
currentPeers := mapper.state.ListPeers(nodeID)
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
for _, peer := range currentPeers.All() {
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
}
removedPeers := nc.computePeerDiff(currentPeerIDs)
// Include self node when this is a self-update (e.g., node's own tags changed)
// so the node sees its updated self info along with new packet filters.
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate)
} else if isSelfUpdate {
// Non-policy self-update: just send the self node info
mapResp, err = mapper.selfMapResponse(nodeID, version)
} else {
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return ErrNodeConnectionNil
}
nodeID := nc.nodeID()
log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started")
data, err := generateMapResponse(nc, mapper, r)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some response types
return nil
}
// Send the map response
err = nc.send(data)
if err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
// Update peer tracking after successful send
nc.updateSentPeers(data)
return nil
}
// workResult represents the result of processing a change.
type workResult struct {
mapResponse *tailcfg.MapResponse
err error
}
// work represents a unit of work to be processed by workers.
// All pending changes for a node are bundled into a single work item
// so that one worker processes them sequentially. This prevents
// out-of-order MapResponse delivery and races on lastSentPeers
// that occur when multiple workers process changes for the same node.
type work struct {
changes []change.Change
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}
// Batcher errors.
var (
errConnectionClosed = errors.New("connection channel already closed")
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
ErrBatcherShuttingDown = errors.New("batcher shutting down")
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
)
// Batcher batches and distributes map responses to connected nodes.
// It uses concurrent maps, per-node mutexes, and a worker pool.
//
// Lifecycle: Call Start() to spawn workers, then Close() to shut down.
// Close() blocks until all workers have exited. A Batcher must not
// be reused after Close().
type Batcher struct {
tick *time.Ticker
mapper *mapper
workers int
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
// Work queue channel
workCh chan work
done chan struct{}
doneOnce sync.Once // Ensures done is only closed once
// wg tracks the doWork and all worker goroutines so that Close()
// can block until they have fully exited.
wg sync.WaitGroup
started atomic.Bool // Ensures Start() is only called once
// Metrics
totalNodes atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
}
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// The stop function tears down the owning session if this connection is later declared stale.
func (b *Batcher) AddNode(
id types.NodeID,
c chan<- *tailcfg.MapResponse,
version tailcfg.CapabilityVersion,
stop func(),
) error {
addNodeStart := time.Now()
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
// Generate connection ID
connID := generateConnectionID()
// Create new connection entry
now := time.Now()
newEntry := &connectionEntry{
id: connID,
c: c,
version: version,
created: now,
stop: stop,
}
// Initialize last used timestamp
newEntry.lastUsed.Store(now.Unix())
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
if !loaded {
b.totalNodes.Add(1)
}
// Add connection to the list (lock-free)
nodeConn.addConnection(newEntry)
// Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil {
nlog.Error().Err(err).Msg("initial map generation failed")
nodeConn.removeConnectionByChannel(c)
if !nodeConn.hasActiveConnections() {
nodeConn.markDisconnected()
}
return fmt.Errorf("generating initial map for node %d: %w", id, err)
}
// Use a blocking send with timeout for initial map since the channel should be ready
// and we want to avoid the race condition where the receiver isn't ready yet
select {
case c <- initialMap:
// Success
case <-time.After(5 * time.Second): //nolint:mnd
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
Msg("initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c)
if !nodeConn.hasActiveConnections() {
nodeConn.markDisconnected()
}
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
}
// Mark the node as connected now that the initial map was sent.
nodeConn.markConnected()
// Node will automatically receive updates through the normal flow
// The initial full map already contains all current state
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection established in batcher")
return nil
}
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches one of the current connections, closes that specific connection,
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
// Reports if the node still has active connections after removal.
func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
nodeConn, exists := b.nodes.Load(id)
if !exists || nodeConn == nil {
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
return false
}
// Remove specific connection
removed := nodeConn.removeConnectionByChannel(c)
if !removed {
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
}
// Check if node has any remaining active connections
if nodeConn.hasActiveConnections() {
nlog.Debug().Caller().
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("node connection removed but keeping online, other connections remain")
return true // Node still has active connections
}
// No active connections - keep the node entry alive for rapid reconnections
// The node will get a fresh full map when it reconnects
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
nodeConn.markDisconnected()
return false
}
// AddWork queues a change to be processed by the batcher.
func (b *Batcher) AddWork(r ...change.Change) {
b.addToBatch(r...)
}
func (b *Batcher) Start() {
if !b.started.CompareAndSwap(false, true) {
return
}
b.wg.Add(1)
go b.doWork()
}
func (b *Batcher) Close() {
// Signal shutdown to all goroutines, only once.
// Workers and queueWork both select on done, so closing it
// is sufficient for graceful shutdown. We intentionally do NOT
// close workCh here because processBatchedChanges or
// MapResponseFromChange may still be sending on it concurrently.
b.doneOnce.Do(func() {
close(b.done)
})
// Wait for all worker goroutines (and doWork) to exit before
// tearing down node connections. This prevents workers from
// sending on connections that are being closed concurrently.
b.wg.Wait()
// Stop the ticker to prevent resource leaks.
b.tick.Stop()
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
if conn == nil {
return true
}
conn.close()
return true
})
}
func (b *Batcher) doWork() {
defer b.wg.Done()
for i := range b.workers {
b.wg.Add(1)
go b.worker(i + 1)
}
// Create a cleanup ticker for removing truly disconnected nodes
cleanupTicker := time.NewTicker(5 * time.Minute)
defer cleanupTicker.Stop()
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.done:
log.Info().Msg("batcher done channel closed, stopping to feed workers")
return
}
}
}
func (b *Batcher) worker(workerID int) {
defer b.wg.Done()
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
for {
select {
case w, ok := <-b.workCh:
if !ok {
wlog.Debug().Msg("worker channel closing, shutting down")
return
}
b.workProcessed.Add(1)
// Synchronous path: a caller is blocking on resultCh
// waiting for a generated MapResponse (used by AddNode
// for the initial map). Always contains a single change.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
// Hold workMu so concurrent async work for this
// node waits until the initial map is sent.
nc.workMu.Lock()
var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.changes[0])
result.err = err
if result.err != nil {
b.workErrors.Add(1)
wlog.Error().Err(result.err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Str(zf.Reason, w.changes[0].Reason).
Msg("failed to generate map response for synchronous work")
} else if result.mapResponse != nil {
nc.updateSentPeers(result.mapResponse)
}
nc.workMu.Unlock()
} else {
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
b.workErrors.Add(1)
wlog.Error().Err(result.err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
select {
case w.resultCh <- result:
case <-b.done:
return
}
continue
}
// Async path: process all bundled changes sequentially.
// workMu ensures that if another worker picks up the next
// tick's bundle for the same node, it waits until we
// finish — preventing out-of-order delivery and races
// on lastSentPeers (Clear+Store vs Range).
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
nc.workMu.Lock()
for _, ch := range w.changes {
err := nc.change(ch)
if err != nil {
b.workErrors.Add(1)
wlog.Error().Err(err).
Uint64(zf.NodeID, w.nodeID.Uint64()).
Str(zf.Reason, ch.Reason).
Msg("failed to apply change")
}
}
nc.workMu.Unlock()
}
case <-b.done:
wlog.Debug().Msg("batcher shutting down, exiting worker")
return
}
}
}
// queueWork safely queues work.
func (b *Batcher) queueWork(w work) {
b.workQueuedCount.Add(1)
select {
case b.workCh <- w:
// Successfully queued
case <-b.done:
// Batcher is shutting down
return
}
}
// addToBatch adds changes to the pending batch.
func (b *Batcher) addToBatch(changes ...change.Change) {
// Clean up any nodes being permanently removed from the system.
//
// This handles the case where a node is deleted from state but the batcher
// still has it registered. By cleaning up here, we prevent "node not found"
// errors when workers try to generate map responses for deleted nodes.
//
// Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually
// deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy
// changes that affect peer visibility do NOT use this field - they set
// RequiresRuntimePeerComputation=true and compute removed peers at runtime,
// putting them in tailcfg.MapResponse.PeersRemoved (a different struct).
// Therefore, this cleanup only removes nodes that are truly being deleted,
// not nodes that are still connected but have lost visibility of certain peers.
//
// See: https://github.com/juanfont/headscale/issues/2924
for _, ch := range changes {
for _, removedID := range ch.PeersRemoved {
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
b.totalNodes.Add(-1)
log.Debug().
Uint64(zf.NodeID, removedID.Uint64()).
Msg("removed deleted node from batcher")
}
}
}
// Short circuit if any of the changes is a full update, which
// means we can skip sending individual changes.
if change.HasFull(changes) {
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
nc.pendingMu.Lock()
nc.pending = []change.Change{change.FullUpdate()}
nc.pendingMu.Unlock()
return true
})
return
}
broadcast, targeted := change.SplitTargetedAndBroadcast(changes)
// Handle targeted changes - send only to the specific node
for _, ch := range targeted {
if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil {
nc.appendPending(ch)
}
}
// Handle broadcast changes - send to all nodes, filtering as needed
if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 {
nc.appendPending(filtered...)
}
return true
})
}
}
// processBatchedChanges processes all pending batched changes.
func (b *Batcher) processBatchedChanges() {
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
pending := nc.drainPending()
if len(pending) == 0 {
return true
}
// Queue a single work item containing all pending changes.
// One item per node ensures a single worker processes them
// sequentially, preventing out-of-order delivery.
b.queueWork(work{changes: pending, nodeID: nodeID, resultCh: nil})
return true
})
}
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node
// reconnects between the hasActiveConnections() check and the Delete() call.
func (b *Batcher) cleanupOfflineNodes() {
var nodesToCleanup []types.NodeID
// Find nodes that have been offline for too long by scanning b.nodes
// and checking each node's disconnectedAt timestamp.
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
if nc != nil && !nc.hasActiveConnections() && nc.offlineDuration() > offlineNodeCleanupThreshold {
nodesToCleanup = append(nodesToCleanup, nodeID)
}
return true
})
// Clean up the identified nodes using Compute() for atomic check-and-delete.
// This prevents a TOCTOU race where a node reconnects (adding an active
// connection) between the hasActiveConnections() check and the Delete() call.
cleaned := 0
for _, nodeID := range nodesToCleanup {
b.nodes.Compute(
nodeID,
func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) {
if !loaded || conn == nil || conn.hasActiveConnections() {
return conn, xsync.CancelOp
}
// Perform all bookkeeping inside the Compute callback so
// that a concurrent AddNode (which calls LoadOrStore on
// b.nodes) cannot slip in between the delete and the
// counter update.
b.totalNodes.Add(-1)
cleaned++
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
Dur("offline_duration", offlineNodeCleanupThreshold).
Msg("cleaning up node that has been offline for too long")
return conn, xsync.DeleteOp
},
)
}
if cleaned > 0 {
log.Info().Int(zf.CleanedNodes, cleaned).
Msg("completed cleanup of long-offline nodes")
}
}
// IsConnected is a lock-free read that checks if a node is connected.
// A node is considered connected if it has active connections or has
// not been marked as disconnected.
func (b *Batcher) IsConnected(id types.NodeID) bool {
nodeConn, exists := b.nodes.Load(id)
if !exists || nodeConn == nil {
return false
}
return nodeConn.isConnected()
}
// ConnectedMap returns a lock-free map of all known nodes and their
// connection status (true = connected, false = disconnected).
func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]()
b.nodes.Range(func(id types.NodeID, nc *multiChannelNodeConn) bool {
if nc != nil {
ret.Store(id, nc.isConnected())
}
return true
})
return ret
}
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{changes: []change.Change{ch}, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.done:
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
}
}
// DebugNodeInfo contains debug information about a node's connections.
type DebugNodeInfo struct {
Connected bool `json:"connected"`
ActiveConnections int `json:"active_connections"`
}
// Debug returns a pre-baked map of node debug information for the debug interface.
func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo {
result := make(map[types.NodeID]DebugNodeInfo)
b.nodes.Range(func(id types.NodeID, nc *multiChannelNodeConn) bool {
if nc == nil {
return true
}
result[id] = DebugNodeInfo{
Connected: nc.isConnected(),
ActiveConnections: nc.getActiveConnectionCount(),
}
return true
})
return result
}
func (b *Batcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
return b.mapper.debugMapResponses()
}
// WorkErrors returns the count of work errors encountered.
// This is primarily useful for testing and debugging.
func (b *Batcher) WorkErrors() int64 {
return b.workErrors.Load()
}