package mapper import ( "errors" "fmt" "sync" "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util/zlog/zf" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) // Mapper errors. var ( ErrInvalidNodeID = errors.New("invalid nodeID") ErrMapperNil = errors.New("mapper is nil") ErrNodeConnectionNil = errors.New("nodeConnection is nil") ErrNodeNotFoundMapper = errors.New("node not found") ) // offlineNodeCleanupThreshold is how long a node must be disconnected // before cleanupOfflineNodes removes its in-memory state. const offlineNodeCleanupThreshold = 15 * time.Minute var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", Help: "total count of mapresponses generated by response type", }, []string{"response_type"}) func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher { return &Batcher{ mapper: mapper, workers: workers, tick: time.NewTicker(batchTime), // The size of this channel is arbitrary chosen, the sizing should be revisited. workCh: make(chan work, workers*200), done: make(chan struct{}), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), } } // NewBatcherAndMapper creates a new Batcher with its mapper. func NewBatcherAndMapper(cfg *types.Config, state *state.State) *Batcher { m := newMapper(cfg, state) b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) m.batcher = b return b } // nodeConnection interface for different connection implementations. type nodeConnection interface { nodeID() types.NodeID version() tailcfg.CapabilityVersion send(data *tailcfg.MapResponse) error // computePeerDiff returns peers that were previously sent but are no longer in the current list. computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID) // updateSentPeers updates the tracking of which peers have been sent to this node. updateSentPeers(resp *tailcfg.MapResponse) } // generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change]. func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) { nodeID := nc.nodeID() version := nc.version() if r.IsEmpty() { return nil, nil //nolint:nilnil // Empty response means nothing to send } if nodeID == 0 { return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID) } if mapper == nil { return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID) } // Handle self-only responses if r.IsSelfOnly() && r.TargetNode != nodeID { return nil, nil //nolint:nilnil // No response needed for other nodes when self-only } // Check if this is a self-update (the changed node is the receiving node). // When true, ensure the response includes the node's self info so it sees // its own attribute changes (e.g., tags changed via admin API). isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID var ( mapResp *tailcfg.MapResponse err error ) // Track metric using categorized type, not free-form reason mapResponseGenerated.WithLabelValues(r.Type()).Inc() // Check if this requires runtime peer visibility computation (e.g., policy changes) if r.RequiresRuntimePeerComputation { currentPeers := mapper.state.ListPeers(nodeID) currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len()) for _, peer := range currentPeers.All() { currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID()) } removedPeers := nc.computePeerDiff(currentPeerIDs) // Include self node when this is a self-update (e.g., node's own tags changed) // so the node sees its updated self info along with new packet filters. mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate) } else if isSelfUpdate { // Non-policy self-update: just send the self node info mapResp, err = mapper.selfMapResponse(nodeID, version) } else { mapResp, err = mapper.buildFromChange(nodeID, version, &r) } if err != nil { return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err) } return mapResp, nil } // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { return ErrNodeConnectionNil } nodeID := nc.nodeID() log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started") data, err := generateMapResponse(nc, mapper, r) if err != nil { return fmt.Errorf("generating map response for node %d: %w", nodeID, err) } if data == nil { // No data to send is valid for some response types return nil } // Send the map response err = nc.send(data) if err != nil { return fmt.Errorf("sending map response to node %d: %w", nodeID, err) } // Update peer tracking after successful send nc.updateSentPeers(data) return nil } // workResult represents the result of processing a change. type workResult struct { mapResponse *tailcfg.MapResponse err error } // work represents a unit of work to be processed by workers. // All pending changes for a node are bundled into a single work item // so that one worker processes them sequentially. This prevents // out-of-order MapResponse delivery and races on lastSentPeers // that occur when multiple workers process changes for the same node. type work struct { changes []change.Change nodeID types.NodeID resultCh chan<- workResult // optional channel for synchronous operations } // Batcher errors. var ( errConnectionClosed = errors.New("connection channel already closed") ErrInitialMapSendTimeout = errors.New("sending initial map: timeout") ErrBatcherShuttingDown = errors.New("batcher shutting down") ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)") ) // Batcher batches and distributes map responses to connected nodes. // It uses concurrent maps, per-node mutexes, and a worker pool. // // Lifecycle: Call Start() to spawn workers, then Close() to shut down. // Close() blocks until all workers have exited. A Batcher must not // be reused after Close(). type Batcher struct { tick *time.Ticker mapper *mapper workers int nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] // Work queue channel workCh chan work done chan struct{} doneOnce sync.Once // Ensures done is only closed once // wg tracks the doWork and all worker goroutines so that Close() // can block until they have fully exited. wg sync.WaitGroup started atomic.Bool // Ensures Start() is only called once // Metrics totalNodes atomic.Int64 workQueuedCount atomic.Int64 workProcessed atomic.Int64 workErrors atomic.Int64 } // AddNode registers a new node connection with the batcher and sends an initial map response. // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. // The stop function tears down the owning session if this connection is later declared stale. func (b *Batcher) AddNode( id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func(), ) error { addNodeStart := time.Now() nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() // Generate connection ID connID := generateConnectionID() // Create new connection entry now := time.Now() newEntry := &connectionEntry{ id: connID, c: c, version: version, created: now, stop: stop, } // Initialize last used timestamp newEntry.lastUsed.Store(now.Unix()) // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper)) if !loaded { b.totalNodes.Add(1) } // Add connection to the list (lock-free) nodeConn.addConnection(newEntry) // Use the worker pool for controlled concurrency instead of direct generation initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) if err != nil { nlog.Error().Err(err).Msg("initial map generation failed") nodeConn.removeConnectionByChannel(c) if !nodeConn.hasActiveConnections() { nodeConn.markDisconnected() } return fmt.Errorf("generating initial map for node %d: %w", id, err) } // Use a blocking send with timeout for initial map since the channel should be ready // and we want to avoid the race condition where the receiver isn't ready yet select { case c <- initialMap: // Success case <-time.After(5 * time.Second): //nolint:mnd nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout") nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd Msg("initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) if !nodeConn.hasActiveConnections() { nodeConn.markDisconnected() } return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) } // Mark the node as connected now that the initial map was sent. nodeConn.markConnected() // Node will automatically receive updates through the normal flow // The initial full map already contains all current state nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("node connection established in batcher") return nil } // RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. // It validates the connection channel matches one of the current connections, closes that specific connection, // and keeps the node entry alive for rapid reconnections instead of aggressive deletion. // Reports if the node still has active connections after removal. func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() nodeConn, exists := b.nodes.Load(id) if !exists || nodeConn == nil { nlog.Debug().Caller().Msg("removeNode called for non-existent node") return false } // Remove specific connection removed := nodeConn.removeConnectionByChannel(c) if !removed { nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid") } // Check if node has any remaining active connections if nodeConn.hasActiveConnections() { nlog.Debug().Caller(). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("node connection removed but keeping online, other connections remain") return true // Node still has active connections } // No active connections - keep the node entry alive for rapid reconnections // The node will get a fresh full map when it reconnects nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection") nodeConn.markDisconnected() return false } // AddWork queues a change to be processed by the batcher. func (b *Batcher) AddWork(r ...change.Change) { b.addToBatch(r...) } func (b *Batcher) Start() { if !b.started.CompareAndSwap(false, true) { return } b.wg.Add(1) go b.doWork() } func (b *Batcher) Close() { // Signal shutdown to all goroutines, only once. // Workers and queueWork both select on done, so closing it // is sufficient for graceful shutdown. We intentionally do NOT // close workCh here because processBatchedChanges or // MapResponseFromChange may still be sending on it concurrently. b.doneOnce.Do(func() { close(b.done) }) // Wait for all worker goroutines (and doWork) to exit before // tearing down node connections. This prevents workers from // sending on connections that are being closed concurrently. b.wg.Wait() // Stop the ticker to prevent resource leaks. b.tick.Stop() // Close the underlying channels supplying the data to the clients. b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { if conn == nil { return true } conn.close() return true }) } func (b *Batcher) doWork() { defer b.wg.Done() for i := range b.workers { b.wg.Add(1) go b.worker(i + 1) } // Create a cleanup ticker for removing truly disconnected nodes cleanupTicker := time.NewTicker(5 * time.Minute) defer cleanupTicker.Stop() for { select { case <-b.tick.C: // Process batched changes b.processBatchedChanges() case <-cleanupTicker.C: // Clean up nodes that have been offline for too long b.cleanupOfflineNodes() case <-b.done: log.Info().Msg("batcher done channel closed, stopping to feed workers") return } } } func (b *Batcher) worker(workerID int) { defer b.wg.Done() wlog := log.With().Int(zf.WorkerID, workerID).Logger() for { select { case w, ok := <-b.workCh: if !ok { wlog.Debug().Msg("worker channel closing, shutting down") return } b.workProcessed.Add(1) // Synchronous path: a caller is blocking on resultCh // waiting for a generated MapResponse (used by AddNode // for the initial map). Always contains a single change. if w.resultCh != nil { var result workResult if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { // Hold workMu so concurrent async work for this // node waits until the initial map is sent. nc.workMu.Lock() var err error result.mapResponse, err = generateMapResponse(nc, b.mapper, w.changes[0]) result.err = err if result.err != nil { b.workErrors.Add(1) wlog.Error().Err(result.err). Uint64(zf.NodeID, w.nodeID.Uint64()). Str(zf.Reason, w.changes[0].Reason). Msg("failed to generate map response for synchronous work") } else if result.mapResponse != nil { nc.updateSentPeers(result.mapResponse) } nc.workMu.Unlock() } else { result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID) b.workErrors.Add(1) wlog.Error().Err(result.err). Uint64(zf.NodeID, w.nodeID.Uint64()). Msg("node not found for synchronous work") } select { case w.resultCh <- result: case <-b.done: return } continue } // Async path: process all bundled changes sequentially. // workMu ensures that if another worker picks up the next // tick's bundle for the same node, it waits until we // finish — preventing out-of-order delivery and races // on lastSentPeers (Clear+Store vs Range). if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { nc.workMu.Lock() for _, ch := range w.changes { err := nc.change(ch) if err != nil { b.workErrors.Add(1) wlog.Error().Err(err). Uint64(zf.NodeID, w.nodeID.Uint64()). Str(zf.Reason, ch.Reason). Msg("failed to apply change") } } nc.workMu.Unlock() } case <-b.done: wlog.Debug().Msg("batcher shutting down, exiting worker") return } } } // queueWork safely queues work. func (b *Batcher) queueWork(w work) { b.workQueuedCount.Add(1) select { case b.workCh <- w: // Successfully queued case <-b.done: // Batcher is shutting down return } } // addToBatch adds changes to the pending batch. func (b *Batcher) addToBatch(changes ...change.Change) { // Clean up any nodes being permanently removed from the system. // // This handles the case where a node is deleted from state but the batcher // still has it registered. By cleaning up here, we prevent "node not found" // errors when workers try to generate map responses for deleted nodes. // // Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually // deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy // changes that affect peer visibility do NOT use this field - they set // RequiresRuntimePeerComputation=true and compute removed peers at runtime, // putting them in tailcfg.MapResponse.PeersRemoved (a different struct). // Therefore, this cleanup only removes nodes that are truly being deleted, // not nodes that are still connected but have lost visibility of certain peers. // // See: https://github.com/juanfont/headscale/issues/2924 for _, ch := range changes { for _, removedID := range ch.PeersRemoved { if _, existed := b.nodes.LoadAndDelete(removedID); existed { b.totalNodes.Add(-1) log.Debug(). Uint64(zf.NodeID, removedID.Uint64()). Msg("removed deleted node from batcher") } } } // Short circuit if any of the changes is a full update, which // means we can skip sending individual changes. if change.HasFull(changes) { b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } nc.pendingMu.Lock() nc.pending = []change.Change{change.FullUpdate()} nc.pendingMu.Unlock() return true }) return } broadcast, targeted := change.SplitTargetedAndBroadcast(changes) // Handle targeted changes - send only to the specific node for _, ch := range targeted { if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil { nc.appendPending(ch) } } // Handle broadcast changes - send to all nodes, filtering as needed if len(broadcast) > 0 { b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } filtered := change.FilterForNode(nodeID, broadcast) if len(filtered) > 0 { nc.appendPending(filtered...) } return true }) } } // processBatchedChanges processes all pending batched changes. func (b *Batcher) processBatchedChanges() { b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } pending := nc.drainPending() if len(pending) == 0 { return true } // Queue a single work item containing all pending changes. // One item per node ensures a single worker processes them // sequentially, preventing out-of-order delivery. b.queueWork(work{changes: pending, nodeID: nodeID, resultCh: nil}) return true }) } // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. // Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node // reconnects between the hasActiveConnections() check and the Delete() call. func (b *Batcher) cleanupOfflineNodes() { var nodesToCleanup []types.NodeID // Find nodes that have been offline for too long by scanning b.nodes // and checking each node's disconnectedAt timestamp. b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { if nc != nil && !nc.hasActiveConnections() && nc.offlineDuration() > offlineNodeCleanupThreshold { nodesToCleanup = append(nodesToCleanup, nodeID) } return true }) // Clean up the identified nodes using Compute() for atomic check-and-delete. // This prevents a TOCTOU race where a node reconnects (adding an active // connection) between the hasActiveConnections() check and the Delete() call. cleaned := 0 for _, nodeID := range nodesToCleanup { b.nodes.Compute( nodeID, func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) { if !loaded || conn == nil || conn.hasActiveConnections() { return conn, xsync.CancelOp } // Perform all bookkeeping inside the Compute callback so // that a concurrent AddNode (which calls LoadOrStore on // b.nodes) cannot slip in between the delete and the // counter update. b.totalNodes.Add(-1) cleaned++ log.Info().Uint64(zf.NodeID, nodeID.Uint64()). Dur("offline_duration", offlineNodeCleanupThreshold). Msg("cleaning up node that has been offline for too long") return conn, xsync.DeleteOp }, ) } if cleaned > 0 { log.Info().Int(zf.CleanedNodes, cleaned). Msg("completed cleanup of long-offline nodes") } } // IsConnected is a lock-free read that checks if a node is connected. // A node is considered connected if it has active connections or has // not been marked as disconnected. func (b *Batcher) IsConnected(id types.NodeID) bool { nodeConn, exists := b.nodes.Load(id) if !exists || nodeConn == nil { return false } return nodeConn.isConnected() } // ConnectedMap returns a lock-free map of all known nodes and their // connection status (true = connected, false = disconnected). func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret := xsync.NewMap[types.NodeID, bool]() b.nodes.Range(func(id types.NodeID, nc *multiChannelNodeConn) bool { if nc != nil { ret.Store(id, nc.isConnected()) } return true }) return ret } // MapResponseFromChange queues work to generate a map response and waits for the result. // This allows synchronous map generation using the same worker pool. func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { resultCh := make(chan workResult, 1) // Queue the work with a result channel using the safe queueing method b.queueWork(work{changes: []change.Change{ch}, nodeID: id, resultCh: resultCh}) // Wait for the result select { case result := <-resultCh: return result.mapResponse, result.err case <-b.done: return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id) } } // DebugNodeInfo contains debug information about a node's connections. type DebugNodeInfo struct { Connected bool `json:"connected"` ActiveConnections int `json:"active_connections"` } // Debug returns a pre-baked map of node debug information for the debug interface. func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo { result := make(map[types.NodeID]DebugNodeInfo) b.nodes.Range(func(id types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } result[id] = DebugNodeInfo{ Connected: nc.isConnected(), ActiveConnections: nc.getActiveConnectionCount(), } return true }) return result } func (b *Batcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { return b.mapper.debugMapResponses() } // WorkErrors returns the count of work errors encountered. // This is primarily useful for testing and debugging. func (b *Batcher) WorkErrors() int64 { return b.workErrors.Load() }