package mapper import ( "errors" "fmt" "sync" "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util/zlog/zf" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) // Mapper errors. var ( ErrInvalidNodeID = errors.New("invalid nodeID") ErrMapperNil = errors.New("mapper is nil") ErrNodeConnectionNil = errors.New("nodeConnection is nil") ErrNodeNotFoundMapper = errors.New("node not found") ) // offlineNodeCleanupThreshold is how long a node must be disconnected // before cleanupOfflineNodes removes its in-memory state. const offlineNodeCleanupThreshold = 15 * time.Minute var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", Help: "total count of mapresponses generated by response type", }, []string{"response_type"}) func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher { return &Batcher{ mapper: mapper, workers: workers, tick: time.NewTicker(batchTime), // The size of this channel is arbitrary chosen, the sizing should be revisited. workCh: make(chan work, workers*200), done: make(chan struct{}), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), } } // NewBatcherAndMapper creates a new Batcher with its mapper. func NewBatcherAndMapper(cfg *types.Config, state *state.State) *Batcher { m := newMapper(cfg, state) b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) m.batcher = b return b } // nodeConnection interface for different connection implementations. type nodeConnection interface { nodeID() types.NodeID version() tailcfg.CapabilityVersion send(data *tailcfg.MapResponse) error // computePeerDiff returns peers that were previously sent but are no longer in the current list. computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID) // updateSentPeers updates the tracking of which peers have been sent to this node. updateSentPeers(resp *tailcfg.MapResponse) } // generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change]. func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) { nodeID := nc.nodeID() version := nc.version() if r.IsEmpty() { return nil, nil //nolint:nilnil // Empty response means nothing to send } if nodeID == 0 { return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID) } if mapper == nil { return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID) } // Handle self-only responses if r.IsSelfOnly() && r.TargetNode != nodeID { return nil, nil //nolint:nilnil // No response needed for other nodes when self-only } // Check if this is a self-update (the changed node is the receiving node). // When true, ensure the response includes the node's self info so it sees // its own attribute changes (e.g., tags changed via admin API). isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID var ( mapResp *tailcfg.MapResponse err error ) // Track metric using categorized type, not free-form reason mapResponseGenerated.WithLabelValues(r.Type()).Inc() // Check if this requires runtime peer visibility computation (e.g., policy changes) if r.RequiresRuntimePeerComputation { currentPeers := mapper.state.ListPeers(nodeID) currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len()) for _, peer := range currentPeers.All() { currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID()) } removedPeers := nc.computePeerDiff(currentPeerIDs) // Include self node when this is a self-update (e.g., node's own tags changed) // so the node sees its updated self info along with new packet filters. mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate) } else if isSelfUpdate { // Non-policy self-update: just send the self node info mapResp, err = mapper.selfMapResponse(nodeID, version) } else { mapResp, err = mapper.buildFromChange(nodeID, version, &r) } if err != nil { return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err) } return mapResp, nil } // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { return ErrNodeConnectionNil } nodeID := nc.nodeID() log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started") data, err := generateMapResponse(nc, mapper, r) if err != nil { return fmt.Errorf("generating map response for node %d: %w", nodeID, err) } if data == nil { // No data to send is valid for some response types return nil } // Send the map response err = nc.send(data) if err != nil { return fmt.Errorf("sending map response to node %d: %w", nodeID, err) } // Update peer tracking after successful send nc.updateSentPeers(data) return nil } // workResult represents the result of processing a change. type workResult struct { mapResponse *tailcfg.MapResponse err error } // work represents a unit of work to be processed by workers. type work struct { c change.Change nodeID types.NodeID resultCh chan<- workResult // optional channel for synchronous operations } // Batcher errors. var ( errConnectionClosed = errors.New("connection channel already closed") ErrInitialMapSendTimeout = errors.New("sending initial map: timeout") ErrBatcherShuttingDown = errors.New("batcher shutting down") ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)") ) // Batcher batches and distributes map responses to connected nodes. // It uses concurrent maps, per-node mutexes, and a worker pool. // // Lifecycle: Call Start() to spawn workers, then Close() to shut down. // Close() blocks until all workers have exited. A Batcher must not // be reused after Close(). type Batcher struct { tick *time.Ticker mapper *mapper workers int nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] connected *xsync.Map[types.NodeID, *time.Time] // Work queue channel workCh chan work done chan struct{} doneOnce sync.Once // Ensures done is only closed once // wg tracks the doWork and all worker goroutines so that Close() // can block until they have fully exited. wg sync.WaitGroup started atomic.Bool // Ensures Start() is only called once // Metrics totalNodes atomic.Int64 workQueuedCount atomic.Int64 workProcessed atomic.Int64 workErrors atomic.Int64 } // AddNode registers a new node connection with the batcher and sends an initial map response. // It creates or updates the node's connection data, validates the initial map generation, // and notifies other nodes that this node has come online. // The stop function tears down the owning session if this connection is later declared stale. func (b *Batcher) AddNode( id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func(), ) error { addNodeStart := time.Now() nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() // Generate connection ID connID := generateConnectionID() // Create new connection entry now := time.Now() newEntry := &connectionEntry{ id: connID, c: c, version: version, created: now, stop: stop, } // Initialize last used timestamp newEntry.lastUsed.Store(now.Unix()) // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper)) if !loaded { b.totalNodes.Add(1) } // Add connection to the list (lock-free) nodeConn.addConnection(newEntry) // Use the worker pool for controlled concurrency instead of direct generation initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) if err != nil { nlog.Error().Err(err).Msg("initial map generation failed") nodeConn.removeConnectionByChannel(c) b.markDisconnectedIfNoConns(id, nodeConn) return fmt.Errorf("generating initial map for node %d: %w", id, err) } // Use a blocking send with timeout for initial map since the channel should be ready // and we want to avoid the race condition where the receiver isn't ready yet select { case c <- initialMap: // Success case <-time.After(5 * time.Second): //nolint:mnd nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout") nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd Msg("initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) b.markDisconnectedIfNoConns(id, nodeConn) return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) } // Update connection status b.connected.Store(id, nil) // nil = connected // Node will automatically receive updates through the normal flow // The initial full map already contains all current state nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("node connection established in batcher") return nil } // RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. // It validates the connection channel matches one of the current connections, closes that specific connection, // and keeps the node entry alive for rapid reconnections instead of aggressive deletion. // Reports if the node still has active connections after removal. func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() nodeConn, exists := b.nodes.Load(id) if !exists || nodeConn == nil { nlog.Debug().Caller().Msg("removeNode called for non-existent node") return false } // Remove specific connection removed := nodeConn.removeConnectionByChannel(c) if !removed { nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid") } // Check if node has any remaining active connections if nodeConn.hasActiveConnections() { nlog.Debug().Caller(). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("node connection removed but keeping online, other connections remain") return true // Node still has active connections } // No active connections - keep the node entry alive for rapid reconnections // The node will get a fresh full map when it reconnects nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection") b.connected.Store(id, new(time.Now())) return false } // AddWork queues a change to be processed by the batcher. func (b *Batcher) AddWork(r ...change.Change) { b.addToBatch(r...) } func (b *Batcher) Start() { if !b.started.CompareAndSwap(false, true) { return } b.wg.Add(1) go b.doWork() } func (b *Batcher) Close() { // Signal shutdown to all goroutines, only once. // Workers and queueWork both select on done, so closing it // is sufficient for graceful shutdown. We intentionally do NOT // close workCh here because processBatchedChanges or // MapResponseFromChange may still be sending on it concurrently. b.doneOnce.Do(func() { close(b.done) }) // Wait for all worker goroutines (and doWork) to exit before // tearing down node connections. This prevents workers from // sending on connections that are being closed concurrently. b.wg.Wait() // Stop the ticker to prevent resource leaks. b.tick.Stop() // Close the underlying channels supplying the data to the clients. b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { if conn == nil { return true } conn.close() return true }) } func (b *Batcher) doWork() { defer b.wg.Done() for i := range b.workers { b.wg.Add(1) go b.worker(i + 1) } // Create a cleanup ticker for removing truly disconnected nodes cleanupTicker := time.NewTicker(5 * time.Minute) defer cleanupTicker.Stop() for { select { case <-b.tick.C: // Process batched changes b.processBatchedChanges() case <-cleanupTicker.C: // Clean up nodes that have been offline for too long b.cleanupOfflineNodes() case <-b.done: log.Info().Msg("batcher done channel closed, stopping to feed workers") return } } } func (b *Batcher) worker(workerID int) { defer b.wg.Done() wlog := log.With().Int(zf.WorkerID, workerID).Logger() for { select { case w, ok := <-b.workCh: if !ok { wlog.Debug().Msg("worker channel closing, shutting down") return } b.workProcessed.Add(1) // If the resultCh is set, it means that this is a work request // where there is a blocking function waiting for the map that // is being generated. // This is used for synchronous map generation. if w.resultCh != nil { var result workResult if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { var err error result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) result.err = err if result.err != nil { b.workErrors.Add(1) wlog.Error().Err(result.err). Uint64(zf.NodeID, w.nodeID.Uint64()). Str(zf.Reason, w.c.Reason). Msg("failed to generate map response for synchronous work") } else if result.mapResponse != nil { // Update peer tracking for synchronous responses too nc.updateSentPeers(result.mapResponse) } } else { result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID) b.workErrors.Add(1) wlog.Error().Err(result.err). Uint64(zf.NodeID, w.nodeID.Uint64()). Msg("node not found for synchronous work") } // Send result select { case w.resultCh <- result: case <-b.done: return } continue } // If resultCh is nil, this is an asynchronous work request // that should be processed and sent to the node instead of // returned to the caller. if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { // Apply change to node - this will handle offline nodes gracefully // and queue work for when they reconnect err := nc.change(w.c) if err != nil { b.workErrors.Add(1) wlog.Error().Err(err). Uint64(zf.NodeID, w.nodeID.Uint64()). Str(zf.Reason, w.c.Reason). Msg("failed to apply change") } } case <-b.done: wlog.Debug().Msg("batcher shutting down, exiting worker") return } } } // queueWork safely queues work. func (b *Batcher) queueWork(w work) { b.workQueuedCount.Add(1) select { case b.workCh <- w: // Successfully queued case <-b.done: // Batcher is shutting down return } } // addToBatch adds changes to the pending batch. func (b *Batcher) addToBatch(changes ...change.Change) { // Clean up any nodes being permanently removed from the system. // // This handles the case where a node is deleted from state but the batcher // still has it registered. By cleaning up here, we prevent "node not found" // errors when workers try to generate map responses for deleted nodes. // // Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually // deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy // changes that affect peer visibility do NOT use this field - they set // RequiresRuntimePeerComputation=true and compute removed peers at runtime, // putting them in tailcfg.MapResponse.PeersRemoved (a different struct). // Therefore, this cleanup only removes nodes that are truly being deleted, // not nodes that are still connected but have lost visibility of certain peers. // // See: https://github.com/juanfont/headscale/issues/2924 for _, ch := range changes { for _, removedID := range ch.PeersRemoved { if _, existed := b.nodes.LoadAndDelete(removedID); existed { b.totalNodes.Add(-1) log.Debug(). Uint64(zf.NodeID, removedID.Uint64()). Msg("removed deleted node from batcher") } b.connected.Delete(removedID) } } // Short circuit if any of the changes is a full update, which // means we can skip sending individual changes. if change.HasFull(changes) { b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } nc.pendingMu.Lock() nc.pending = []change.Change{change.FullUpdate()} nc.pendingMu.Unlock() return true }) return } broadcast, targeted := change.SplitTargetedAndBroadcast(changes) // Handle targeted changes - send only to the specific node for _, ch := range targeted { if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil { nc.appendPending(ch) } } // Handle broadcast changes - send to all nodes, filtering as needed if len(broadcast) > 0 { b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } filtered := change.FilterForNode(nodeID, broadcast) if len(filtered) > 0 { nc.appendPending(filtered...) } return true }) } } // processBatchedChanges processes all pending batched changes. func (b *Batcher) processBatchedChanges() { b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { if nc == nil { return true } pending := nc.drainPending() if len(pending) == 0 { return true } // Send all batched changes for this node for _, ch := range pending { b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) } return true }) } // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. // Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node // reconnects between the hasActiveConnections() check and the Delete() call. // TODO(kradalby): reevaluate if we want to keep this. func (b *Batcher) cleanupOfflineNodes() { now := time.Now() var nodesToCleanup []types.NodeID // Find nodes that have been offline for too long b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { if disconnectTime != nil && now.Sub(*disconnectTime) > offlineNodeCleanupThreshold { nodesToCleanup = append(nodesToCleanup, nodeID) } return true }) // Clean up the identified nodes using Compute() for atomic check-and-delete. // This prevents a TOCTOU race where a node reconnects (adding an active // connection) between the hasActiveConnections() check and the Delete() call. cleaned := 0 for _, nodeID := range nodesToCleanup { b.nodes.Compute( nodeID, func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) { if !loaded || conn == nil || conn.hasActiveConnections() { return conn, xsync.CancelOp } // Perform all bookkeeping inside the Compute callback so // that a concurrent AddNode (which calls LoadOrStore on // b.nodes) cannot slip in between the delete and the // connected/counter updates. b.connected.Delete(nodeID) b.totalNodes.Add(-1) cleaned++ log.Info().Uint64(zf.NodeID, nodeID.Uint64()). Dur("offline_duration", offlineNodeCleanupThreshold). Msg("cleaning up node that has been offline for too long") return conn, xsync.DeleteOp }, ) } if cleaned > 0 { log.Info().Int(zf.CleanedNodes, cleaned). Msg("completed cleanup of long-offline nodes") } } // IsConnected is lock-free read that checks if a node has any active connections. func (b *Batcher) IsConnected(id types.NodeID) bool { // First check if we have active connections for this node if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil { if nodeConn.hasActiveConnections() { return true } } // Check disconnected timestamp with grace period val, ok := b.connected.Load(id) if !ok { return false } // nil means connected if val == nil { return true } return false } // ConnectedMap returns a lock-free map of all connected nodes. func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret := xsync.NewMap[types.NodeID, bool]() // First, add all nodes with active connections b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { if nodeConn == nil { return true } if nodeConn.hasActiveConnections() { ret.Store(id, true) } return true }) // Then add all entries from the connected map b.connected.Range(func(id types.NodeID, val *time.Time) bool { // Only add if not already added as connected above if _, exists := ret.Load(id); !exists { if val == nil { // nil means connected ret.Store(id, true) } else { // timestamp means disconnected ret.Store(id, false) } } return true }) return ret } // markDisconnectedIfNoConns stores a disconnect timestamp in b.connected // when the node has no remaining active connections. This prevents // IsConnected from returning a stale true after all connections have been // removed on an error path (e.g. AddNode failure). func (b *Batcher) markDisconnectedIfNoConns(id types.NodeID, nc *multiChannelNodeConn) { if !nc.hasActiveConnections() { now := time.Now() b.connected.Store(id, &now) } } // MapResponseFromChange queues work to generate a map response and waits for the result. // This allows synchronous map generation using the same worker pool. func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { resultCh := make(chan workResult, 1) // Queue the work with a result channel using the safe queueing method b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh}) // Wait for the result select { case result := <-resultCh: return result.mapResponse, result.err case <-b.done: return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id) } } // DebugNodeInfo contains debug information about a node's connections. type DebugNodeInfo struct { Connected bool `json:"connected"` ActiveConnections int `json:"active_connections"` } // Debug returns a pre-baked map of node debug information for the debug interface. func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo { result := make(map[types.NodeID]DebugNodeInfo) // Get all nodes with their connection status using immediate connection logic // (no grace period) for debug purposes b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { if nodeConn == nil { return true } activeConnCount := nodeConn.getActiveConnectionCount() // Use immediate connection status: if active connections exist, node is connected // If not, check the connected map for nil (connected) vs timestamp (disconnected) connected := false if activeConnCount > 0 { connected = true } else { // Check connected map for immediate status if val, ok := b.connected.Load(id); ok && val == nil { connected = true } } result[id] = DebugNodeInfo{ Connected: connected, ActiveConnections: activeConnCount, } return true }) // Add all entries from the connected map to capture both connected and disconnected nodes b.connected.Range(func(id types.NodeID, val *time.Time) bool { // Only add if not already processed above if _, exists := result[id]; !exists { // Use immediate connection status for debug (no grace period) connected := (val == nil) // nil means connected, timestamp means disconnected result[id] = DebugNodeInfo{ Connected: connected, ActiveConnections: 0, } } return true }) return result } func (b *Batcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { return b.mapper.debugMapResponses() } // WorkErrors returns the count of work errors encountered. // This is primarily useful for testing and debugging. func (b *Batcher) WorkErrors() int64 { return b.workErrors.Load() }