diff --git a/hscontrol/app.go b/hscontrol/app.go index 87da6f87..c57c7be0 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -101,7 +101,7 @@ type Headscale struct { // Things that generate changes extraRecordMan *dns.ExtraRecordsMan authProvider AuthProvider - mapBatcher mapper.Batcher + mapBatcher *mapper.Batcher clientStreamsOpen sync.WaitGroup } diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 93200b95..52189af8 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/arl/statsviz" - "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/prometheus/client_golang/prometheus/promhttp" "tailscale.com/tsweb" @@ -329,38 +328,18 @@ func (h *Headscale) debugBatcher() string { var nodes []nodeStatus - // Try to get detailed debug info if we have a LockFreeBatcher - if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok { - debugInfo := batcher.Debug() - for nodeID, info := range debugInfo { - nodes = append(nodes, nodeStatus{ - id: nodeID, - connected: info.Connected, - activeConnections: info.ActiveConnections, - }) - totalNodes++ - - if info.Connected { - connectedCount++ - } - } - } else { - // Fallback to basic connection info - connectedMap := h.mapBatcher.ConnectedMap() - connectedMap.Range(func(nodeID types.NodeID, connected bool) bool { - nodes = append(nodes, nodeStatus{ - id: nodeID, - connected: connected, - activeConnections: 0, - }) - totalNodes++ - - if connected { - connectedCount++ - } - - return true + debugInfo := h.mapBatcher.Debug() + for nodeID, info := range debugInfo { + nodes = append(nodes, nodeStatus{ + id: nodeID, + connected: info.Connected, + activeConnections: info.ActiveConnections, }) + totalNodes++ + + if info.Connected { + connectedCount++ + } } // Sort by node ID @@ -410,28 +389,13 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { TotalNodes: 0, } - // Try to get detailed debug info if we have a LockFreeBatcher - if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok { - debugInfo := batcher.Debug() - for nodeID, debugData := range debugInfo { - info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ - Connected: debugData.Connected, - ActiveConnections: debugData.ActiveConnections, - } - info.TotalNodes++ + debugInfo := h.mapBatcher.Debug() + for nodeID, debugData := range debugInfo { + info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ + Connected: debugData.Connected, + ActiveConnections: debugData.ActiveConnections, } - } else { - // Fallback to basic connection info - connectedMap := h.mapBatcher.ConnectedMap() - connectedMap.Range(func(nodeID types.NodeID, connected bool) bool { - info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ - Connected: connected, - ActiveConnections: 0, - } - info.TotalNodes++ - - return true - }) + info.TotalNodes++ } return info diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 2c1bf94e..becaec71 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -1,8 +1,12 @@ package mapper import ( + "crypto/rand" + "encoding/hex" "errors" "fmt" + "sync" + "sync/atomic" "time" "github.com/juanfont/headscale/hscontrol/state" @@ -12,6 +16,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" ) @@ -30,23 +35,8 @@ var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Help: "total count of mapresponses generated by response type", }, []string{"response_type"}) -type batcherFunc func(cfg *types.Config, state *state.State) Batcher - -// Batcher defines the common interface for all batcher implementations. -type Batcher interface { - Start() - Close() - AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error - RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool - IsConnected(id types.NodeID) bool - ConnectedMap() *xsync.Map[types.NodeID, bool] - AddWork(r ...change.Change) - MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) - DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) -} - -func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher { - return &LockFreeBatcher{ +func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher { + return &Batcher{ mapper: mapper, workers: workers, tick: time.NewTicker(batchTime), @@ -58,8 +48,8 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB } } -// NewBatcherAndMapper creates a Batcher implementation. -func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher { +// NewBatcherAndMapper creates a new Batcher with its mapper. +func NewBatcherAndMapper(cfg *types.Config, state *state.State) *Batcher { m := newMapper(cfg, state) b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) m.batcher = b @@ -184,3 +174,973 @@ type work struct { nodeID types.NodeID resultCh chan<- workResult // optional channel for synchronous operations } + +// Batcher errors. +var ( + errConnectionClosed = errors.New("connection channel already closed") + ErrInitialMapSendTimeout = errors.New("sending initial map: timeout") + ErrBatcherShuttingDown = errors.New("batcher shutting down") + ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)") +) + +// Batcher batches and distributes map responses to connected nodes. +// It uses concurrent maps, per-node mutexes, and a worker pool. +type Batcher struct { + tick *time.Ticker + mapper *mapper + workers int + + nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] + connected *xsync.Map[types.NodeID, *time.Time] + + // Work queue channel + workCh chan work + done chan struct{} + doneOnce sync.Once // Ensures done is only closed once + + started atomic.Bool // Ensures Start() is only called once + + // Metrics + totalNodes atomic.Int64 + workQueuedCount atomic.Int64 + workProcessed atomic.Int64 + workErrors atomic.Int64 +} + +// AddNode registers a new node connection with the batcher and sends an initial map response. +// It creates or updates the node's connection data, validates the initial map generation, +// and notifies other nodes that this node has come online. +// The stop function tears down the owning session if this connection is later declared stale. +func (b *Batcher) AddNode( + id types.NodeID, + c chan<- *tailcfg.MapResponse, + version tailcfg.CapabilityVersion, + stop func(), +) error { + addNodeStart := time.Now() + nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() + + // Generate connection ID + connID := generateConnectionID() + + // Create new connection entry + now := time.Now() + newEntry := &connectionEntry{ + id: connID, + c: c, + version: version, + created: now, + stop: stop, + } + // Initialize last used timestamp + newEntry.lastUsed.Store(now.Unix()) + + // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection + nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper)) + + if !loaded { + b.totalNodes.Add(1) + } + + // Add connection to the list (lock-free) + nodeConn.addConnection(newEntry) + + // Use the worker pool for controlled concurrency instead of direct generation + initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) + if err != nil { + nlog.Error().Err(err).Msg("initial map generation failed") + nodeConn.removeConnectionByChannel(c) + + return fmt.Errorf("generating initial map for node %d: %w", id, err) + } + + // Use a blocking send with timeout for initial map since the channel should be ready + // and we want to avoid the race condition where the receiver isn't ready yet + select { + case c <- initialMap: + // Success + case <-time.After(5 * time.Second): //nolint:mnd + nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout") + nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd + Msg("initial map send timed out because channel was blocked or receiver not ready") + nodeConn.removeConnectionByChannel(c) + + return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) + } + + // Update connection status + b.connected.Store(id, nil) // nil = connected + + // Node will automatically receive updates through the normal flow + // The initial full map already contains all current state + + nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("node connection established in batcher") + + return nil +} + +// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. +// It validates the connection channel matches one of the current connections, closes that specific connection, +// and keeps the node entry alive for rapid reconnections instead of aggressive deletion. +// Reports if the node still has active connections after removal. +func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() + + nodeConn, exists := b.nodes.Load(id) + if !exists || nodeConn == nil { + nlog.Debug().Caller().Msg("removeNode called for non-existent node") + return false + } + + // Remove specific connection + removed := nodeConn.removeConnectionByChannel(c) + if !removed { + nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid") + } + + // Check if node has any remaining active connections + if nodeConn.hasActiveConnections() { + nlog.Debug().Caller(). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("node connection removed but keeping online, other connections remain") + + return true // Node still has active connections + } + + // No active connections - keep the node entry alive for rapid reconnections + // The node will get a fresh full map when it reconnects + nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection") + b.connected.Store(id, new(time.Now())) + + return false +} + +// AddWork queues a change to be processed by the batcher. +func (b *Batcher) AddWork(r ...change.Change) { + b.addWork(r...) +} + +func (b *Batcher) Start() { + if !b.started.CompareAndSwap(false, true) { + return + } + + b.done = make(chan struct{}) + + go b.doWork() +} + +func (b *Batcher) Close() { + // Signal shutdown to all goroutines, only once. + // Workers and queueWork both select on done, so closing it + // is sufficient for graceful shutdown. We intentionally do NOT + // close workCh here because processBatchedChanges or + // MapResponseFromChange may still be sending on it concurrently. + b.doneOnce.Do(func() { + if b.done != nil { + close(b.done) + } + }) + + // Close the underlying channels supplying the data to the clients. + b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { + if conn == nil { + return true + } + + conn.close() + + return true + }) +} + +func (b *Batcher) doWork() { + for i := range b.workers { + go b.worker(i + 1) + } + + // Create a cleanup ticker for removing truly disconnected nodes + cleanupTicker := time.NewTicker(5 * time.Minute) + defer cleanupTicker.Stop() + + for { + select { + case <-b.tick.C: + // Process batched changes + b.processBatchedChanges() + case <-cleanupTicker.C: + // Clean up nodes that have been offline for too long + b.cleanupOfflineNodes() + case <-b.done: + log.Info().Msg("batcher done channel closed, stopping to feed workers") + return + } + } +} + +func (b *Batcher) worker(workerID int) { + wlog := log.With().Int(zf.WorkerID, workerID).Logger() + + for { + select { + case w, ok := <-b.workCh: + if !ok { + wlog.Debug().Msg("worker channel closing, shutting down") + return + } + + b.workProcessed.Add(1) + + // If the resultCh is set, it means that this is a work request + // where there is a blocking function waiting for the map that + // is being generated. + // This is used for synchronous map generation. + if w.resultCh != nil { + var result workResult + + if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { + var err error + + result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + + result.err = err + if result.err != nil { + b.workErrors.Add(1) + wlog.Error().Err(result.err). + Uint64(zf.NodeID, w.nodeID.Uint64()). + Str(zf.Reason, w.c.Reason). + Msg("failed to generate map response for synchronous work") + } else if result.mapResponse != nil { + // Update peer tracking for synchronous responses too + nc.updateSentPeers(result.mapResponse) + } + } else { + result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID) + + b.workErrors.Add(1) + wlog.Error().Err(result.err). + Uint64(zf.NodeID, w.nodeID.Uint64()). + Msg("node not found for synchronous work") + } + + // Send result + select { + case w.resultCh <- result: + case <-b.done: + return + } + + continue + } + + // If resultCh is nil, this is an asynchronous work request + // that should be processed and sent to the node instead of + // returned to the caller. + if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { + // Apply change to node - this will handle offline nodes gracefully + // and queue work for when they reconnect + err := nc.change(w.c) + if err != nil { + b.workErrors.Add(1) + wlog.Error().Err(err). + Uint64(zf.NodeID, w.nodeID.Uint64()). + Str(zf.Reason, w.c.Reason). + Msg("failed to apply change") + } + } + case <-b.done: + wlog.Debug().Msg("batcher shutting down, exiting worker") + return + } + } +} + +func (b *Batcher) addWork(r ...change.Change) { + b.addToBatch(r...) +} + +// queueWork safely queues work. +func (b *Batcher) queueWork(w work) { + b.workQueuedCount.Add(1) + + select { + case b.workCh <- w: + // Successfully queued + case <-b.done: + // Batcher is shutting down + return + } +} + +// addToBatch adds changes to the pending batch. +func (b *Batcher) addToBatch(changes ...change.Change) { + // Clean up any nodes being permanently removed from the system. + // + // This handles the case where a node is deleted from state but the batcher + // still has it registered. By cleaning up here, we prevent "node not found" + // errors when workers try to generate map responses for deleted nodes. + // + // Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually + // deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy + // changes that affect peer visibility do NOT use this field - they set + // RequiresRuntimePeerComputation=true and compute removed peers at runtime, + // putting them in tailcfg.MapResponse.PeersRemoved (a different struct). + // Therefore, this cleanup only removes nodes that are truly being deleted, + // not nodes that are still connected but have lost visibility of certain peers. + // + // See: https://github.com/juanfont/headscale/issues/2924 + for _, ch := range changes { + for _, removedID := range ch.PeersRemoved { + if _, existed := b.nodes.LoadAndDelete(removedID); existed { + b.totalNodes.Add(-1) + log.Debug(). + Uint64(zf.NodeID, removedID.Uint64()). + Msg("removed deleted node from batcher") + } + + b.connected.Delete(removedID) + } + } + + // Short circuit if any of the changes is a full update, which + // means we can skip sending individual changes. + if change.HasFull(changes) { + b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { + if nc == nil { + return true + } + + nc.pendingMu.Lock() + nc.pending = []change.Change{change.FullUpdate()} + nc.pendingMu.Unlock() + + return true + }) + + return + } + + broadcast, targeted := change.SplitTargetedAndBroadcast(changes) + + // Handle targeted changes - send only to the specific node + for _, ch := range targeted { + if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil { + nc.appendPending(ch) + } + } + + // Handle broadcast changes - send to all nodes, filtering as needed + if len(broadcast) > 0 { + b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { + if nc == nil { + return true + } + + filtered := change.FilterForNode(nodeID, broadcast) + + if len(filtered) > 0 { + nc.appendPending(filtered...) + } + + return true + }) + } +} + +// processBatchedChanges processes all pending batched changes. +func (b *Batcher) processBatchedChanges() { + b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { + if nc == nil { + return true + } + + pending := nc.drainPending() + if len(pending) == 0 { + return true + } + + // Send all batched changes for this node + for _, ch := range pending { + b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) + } + + return true + }) +} + +// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. +// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node +// reconnects between the hasActiveConnections() check and the Delete() call. +// TODO(kradalby): reevaluate if we want to keep this. +func (b *Batcher) cleanupOfflineNodes() { + cleanupThreshold := 15 * time.Minute + now := time.Now() + + var nodesToCleanup []types.NodeID + + // Find nodes that have been offline for too long + b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { + if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold { + nodesToCleanup = append(nodesToCleanup, nodeID) + } + + return true + }) + + // Clean up the identified nodes using Compute() for atomic check-and-delete. + // This prevents a TOCTOU race where a node reconnects (adding an active + // connection) between the hasActiveConnections() check and the Delete() call. + cleaned := 0 + + for _, nodeID := range nodesToCleanup { + deleted := false + + b.nodes.Compute( + nodeID, + func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) { + if !loaded || conn == nil || conn.hasActiveConnections() { + return conn, xsync.CancelOp + } + + deleted = true + + return conn, xsync.DeleteOp + }, + ) + + if deleted { + log.Info().Uint64(zf.NodeID, nodeID.Uint64()). + Dur("offline_duration", cleanupThreshold). + Msg("cleaning up node that has been offline for too long") + + b.connected.Delete(nodeID) + b.totalNodes.Add(-1) + + cleaned++ + } + } + + if cleaned > 0 { + log.Info().Int(zf.CleanedNodes, cleaned). + Msg("completed cleanup of long-offline nodes") + } +} + +// IsConnected is lock-free read that checks if a node has any active connections. +func (b *Batcher) IsConnected(id types.NodeID) bool { + // First check if we have active connections for this node + if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil { + if nodeConn.hasActiveConnections() { + return true + } + } + + // Check disconnected timestamp with grace period + val, ok := b.connected.Load(id) + if !ok { + return false + } + + // nil means connected + if val == nil { + return true + } + + return false +} + +// ConnectedMap returns a lock-free map of all connected nodes. +func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { + ret := xsync.NewMap[types.NodeID, bool]() + + // First, add all nodes with active connections + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + if nodeConn == nil { + return true + } + + if nodeConn.hasActiveConnections() { + ret.Store(id, true) + } + + return true + }) + + // Then add all entries from the connected map + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // Only add if not already added as connected above + if _, exists := ret.Load(id); !exists { + if val == nil { + // nil means connected + ret.Store(id, true) + } else { + // timestamp means disconnected + ret.Store(id, false) + } + } + + return true + }) + + return ret +} + +// MapResponseFromChange queues work to generate a map response and waits for the result. +// This allows synchronous map generation using the same worker pool. +func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { + resultCh := make(chan workResult, 1) + + // Queue the work with a result channel using the safe queueing method + b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh}) + + // Wait for the result + select { + case result := <-resultCh: + return result.mapResponse, result.err + case <-b.done: + return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id) + } +} + +// connectionEntry represents a single connection to a node. +type connectionEntry struct { + id string // unique connection ID + c chan<- *tailcfg.MapResponse + version tailcfg.CapabilityVersion + created time.Time + stop func() + lastUsed atomic.Int64 // Unix timestamp of last successful send + closed atomic.Bool // Indicates if this connection has been closed +} + +// multiChannelNodeConn manages multiple concurrent connections for a single node. +type multiChannelNodeConn struct { + id types.NodeID + mapper *mapper + log zerolog.Logger + + mutex sync.RWMutex + connections []*connectionEntry + + // pendingMu protects pending changes independently of the connection mutex. + // This avoids contention between addToBatch (which appends changes) and + // send() (which sends data to connections). + pendingMu sync.Mutex + pending []change.Change + + closeOnce sync.Once + updateCount atomic.Int64 + + // lastSentPeers tracks which peers were last sent to this node. + // This enables computing diffs for policy changes instead of sending + // full peer lists (which clients interpret as "no change" when empty). + // Using xsync.Map for lock-free concurrent access. + lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}] +} + +// generateConnectionID generates a unique connection identifier. +func generateConnectionID() string { + bytes := make([]byte, 8) + _, _ = rand.Read(bytes) + + return hex.EncodeToString(bytes) +} + +// newMultiChannelNodeConn creates a new multi-channel node connection. +func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn { + return &multiChannelNodeConn{ + id: id, + mapper: mapper, + lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](), + log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(), + } +} + +func (mc *multiChannelNodeConn) close() { + mc.closeOnce.Do(func() { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for _, conn := range mc.connections { + mc.stopConnection(conn) + } + }) +} + +// stopConnection marks a connection as closed and tears down the owning session +// at most once, even if multiple cleanup paths race to remove it. +func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) { + if conn.closed.CompareAndSwap(false, true) { + if conn.stop != nil { + conn.stop() + } + } +} + +// removeConnectionAtIndexLocked removes the active connection at index. +// If stopConnection is true, it also stops that session. +// Caller must hold mc.mutex. +func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry { + conn := mc.connections[i] + mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) + + if stopConnection { + mc.stopConnection(conn) + } + + return conn +} + +// addConnection adds a new connection. +func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { + mutexWaitStart := time.Now() + + mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id). + Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") + + mc.mutex.Lock() + + mutexWaitDur := time.Since(mutexWaitStart) + + defer mc.mutex.Unlock() + + mc.connections = append(mc.connections, entry) + mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id). + Int("total_connections", len(mc.connections)). + Dur("mutex_wait_time", mutexWaitDur). + Msg("successfully added connection after mutex wait") +} + +// removeConnectionByChannel removes a connection by matching channel pointer. +func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for i, entry := range mc.connections { + if entry.c == c { + mc.removeConnectionAtIndexLocked(i, false) + mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)). + Int("remaining_connections", len(mc.connections)). + Msg("successfully removed connection") + + return true + } + } + + return false +} + +// hasActiveConnections checks if the node has any active connections. +func (mc *multiChannelNodeConn) hasActiveConnections() bool { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + return len(mc.connections) > 0 +} + +// getActiveConnectionCount returns the number of active connections. +func (mc *multiChannelNodeConn) getActiveConnectionCount() int { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + return len(mc.connections) +} + +// appendPending appends changes to this node's pending change list. +// Thread-safe via pendingMu; does not contend with the connection mutex. +func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) { + mc.pendingMu.Lock() + mc.pending = append(mc.pending, changes...) + mc.pendingMu.Unlock() +} + +// drainPending atomically removes and returns all pending changes. +// Returns nil if there are no pending changes. +func (mc *multiChannelNodeConn) drainPending() []change.Change { + mc.pendingMu.Lock() + p := mc.pending + mc.pending = nil + mc.pendingMu.Unlock() + + return p +} + +// send broadcasts data to all active connections for the node. +// +// To avoid holding the write lock during potentially slow sends (each stale +// connection can block for up to 50ms), the method snapshots connections under +// a read lock, sends without any lock held, then write-locks only to remove +// failures. New connections added between the snapshot and cleanup are safe: +// they receive a full initial map via AddNode, so missing this update causes +// no data loss. +func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + + // Snapshot connections under read lock. + mc.mutex.RLock() + + if len(mc.connections) == 0 { + mc.mutex.RUnlock() + mc.log.Debug().Caller(). + Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + + return nil + } + + // Copy the slice so we can release the read lock before sending. + snapshot := make([]*connectionEntry, len(mc.connections)) + copy(snapshot, mc.connections) + mc.mutex.RUnlock() + + mc.log.Debug().Caller(). + Int("total_connections", len(snapshot)). + Msg("send: broadcasting to all connections") + + // Send to all connections without holding any lock. + // Stale connection timeouts (50ms each) happen here without blocking + // other goroutines that need the mutex. + var ( + lastErr error + successCount int + failed []*connectionEntry + ) + + for _, conn := range snapshot { + mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)). + Str(zf.ConnID, conn.id). + Msg("send: attempting to send to connection") + + err := conn.send(data) + if err != nil { + lastErr = err + + failed = append(failed, conn) + + mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)). + Str(zf.ConnID, conn.id). + Msg("send: connection send failed") + } else { + successCount++ + + mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)). + Str(zf.ConnID, conn.id). + Msg("send: successfully sent to connection") + } + } + + // Write-lock only to remove failed connections. + if len(failed) > 0 { + mc.mutex.Lock() + // Remove by pointer identity: only remove entries that still exist + // in the current connections slice and match a failed pointer. + // New connections added since the snapshot are not affected. + failedSet := make(map[*connectionEntry]struct{}, len(failed)) + for _, f := range failed { + failedSet[f] = struct{}{} + } + + clean := mc.connections[:0] + for _, conn := range mc.connections { + if _, isFailed := failedSet[conn]; !isFailed { + clean = append(clean, conn) + } else { + mc.log.Debug().Caller(). + Str(zf.ConnID, conn.id). + Msg("send: removing failed connection") + // Tear down the owning session so the old serveLongPoll + // goroutine exits instead of lingering as a stale session. + mc.stopConnection(conn) + } + } + + mc.connections = clean + mc.mutex.Unlock() + } + + mc.updateCount.Add(1) + + mc.log.Debug(). + Int("successful_sends", successCount). + Int("failed_connections", len(failed)). + Msg("send: completed broadcast") + + // Success if at least one send succeeded + if successCount > 0 { + return nil + } + + return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr) +} + +// send sends data to a single connection entry with timeout-based stale connection detection. +func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + + // Check if the connection has been closed to prevent send on closed channel panic. + // This can happen during shutdown when Close() is called while workers are still processing. + if entry.closed.Load() { + return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed) + } + + // Use a short timeout to detect stale connections where the client isn't reading the channel. + // This is critical for detecting Docker containers that are forcefully terminated + // but still have channels that appear open. + select { + case entry.c <- data: + // Update last used timestamp on successful send + entry.lastUsed.Store(time.Now().Unix()) + return nil + case <-time.After(50 * time.Millisecond): + // Connection is likely stale - client isn't reading from channel + // This catches the case where Docker containers are killed but channels remain open + return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout) + } +} + +// nodeID returns the node ID. +func (mc *multiChannelNodeConn) nodeID() types.NodeID { + return mc.id +} + +// version returns the capability version from the first active connection. +// All connections for a node should have the same version in practice. +func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + if len(mc.connections) == 0 { + return 0 + } + + return mc.connections[0].version +} + +// updateSentPeers updates the tracked peer state based on a sent MapResponse. +// This must be called after successfully sending a response to keep track of +// what the client knows about, enabling accurate diffs for future updates. +func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) { + if resp == nil { + return + } + + // Full peer list replaces tracked state entirely + if resp.Peers != nil { + mc.lastSentPeers.Clear() + + for _, peer := range resp.Peers { + mc.lastSentPeers.Store(peer.ID, struct{}{}) + } + } + + // Incremental additions + for _, peer := range resp.PeersChanged { + mc.lastSentPeers.Store(peer.ID, struct{}{}) + } + + // Incremental removals + for _, id := range resp.PeersRemoved { + mc.lastSentPeers.Delete(id) + } +} + +// computePeerDiff compares the current peer list against what was last sent +// and returns the peers that were removed (in lastSentPeers but not in current). +func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID { + currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers)) + for _, id := range currentPeers { + currentSet[id] = struct{}{} + } + + var removed []tailcfg.NodeID + + // Find removed: in lastSentPeers but not in current + mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool { + if _, exists := currentSet[id]; !exists { + removed = append(removed, id) + } + + return true + }) + + return removed +} + +// change applies a change to all active connections for the node. +func (mc *multiChannelNodeConn) change(r change.Change) error { + return handleNodeChange(mc, mc.mapper, r) +} + +// DebugNodeInfo contains debug information about a node's connections. +type DebugNodeInfo struct { + Connected bool `json:"connected"` + ActiveConnections int `json:"active_connections"` +} + +// Debug returns a pre-baked map of node debug information for the debug interface. +func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo { + result := make(map[types.NodeID]DebugNodeInfo) + + // Get all nodes with their connection status using immediate connection logic + // (no grace period) for debug purposes + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + if nodeConn == nil { + return true + } + + nodeConn.mutex.RLock() + activeConnCount := len(nodeConn.connections) + nodeConn.mutex.RUnlock() + + // Use immediate connection status: if active connections exist, node is connected + // If not, check the connected map for nil (connected) vs timestamp (disconnected) + connected := false + if activeConnCount > 0 { + connected = true + } else { + // Check connected map for immediate status + if val, ok := b.connected.Load(id); ok && val == nil { + connected = true + } + } + + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: activeConnCount, + } + + return true + }) + + // Add all entries from the connected map to capture both connected and disconnected nodes + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // Only add if not already processed above + if _, exists := result[id]; !exists { + // Use immediate connection status for debug (no grace period) + connected := (val == nil) // nil means connected, timestamp means disconnected + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: 0, + } + } + + return true + }) + + return result +} + +func (b *Batcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + return b.mapper.debugMapResponses() +} + +// WorkErrors returns the count of work errors encountered. +// This is primarily useful for testing and debugging. +func (b *Batcher) WorkErrors() int64 { + return b.workErrors.Load() +} diff --git a/hscontrol/mapper/batcher_bench_test.go b/hscontrol/mapper/batcher_bench_test.go index 2d30110b..e229faa5 100644 --- a/hscontrol/mapper/batcher_bench_test.go +++ b/hscontrol/mapper/batcher_bench_test.go @@ -148,8 +148,8 @@ func BenchmarkUpdateSentPeers(b *testing.B) { // benchBatcher creates a lightweight batcher for benchmarks. Unlike the test // helper, it doesn't register cleanup and suppresses logging. -func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) { - b := &LockFreeBatcher{ +func benchBatcher(nodeCount, bufferSize int) (*Batcher, map[types.NodeID]chan *tailcfg.MapResponse) { + b := &Batcher{ tick: time.NewTicker(1 * time.Hour), // never fires during bench workers: 4, workCh: make(chan work, 4*200), diff --git a/hscontrol/mapper/batcher_concurrency_test.go b/hscontrol/mapper/batcher_concurrency_test.go index b15175b1..d6cb0f83 100644 --- a/hscontrol/mapper/batcher_concurrency_test.go +++ b/hscontrol/mapper/batcher_concurrency_test.go @@ -35,7 +35,7 @@ import ( // lightweightBatcher provides a batcher with pre-populated nodes for testing // the batching, channel, and concurrency mechanics without database overhead. type lightweightBatcher struct { - b *LockFreeBatcher + b *Batcher channels map[types.NodeID]chan *tailcfg.MapResponse } @@ -46,7 +46,7 @@ type lightweightBatcher struct { func setupLightweightBatcher(t *testing.T, nodeCount, bufferSize int) *lightweightBatcher { t.Helper() - b := &LockFreeBatcher{ + b := &Batcher{ tick: time.NewTicker(10 * time.Millisecond), workers: 4, workCh: make(chan work, 4*200), @@ -86,7 +86,7 @@ func (lb *lightweightBatcher) cleanup() { } // countTotalPending counts total pending change entries across all nodes. -func countTotalPending(b *LockFreeBatcher) int { +func countTotalPending(b *Batcher) int { count := 0 b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { @@ -101,7 +101,7 @@ func countTotalPending(b *LockFreeBatcher) int { } // countNodesPending counts how many nodes have pending changes. -func countNodesPending(b *LockFreeBatcher) int { +func countNodesPending(b *Batcher) int { count := 0 b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { @@ -120,7 +120,7 @@ func countNodesPending(b *LockFreeBatcher) int { } // getPendingForNode returns pending changes for a specific node. -func getPendingForNode(b *LockFreeBatcher, id types.NodeID) []change.Change { +func getPendingForNode(b *Batcher, id types.NodeID) []change.Change { nc, ok := b.nodes.Load(id) if !ok { return nil @@ -1167,7 +1167,7 @@ func TestScale1000_MultiChannelBroadcast(t *testing.T) { ) // Create nodes with varying connection counts - b := &LockFreeBatcher{ + b := &Batcher{ tick: time.NewTicker(10 * time.Millisecond), workers: 4, workCh: make(chan work, 4*200), @@ -1569,7 +1569,7 @@ func TestScale1000_WorkChannelSaturation(t *testing.T) { defer zerolog.SetGlobalLevel(zerolog.DebugLevel) // Create batcher with SMALL work channel to force saturation - b := &LockFreeBatcher{ + b := &Batcher{ tick: time.NewTicker(10 * time.Millisecond), workers: 2, workCh: make(chan work, 10), // Very small - will saturate diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go deleted file mode 100644 index 37a748bd..00000000 --- a/hscontrol/mapper/batcher_lockfree.go +++ /dev/null @@ -1,988 +0,0 @@ -package mapper - -import ( - "crypto/rand" - "encoding/hex" - "errors" - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/types/change" - "github.com/juanfont/headscale/hscontrol/util/zlog/zf" - "github.com/puzpuzpuz/xsync/v4" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "tailscale.com/tailcfg" -) - -// LockFreeBatcher errors. -var ( - errConnectionClosed = errors.New("connection channel already closed") - ErrInitialMapSendTimeout = errors.New("sending initial map: timeout") - ErrBatcherShuttingDown = errors.New("batcher shutting down") - ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)") -) - -// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. -type LockFreeBatcher struct { - tick *time.Ticker - mapper *mapper - workers int - - nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] - connected *xsync.Map[types.NodeID, *time.Time] - - // Work queue channel - workCh chan work - workChOnce sync.Once // Ensures workCh is only closed once - done chan struct{} - doneOnce sync.Once // Ensures done is only closed once - - started atomic.Bool // Ensures Start() is only called once - - // Metrics - totalNodes atomic.Int64 - workQueuedCount atomic.Int64 - workProcessed atomic.Int64 - workErrors atomic.Int64 -} - -// AddNode registers a new node connection with the batcher and sends an initial map response. -// It creates or updates the node's connection data, validates the initial map generation, -// and notifies other nodes that this node has come online. -// The stop function tears down the owning session if this connection is later declared stale. -func (b *LockFreeBatcher) AddNode( - id types.NodeID, - c chan<- *tailcfg.MapResponse, - version tailcfg.CapabilityVersion, - stop func(), -) error { - addNodeStart := time.Now() - nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() - - // Generate connection ID - connID := generateConnectionID() - - // Create new connection entry - now := time.Now() - newEntry := &connectionEntry{ - id: connID, - c: c, - version: version, - created: now, - stop: stop, - } - // Initialize last used timestamp - newEntry.lastUsed.Store(now.Unix()) - - // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection - nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper)) - - if !loaded { - b.totalNodes.Add(1) - } - - // Add connection to the list (lock-free) - nodeConn.addConnection(newEntry) - - // Use the worker pool for controlled concurrency instead of direct generation - initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) - if err != nil { - nlog.Error().Err(err).Msg("initial map generation failed") - nodeConn.removeConnectionByChannel(c) - - return fmt.Errorf("generating initial map for node %d: %w", id, err) - } - - // Use a blocking send with timeout for initial map since the channel should be ready - // and we want to avoid the race condition where the receiver isn't ready yet - select { - case c <- initialMap: - // Success - case <-time.After(5 * time.Second): //nolint:mnd - nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout") - nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd - Msg("initial map send timed out because channel was blocked or receiver not ready") - nodeConn.removeConnectionByChannel(c) - - return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) - } - - // Update connection status - b.connected.Store(id, nil) // nil = connected - - // Node will automatically receive updates through the normal flow - // The initial full map already contains all current state - - nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)). - Int("active.connections", nodeConn.getActiveConnectionCount()). - Msg("node connection established in batcher") - - return nil -} - -// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. -// It validates the connection channel matches one of the current connections, closes that specific connection, -// and keeps the node entry alive for rapid reconnections instead of aggressive deletion. -// Reports if the node still has active connections after removal. -func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { - nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger() - - nodeConn, exists := b.nodes.Load(id) - if !exists || nodeConn == nil { - nlog.Debug().Caller().Msg("removeNode called for non-existent node") - return false - } - - // Remove specific connection - removed := nodeConn.removeConnectionByChannel(c) - if !removed { - nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid") - } - - // Check if node has any remaining active connections - if nodeConn.hasActiveConnections() { - nlog.Debug().Caller(). - Int("active.connections", nodeConn.getActiveConnectionCount()). - Msg("node connection removed but keeping online, other connections remain") - - return true // Node still has active connections - } - - // No active connections - keep the node entry alive for rapid reconnections - // The node will get a fresh full map when it reconnects - nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection") - b.connected.Store(id, new(time.Now())) - - return false -} - -// AddWork queues a change to be processed by the batcher. -func (b *LockFreeBatcher) AddWork(r ...change.Change) { - b.addWork(r...) -} - -func (b *LockFreeBatcher) Start() { - if !b.started.CompareAndSwap(false, true) { - return - } - - b.done = make(chan struct{}) - - go b.doWork() -} - -func (b *LockFreeBatcher) Close() { - // Signal shutdown to all goroutines, only once - b.doneOnce.Do(func() { - if b.done != nil { - close(b.done) - } - }) - - // Only close workCh once using sync.Once to prevent races - b.workChOnce.Do(func() { - close(b.workCh) - }) - - // Close the underlying channels supplying the data to the clients. - b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { - if conn == nil { - return true - } - conn.close() - return true - }) -} - -func (b *LockFreeBatcher) doWork() { - for i := range b.workers { - go b.worker(i + 1) - } - - // Create a cleanup ticker for removing truly disconnected nodes - cleanupTicker := time.NewTicker(5 * time.Minute) - defer cleanupTicker.Stop() - - for { - select { - case <-b.tick.C: - // Process batched changes - b.processBatchedChanges() - case <-cleanupTicker.C: - // Clean up nodes that have been offline for too long - b.cleanupOfflineNodes() - case <-b.done: - log.Info().Msg("batcher done channel closed, stopping to feed workers") - return - } - } -} - -func (b *LockFreeBatcher) worker(workerID int) { - wlog := log.With().Int(zf.WorkerID, workerID).Logger() - - for { - select { - case w, ok := <-b.workCh: - if !ok { - wlog.Debug().Msg("worker channel closing, shutting down") - return - } - - b.workProcessed.Add(1) - - // If the resultCh is set, it means that this is a work request - // where there is a blocking function waiting for the map that - // is being generated. - // This is used for synchronous map generation. - if w.resultCh != nil { - var result workResult - - if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { - var err error - - result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) - - result.err = err - if result.err != nil { - b.workErrors.Add(1) - wlog.Error().Err(result.err). - Uint64(zf.NodeID, w.nodeID.Uint64()). - Str(zf.Reason, w.c.Reason). - Msg("failed to generate map response for synchronous work") - } else if result.mapResponse != nil { - // Update peer tracking for synchronous responses too - nc.updateSentPeers(result.mapResponse) - } - } else { - result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID) - - b.workErrors.Add(1) - wlog.Error().Err(result.err). - Uint64(zf.NodeID, w.nodeID.Uint64()). - Msg("node not found for synchronous work") - } - - // Send result - select { - case w.resultCh <- result: - case <-b.done: - return - } - - continue - } - - // If resultCh is nil, this is an asynchronous work request - // that should be processed and sent to the node instead of - // returned to the caller. - if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil { - // Apply change to node - this will handle offline nodes gracefully - // and queue work for when they reconnect - err := nc.change(w.c) - if err != nil { - b.workErrors.Add(1) - wlog.Error().Err(err). - Uint64(zf.NodeID, w.nodeID.Uint64()). - Str(zf.Reason, w.c.Reason). - Msg("failed to apply change") - } - } - case <-b.done: - wlog.Debug().Msg("batcher shutting down, exiting worker") - return - } - } -} - -func (b *LockFreeBatcher) addWork(r ...change.Change) { - b.addToBatch(r...) -} - -// queueWork safely queues work. -func (b *LockFreeBatcher) queueWork(w work) { - b.workQueuedCount.Add(1) - - select { - case b.workCh <- w: - // Successfully queued - case <-b.done: - // Batcher is shutting down - return - } -} - -// addToBatch adds changes to the pending batch. -func (b *LockFreeBatcher) addToBatch(changes ...change.Change) { - // Clean up any nodes being permanently removed from the system. - // - // This handles the case where a node is deleted from state but the batcher - // still has it registered. By cleaning up here, we prevent "node not found" - // errors when workers try to generate map responses for deleted nodes. - // - // Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually - // deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy - // changes that affect peer visibility do NOT use this field - they set - // RequiresRuntimePeerComputation=true and compute removed peers at runtime, - // putting them in tailcfg.MapResponse.PeersRemoved (a different struct). - // Therefore, this cleanup only removes nodes that are truly being deleted, - // not nodes that are still connected but have lost visibility of certain peers. - // - // See: https://github.com/juanfont/headscale/issues/2924 - for _, ch := range changes { - for _, removedID := range ch.PeersRemoved { - if _, existed := b.nodes.LoadAndDelete(removedID); existed { - b.totalNodes.Add(-1) - log.Debug(). - Uint64(zf.NodeID, removedID.Uint64()). - Msg("removed deleted node from batcher") - } - - b.connected.Delete(removedID) - } - } - - // Short circuit if any of the changes is a full update, which - // means we can skip sending individual changes. - if change.HasFull(changes) { - b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool { - if nc == nil { - return true - } - - nc.pendingMu.Lock() - nc.pending = []change.Change{change.FullUpdate()} - nc.pendingMu.Unlock() - - return true - }) - - return - } - - broadcast, targeted := change.SplitTargetedAndBroadcast(changes) - - // Handle targeted changes - send only to the specific node - for _, ch := range targeted { - if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil { - nc.appendPending(ch) - } - } - - // Handle broadcast changes - send to all nodes, filtering as needed - if len(broadcast) > 0 { - b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { - if nc == nil { - return true - } - filtered := change.FilterForNode(nodeID, broadcast) - - if len(filtered) > 0 { - nc.appendPending(filtered...) - } - - return true - }) - } -} - -// processBatchedChanges processes all pending batched changes. -func (b *LockFreeBatcher) processBatchedChanges() { - b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool { - if nc == nil { - return true - } - - pending := nc.drainPending() - if len(pending) == 0 { - return true - } - - // Send all batched changes for this node - for _, ch := range pending { - b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) - } - - return true - }) -} - -// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. -// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node -// reconnects between the hasActiveConnections() check and the Delete() call. -// TODO(kradalby): reevaluate if we want to keep this. -func (b *LockFreeBatcher) cleanupOfflineNodes() { - cleanupThreshold := 15 * time.Minute - now := time.Now() - - var nodesToCleanup []types.NodeID - - // Find nodes that have been offline for too long - b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { - if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold { - nodesToCleanup = append(nodesToCleanup, nodeID) - } - - return true - }) - - // Clean up the identified nodes using Compute() for atomic check-and-delete. - // This prevents a TOCTOU race where a node reconnects (adding an active - // connection) between the hasActiveConnections() check and the Delete() call. - cleaned := 0 - for _, nodeID := range nodesToCleanup { - deleted := false - - b.nodes.Compute( - nodeID, - func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) { - if !loaded || conn == nil || conn.hasActiveConnections() { - return conn, xsync.CancelOp - } - - deleted = true - - return conn, xsync.DeleteOp - }, - ) - - if deleted { - log.Info().Uint64(zf.NodeID, nodeID.Uint64()). - Dur("offline_duration", cleanupThreshold). - Msg("cleaning up node that has been offline for too long") - - b.connected.Delete(nodeID) - b.totalNodes.Add(-1) - - cleaned++ - } - } - - if cleaned > 0 { - log.Info().Int(zf.CleanedNodes, cleaned). - Msg("completed cleanup of long-offline nodes") - } -} - -// IsConnected is lock-free read that checks if a node has any active connections. -func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { - // First check if we have active connections for this node - if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil { - if nodeConn.hasActiveConnections() { - return true - } - } - - // Check disconnected timestamp with grace period - val, ok := b.connected.Load(id) - if !ok { - return false - } - - // nil means connected - if val == nil { - return true - } - - return false -} - -// ConnectedMap returns a lock-free map of all connected nodes. -func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { - ret := xsync.NewMap[types.NodeID, bool]() - - // First, add all nodes with active connections - b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { - if nodeConn == nil { - return true - } - if nodeConn.hasActiveConnections() { - ret.Store(id, true) - } - - return true - }) - - // Then add all entries from the connected map - b.connected.Range(func(id types.NodeID, val *time.Time) bool { - // Only add if not already added as connected above - if _, exists := ret.Load(id); !exists { - if val == nil { - // nil means connected - ret.Store(id, true) - } else { - // timestamp means disconnected - ret.Store(id, false) - } - } - - return true - }) - - return ret -} - -// MapResponseFromChange queues work to generate a map response and waits for the result. -// This allows synchronous map generation using the same worker pool. -func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { - resultCh := make(chan workResult, 1) - - // Queue the work with a result channel using the safe queueing method - b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh}) - - // Wait for the result - select { - case result := <-resultCh: - return result.mapResponse, result.err - case <-b.done: - return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id) - } -} - -// connectionEntry represents a single connection to a node. -type connectionEntry struct { - id string // unique connection ID - c chan<- *tailcfg.MapResponse - version tailcfg.CapabilityVersion - created time.Time - stop func() - lastUsed atomic.Int64 // Unix timestamp of last successful send - closed atomic.Bool // Indicates if this connection has been closed -} - -// multiChannelNodeConn manages multiple concurrent connections for a single node. -type multiChannelNodeConn struct { - id types.NodeID - mapper *mapper - log zerolog.Logger - - mutex sync.RWMutex - connections []*connectionEntry - - // pendingMu protects pending changes independently of the connection mutex. - // This avoids contention between addToBatch (which appends changes) and - // send() (which sends data to connections). - pendingMu sync.Mutex - pending []change.Change - - closeOnce sync.Once - updateCount atomic.Int64 - - // lastSentPeers tracks which peers were last sent to this node. - // This enables computing diffs for policy changes instead of sending - // full peer lists (which clients interpret as "no change" when empty). - // Using xsync.Map for lock-free concurrent access. - lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}] -} - -// generateConnectionID generates a unique connection identifier. -func generateConnectionID() string { - bytes := make([]byte, 8) - _, _ = rand.Read(bytes) - - return hex.EncodeToString(bytes) -} - -// newMultiChannelNodeConn creates a new multi-channel node connection. -func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn { - return &multiChannelNodeConn{ - id: id, - mapper: mapper, - lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](), - log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(), - } -} - -func (mc *multiChannelNodeConn) close() { - mc.closeOnce.Do(func() { - mc.mutex.Lock() - defer mc.mutex.Unlock() - - for _, conn := range mc.connections { - mc.stopConnection(conn) - } - }) -} - -// stopConnection marks a connection as closed and tears down the owning session -// at most once, even if multiple cleanup paths race to remove it. -func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) { - if conn.closed.CompareAndSwap(false, true) { - if conn.stop != nil { - conn.stop() - } - } -} - -// removeConnectionAtIndexLocked removes the active connection at index. -// If stopConnection is true, it also stops that session. -// Caller must hold mc.mutex. -func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry { - conn := mc.connections[i] - mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) - - if stopConnection { - mc.stopConnection(conn) - } - - return conn -} - -// addConnection adds a new connection. -func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { - mutexWaitStart := time.Now() - - mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id). - Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") - - mc.mutex.Lock() - - mutexWaitDur := time.Since(mutexWaitStart) - - defer mc.mutex.Unlock() - - mc.connections = append(mc.connections, entry) - mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id). - Int("total_connections", len(mc.connections)). - Dur("mutex_wait_time", mutexWaitDur). - Msg("successfully added connection after mutex wait") -} - -// removeConnectionByChannel removes a connection by matching channel pointer. -func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool { - mc.mutex.Lock() - defer mc.mutex.Unlock() - - for i, entry := range mc.connections { - if entry.c == c { - mc.removeConnectionAtIndexLocked(i, false) - mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)). - Int("remaining_connections", len(mc.connections)). - Msg("successfully removed connection") - - return true - } - } - - return false -} - -// hasActiveConnections checks if the node has any active connections. -func (mc *multiChannelNodeConn) hasActiveConnections() bool { - mc.mutex.RLock() - defer mc.mutex.RUnlock() - - return len(mc.connections) > 0 -} - -// getActiveConnectionCount returns the number of active connections. -func (mc *multiChannelNodeConn) getActiveConnectionCount() int { - mc.mutex.RLock() - defer mc.mutex.RUnlock() - - return len(mc.connections) -} - -// appendPending appends changes to this node's pending change list. -// Thread-safe via pendingMu; does not contend with the connection mutex. -func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) { - mc.pendingMu.Lock() - mc.pending = append(mc.pending, changes...) - mc.pendingMu.Unlock() -} - -// drainPending atomically removes and returns all pending changes. -// Returns nil if there are no pending changes. -func (mc *multiChannelNodeConn) drainPending() []change.Change { - mc.pendingMu.Lock() - p := mc.pending - mc.pending = nil - mc.pendingMu.Unlock() - - return p -} - -// send broadcasts data to all active connections for the node. -// send broadcasts data to all connections using a two-phase approach to avoid -// holding the write lock during potentially slow sends. Each stale connection -// can block for up to 50ms (see connectionEntry.send), so N stale connections -// under a single write lock would block for N*50ms. The two-phase approach: -// -// 1. RLock: snapshot the connections slice (cheap pointer copy) -// 2. Unlock: send to all connections without any lock held (timeouts happen here) -// 3. Lock: remove only the failed connections by pointer identity -// -// New connections added during step 2 are safe: they receive a full initial -// map via AddNode, so missing this particular update causes no data loss. -func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { - if data == nil { - return nil - } - - // Phase 1: snapshot connections under read lock. - mc.mutex.RLock() - if len(mc.connections) == 0 { - mc.mutex.RUnlock() - mc.log.Debug().Caller(). - Msg("send: skipping send to node with no active connections (likely rapid reconnection)") - - return nil - } - - // Copy the slice header (shares underlying array, but that's fine since - // we only read; writes go through the write lock in phase 3). - snapshot := make([]*connectionEntry, len(mc.connections)) - copy(snapshot, mc.connections) - mc.mutex.RUnlock() - - mc.log.Debug().Caller(). - Int("total_connections", len(snapshot)). - Msg("send: broadcasting to all connections") - - // Phase 2: send to all connections without holding any lock. - // Stale connection timeouts (50ms each) happen here without blocking - // other goroutines that need the mutex. - var ( - lastErr error - successCount int - failed []*connectionEntry - ) - - for _, conn := range snapshot { - mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)). - Str(zf.ConnID, conn.id). - Msg("send: attempting to send to connection") - - err := conn.send(data) - if err != nil { - lastErr = err - - failed = append(failed, conn) - - mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)). - Str(zf.ConnID, conn.id). - Msg("send: connection send failed") - } else { - successCount++ - - mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)). - Str(zf.ConnID, conn.id). - Msg("send: successfully sent to connection") - } - } - - // Phase 3: write-lock only to remove failed connections. - if len(failed) > 0 { - mc.mutex.Lock() - // Remove by pointer identity: only remove entries that still exist - // in the current connections slice and match a failed pointer. - // New connections added between phase 1 and 3 are not affected. - failedSet := make(map[*connectionEntry]struct{}, len(failed)) - for _, f := range failed { - failedSet[f] = struct{}{} - } - - clean := mc.connections[:0] - for _, conn := range mc.connections { - if _, isFailed := failedSet[conn]; !isFailed { - clean = append(clean, conn) - } else { - mc.log.Debug().Caller(). - Str(zf.ConnID, conn.id). - Msg("send: removing failed connection") - // Tear down the owning session so the old serveLongPoll - // goroutine exits instead of lingering as a stale session. - mc.stopConnection(conn) - } - } - - mc.connections = clean - mc.mutex.Unlock() - } - - mc.updateCount.Add(1) - - mc.log.Debug(). - Int("successful_sends", successCount). - Int("failed_connections", len(failed)). - Msg("send: completed broadcast") - - // Success if at least one send succeeded - if successCount > 0 { - return nil - } - - return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr) -} - -// send sends data to a single connection entry with timeout-based stale connection detection. -func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { - if data == nil { - return nil - } - - // Check if the connection has been closed to prevent send on closed channel panic. - // This can happen during shutdown when Close() is called while workers are still processing. - if entry.closed.Load() { - return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed) - } - - // Use a short timeout to detect stale connections where the client isn't reading the channel. - // This is critical for detecting Docker containers that are forcefully terminated - // but still have channels that appear open. - select { - case entry.c <- data: - // Update last used timestamp on successful send - entry.lastUsed.Store(time.Now().Unix()) - return nil - case <-time.After(50 * time.Millisecond): - // Connection is likely stale - client isn't reading from channel - // This catches the case where Docker containers are killed but channels remain open - return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout) - } -} - -// nodeID returns the node ID. -func (mc *multiChannelNodeConn) nodeID() types.NodeID { - return mc.id -} - -// version returns the capability version from the first active connection. -// All connections for a node should have the same version in practice. -func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion { - mc.mutex.RLock() - defer mc.mutex.RUnlock() - - if len(mc.connections) == 0 { - return 0 - } - - return mc.connections[0].version -} - -// updateSentPeers updates the tracked peer state based on a sent MapResponse. -// This must be called after successfully sending a response to keep track of -// what the client knows about, enabling accurate diffs for future updates. -func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) { - if resp == nil { - return - } - - // Full peer list replaces tracked state entirely - if resp.Peers != nil { - mc.lastSentPeers.Clear() - - for _, peer := range resp.Peers { - mc.lastSentPeers.Store(peer.ID, struct{}{}) - } - } - - // Incremental additions - for _, peer := range resp.PeersChanged { - mc.lastSentPeers.Store(peer.ID, struct{}{}) - } - - // Incremental removals - for _, id := range resp.PeersRemoved { - mc.lastSentPeers.Delete(id) - } -} - -// computePeerDiff compares the current peer list against what was last sent -// and returns the peers that were removed (in lastSentPeers but not in current). -func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID { - currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers)) - for _, id := range currentPeers { - currentSet[id] = struct{}{} - } - - var removed []tailcfg.NodeID - - // Find removed: in lastSentPeers but not in current - mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool { - if _, exists := currentSet[id]; !exists { - removed = append(removed, id) - } - - return true - }) - - return removed -} - -// change applies a change to all active connections for the node. -func (mc *multiChannelNodeConn) change(r change.Change) error { - return handleNodeChange(mc, mc.mapper, r) -} - -// DebugNodeInfo contains debug information about a node's connections. -type DebugNodeInfo struct { - Connected bool `json:"connected"` - ActiveConnections int `json:"active_connections"` -} - -// Debug returns a pre-baked map of node debug information for the debug interface. -func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { - result := make(map[types.NodeID]DebugNodeInfo) - - // Get all nodes with their connection status using immediate connection logic - // (no grace period) for debug purposes - b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { - if nodeConn == nil { - return true - } - nodeConn.mutex.RLock() - activeConnCount := len(nodeConn.connections) - nodeConn.mutex.RUnlock() - - // Use immediate connection status: if active connections exist, node is connected - // If not, check the connected map for nil (connected) vs timestamp (disconnected) - connected := false - if activeConnCount > 0 { - connected = true - } else { - // Check connected map for immediate status - if val, ok := b.connected.Load(id); ok && val == nil { - connected = true - } - } - - result[id] = DebugNodeInfo{ - Connected: connected, - ActiveConnections: activeConnCount, - } - - return true - }) - - // Add all entries from the connected map to capture both connected and disconnected nodes - b.connected.Range(func(id types.NodeID, val *time.Time) bool { - // Only add if not already processed above - if _, exists := result[id]; !exists { - // Use immediate connection status for debug (no grace period) - connected := (val == nil) // nil means connected, timestamp means disconnected - result[id] = DebugNodeInfo{ - Connected: connected, - ActiveConnections: 0, - } - } - - return true - }) - - return result -} - -func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { - return b.mapper.debugMapResponses() -} - -// WorkErrors returns the count of work errors encountered. -// This is primarily useful for testing and debugging. -func (b *LockFreeBatcher) WorkErrors() int64 { - return b.workErrors.Load() -} diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 281bdb02..46c4d1d7 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -25,6 +25,8 @@ import ( var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher") +type batcherFunc func(cfg *types.Config, state *state.State) *Batcher + // batcherTestCase defines a batcher function with a descriptive name for testing. type batcherTestCase struct { name string @@ -34,7 +36,7 @@ type batcherTestCase struct { // testBatcherWrapper wraps a real batcher to add online/offline notifications // that would normally be sent by poll.go in production. type testBatcherWrapper struct { - Batcher + *Batcher state *state.State } @@ -85,13 +87,13 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe } // wrapBatcherForTest wraps a batcher with test-specific behavior. -func wrapBatcherForTest(b Batcher, state *state.State) Batcher { +func wrapBatcherForTest(b *Batcher, state *state.State) *testBatcherWrapper { return &testBatcherWrapper{Batcher: b, state: state} } // allBatcherFunctions contains all batcher implementations to test. var allBatcherFunctions = []batcherTestCase{ - {"LockFree", NewBatcherAndMapper}, + {"Default", NewBatcherAndMapper}, } // emptyCache creates an empty registration cache for testing. @@ -134,7 +136,7 @@ type TestData struct { Nodes []node State *state.State Config *types.Config - Batcher Batcher + Batcher *testBatcherWrapper } type node struct { @@ -2354,46 +2356,35 @@ func TestBatcherRapidReconnection(t *testing.T) { // Check debug status after reconnection. t.Logf("Checking debug status...") - if debugBatcher, ok := batcher.(interface { - Debug() map[types.NodeID]any - }); ok { - debugInfo := debugBatcher.Debug() - disconnectedCount := 0 + debugInfo := batcher.Debug() + disconnectedCount := 0 - for i := range allNodes { - node := &allNodes[i] - if info, exists := debugInfo[node.n.ID]; exists { - t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info) + for i := range allNodes { + node := &allNodes[i] + if info, exists := debugInfo[node.n.ID]; exists { + t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info) - // Check if the debug info shows the node as connected - if infoMap, ok := info.(map[string]any); ok { - if connected, ok := infoMap["connected"].(bool); ok && !connected { - disconnectedCount++ - - t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) - } - } - } else { + if !info.Connected { disconnectedCount++ - t.Logf("Node %d missing from debug info entirely", i) + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) } - - // Also check IsConnected method - if !batcher.IsConnected(node.n.ID) { - t.Logf("Node %d IsConnected() returns false", i) - } - } - - if disconnectedCount > 0 { - t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes)) - // This is expected behavior for multi-channel batcher according to user - // "it has never worked with the multi" } else { - t.Logf("All nodes show as connected - working correctly") + disconnectedCount++ + + t.Logf("Node %d missing from debug info entirely", i) } + + // Also check IsConnected method + if !batcher.IsConnected(node.n.ID) { + t.Logf("Node %d IsConnected() returns false", i) + } + } + + if disconnectedCount > 0 { + t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes)) } else { - t.Logf("Batcher does not implement Debug() method") + t.Logf("All nodes show as connected - working correctly") } // Test if "disconnected" nodes can actually receive updates. @@ -2491,37 +2482,25 @@ func TestBatcherMultiConnection(t *testing.T) { // Verify debug status shows correct connection count. t.Logf("Verifying debug status shows multiple connections...") - if debugBatcher, ok := batcher.(interface { - Debug() map[types.NodeID]any - }); ok { - debugInfo := debugBatcher.Debug() + debugInfo := batcher.Debug() - if info, exists := debugInfo[node1.n.ID]; exists { - t.Logf("Node1 debug info: %+v", info) + if info, exists := debugInfo[node1.n.ID]; exists { + t.Logf("Node1 debug info: %+v", info) - if infoMap, ok := info.(map[string]any); ok { - if activeConnections, ok := infoMap["active_connections"].(int); ok { - if activeConnections != 3 { - t.Errorf("Node1 should have 3 active connections, got %d", activeConnections) - } else { - t.Logf("SUCCESS: Node1 correctly shows 3 active connections") - } - } - - if connected, ok := infoMap["connected"].(bool); ok && !connected { - t.Errorf("Node1 should show as connected with 3 active connections") - } - } + if info.ActiveConnections != 3 { + t.Errorf("Node1 should have 3 active connections, got %d", info.ActiveConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 3 active connections") } - if info, exists := debugInfo[node2.n.ID]; exists { - if infoMap, ok := info.(map[string]any); ok { - if activeConnections, ok := infoMap["active_connections"].(int); ok { - if activeConnections != 1 { - t.Errorf("Node2 should have 1 active connection, got %d", activeConnections) - } - } - } + if !info.Connected { + t.Errorf("Node1 should show as connected with 3 active connections") + } + } + + if info, exists := debugInfo[node2.n.ID]; exists { + if info.ActiveConnections != 1 { + t.Errorf("Node2 should have 1 active connection, got %d", info.ActiveConnections) } } @@ -2604,20 +2583,12 @@ func TestBatcherMultiConnection(t *testing.T) { runtime.Gosched() // Verify debug status shows 2 connections now - if debugBatcher, ok := batcher.(interface { - Debug() map[types.NodeID]any - }); ok { - debugInfo := debugBatcher.Debug() - if info, exists := debugInfo[node1.n.ID]; exists { - if infoMap, ok := info.(map[string]any); ok { - if activeConnections, ok := infoMap["active_connections"].(int); ok { - if activeConnections != 2 { - t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections) - } else { - t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal") - } - } - } + debugInfo2 := batcher.Debug() + if info, exists := debugInfo2[node1.n.ID]; exists { + if info.ActiveConnections != 2 { + t.Errorf("Node1 should have 2 active connections after removal, got %d", info.ActiveConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal") } } @@ -2731,11 +2702,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { }, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect") // Get initial work errors count - var initialWorkErrors int64 - if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok { - initialWorkErrors = lfb.WorkErrors() - t.Logf("Initial work errors: %d", initialWorkErrors) - } + lfb := unwrapBatcher(batcher) + initialWorkErrors := lfb.WorkErrors() + t.Logf("Initial work errors: %d", initialWorkErrors) // Clear channels to prepare for the test drainCh(node1.ch) @@ -2777,11 +2746,7 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { // With the fix, no new errors should occur because the deleted node // was cleaned up from batcher state when NodeRemoved was processed assert.EventuallyWithT(t, func(c *assert.CollectT) { - var finalWorkErrors int64 - if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok { - finalWorkErrors = lfb.WorkErrors() - } - + finalWorkErrors := lfb.WorkErrors() newErrors := finalWorkErrors - initialWorkErrors assert.Zero(c, newErrors, "Fix for #2924: should have no work errors after node deletion") }, 5*time.Second, 100*time.Millisecond, "waiting for work processing to complete without errors") @@ -2809,8 +2774,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, NORMAL_BUFFER_SIZE) defer cleanup() - lfb, ok := unwrapBatcher(testData.Batcher).(*LockFreeBatcher) - require.True(t, ok, "expected LockFreeBatcher") + lfb := unwrapBatcher(testData.Batcher) nodeID := testData.Nodes[0].n.ID ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) @@ -2838,8 +2802,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, NORMAL_BUFFER_SIZE) defer cleanup() - lfb, ok := unwrapBatcher(testData.Batcher).(*LockFreeBatcher) - require.True(t, ok, "expected LockFreeBatcher") + lfb := unwrapBatcher(testData.Batcher) nodeID := testData.Nodes[0].n.ID ch1 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) @@ -2867,11 +2830,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) { } } -// unwrapBatcher extracts the underlying batcher from wrapper types. -func unwrapBatcher(b Batcher) Batcher { - if wrapper, ok := b.(*testBatcherWrapper); ok { - return unwrapBatcher(wrapper.Batcher) - } - - return b +// unwrapBatcher extracts the underlying *Batcher from the test wrapper. +func unwrapBatcher(b *testBatcherWrapper) *Batcher { + return b.Batcher } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 4505f765..0aba4175 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -44,7 +44,7 @@ type mapper struct { // Configuration state *state.State cfg *types.Config - batcher Batcher + batcher *Batcher created time.Time }