mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-17 14:29:57 +02:00
M7: Nil out trailing *connectionEntry pointers in the backing array after slice removal in removeConnectionAtIndexLocked and send(). Without this, the GC cannot collect removed entries until the slice is reallocated. M1: Initialize the done channel in NewBatcher instead of Start(). Previously, calling Close() or queueWork before Start() would select on a nil channel, blocking forever. Moving the make() to the constructor ensures the channel is always usable. M2: Move b.connected.Delete and b.totalNodes decrement inside the Compute callback in cleanupOfflineNodes. Previously these ran after the Compute returned, allowing a concurrent AddNode to reconnect between the delete and the bookkeeping update, which would wipe the fresh connected state. M3: Call markDisconnectedIfNoConns on AddNode error paths. Previously, when initial map generation or send timed out, the connection was removed but b.connected retained its old nil (= connected) value, making IsConnected return true for a node with zero connections. Updates #2545
1178 lines
35 KiB
Go
1178 lines
35 KiB
Go
package mapper
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/state"
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"tailscale.com/tailcfg"
|
|
)
|
|
|
|
// Mapper errors.
|
|
var (
|
|
ErrInvalidNodeID = errors.New("invalid nodeID")
|
|
ErrMapperNil = errors.New("mapper is nil")
|
|
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
|
ErrNodeNotFoundMapper = errors.New("node not found")
|
|
)
|
|
|
|
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Namespace: "headscale",
|
|
Name: "mapresponse_generated_total",
|
|
Help: "total count of mapresponses generated by response type",
|
|
}, []string{"response_type"})
|
|
|
|
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *Batcher {
|
|
return &Batcher{
|
|
mapper: mapper,
|
|
workers: workers,
|
|
tick: time.NewTicker(batchTime),
|
|
|
|
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
|
workCh: make(chan work, workers*200),
|
|
done: make(chan struct{}),
|
|
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
|
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
|
}
|
|
}
|
|
|
|
// NewBatcherAndMapper creates a new Batcher with its mapper.
|
|
func NewBatcherAndMapper(cfg *types.Config, state *state.State) *Batcher {
|
|
m := newMapper(cfg, state)
|
|
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
|
m.batcher = b
|
|
|
|
return b
|
|
}
|
|
|
|
// nodeConnection interface for different connection implementations.
|
|
type nodeConnection interface {
|
|
nodeID() types.NodeID
|
|
version() tailcfg.CapabilityVersion
|
|
send(data *tailcfg.MapResponse) error
|
|
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
|
|
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
|
|
// updateSentPeers updates the tracking of which peers have been sent to this node.
|
|
updateSentPeers(resp *tailcfg.MapResponse)
|
|
}
|
|
|
|
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
|
|
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
|
|
nodeID := nc.nodeID()
|
|
version := nc.version()
|
|
|
|
if r.IsEmpty() {
|
|
return nil, nil //nolint:nilnil // Empty response means nothing to send
|
|
}
|
|
|
|
if nodeID == 0 {
|
|
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
|
}
|
|
|
|
if mapper == nil {
|
|
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
|
}
|
|
|
|
// Handle self-only responses
|
|
if r.IsSelfOnly() && r.TargetNode != nodeID {
|
|
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
|
|
}
|
|
|
|
// Check if this is a self-update (the changed node is the receiving node).
|
|
// When true, ensure the response includes the node's self info so it sees
|
|
// its own attribute changes (e.g., tags changed via admin API).
|
|
isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID
|
|
|
|
var (
|
|
mapResp *tailcfg.MapResponse
|
|
err error
|
|
)
|
|
|
|
// Track metric using categorized type, not free-form reason
|
|
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
|
|
|
|
// Check if this requires runtime peer visibility computation (e.g., policy changes)
|
|
if r.RequiresRuntimePeerComputation {
|
|
currentPeers := mapper.state.ListPeers(nodeID)
|
|
|
|
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
|
|
for _, peer := range currentPeers.All() {
|
|
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
|
|
}
|
|
|
|
removedPeers := nc.computePeerDiff(currentPeerIDs)
|
|
// Include self node when this is a self-update (e.g., node's own tags changed)
|
|
// so the node sees its updated self info along with new packet filters.
|
|
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate)
|
|
} else if isSelfUpdate {
|
|
// Non-policy self-update: just send the self node info
|
|
mapResp, err = mapper.selfMapResponse(nodeID, version)
|
|
} else {
|
|
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
|
}
|
|
|
|
return mapResp, nil
|
|
}
|
|
|
|
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
|
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
|
if nc == nil {
|
|
return ErrNodeConnectionNil
|
|
}
|
|
|
|
nodeID := nc.nodeID()
|
|
|
|
log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started")
|
|
|
|
data, err := generateMapResponse(nc, mapper, r)
|
|
if err != nil {
|
|
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
|
}
|
|
|
|
if data == nil {
|
|
// No data to send is valid for some response types
|
|
return nil
|
|
}
|
|
|
|
// Send the map response
|
|
err = nc.send(data)
|
|
if err != nil {
|
|
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
|
}
|
|
|
|
// Update peer tracking after successful send
|
|
nc.updateSentPeers(data)
|
|
|
|
return nil
|
|
}
|
|
|
|
// workResult represents the result of processing a change.
|
|
type workResult struct {
|
|
mapResponse *tailcfg.MapResponse
|
|
err error
|
|
}
|
|
|
|
// work represents a unit of work to be processed by workers.
|
|
type work struct {
|
|
c change.Change
|
|
nodeID types.NodeID
|
|
resultCh chan<- workResult // optional channel for synchronous operations
|
|
}
|
|
|
|
// Batcher errors.
|
|
var (
|
|
errConnectionClosed = errors.New("connection channel already closed")
|
|
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
|
|
ErrBatcherShuttingDown = errors.New("batcher shutting down")
|
|
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
|
|
)
|
|
|
|
// Batcher batches and distributes map responses to connected nodes.
|
|
// It uses concurrent maps, per-node mutexes, and a worker pool.
|
|
//
|
|
// Lifecycle: Call Start() to spawn workers, then Close() to shut down.
|
|
// Close() blocks until all workers have exited. A Batcher must not
|
|
// be reused after Close().
|
|
type Batcher struct {
|
|
tick *time.Ticker
|
|
mapper *mapper
|
|
workers int
|
|
|
|
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
|
|
connected *xsync.Map[types.NodeID, *time.Time]
|
|
|
|
// Work queue channel
|
|
workCh chan work
|
|
done chan struct{}
|
|
doneOnce sync.Once // Ensures done is only closed once
|
|
|
|
// wg tracks the doWork and all worker goroutines so that Close()
|
|
// can block until they have fully exited.
|
|
wg sync.WaitGroup
|
|
|
|
started atomic.Bool // Ensures Start() is only called once
|
|
|
|
// Metrics
|
|
totalNodes atomic.Int64
|
|
workQueuedCount atomic.Int64
|
|
workProcessed atomic.Int64
|
|
workErrors atomic.Int64
|
|
}
|
|
|
|
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
|
// It creates or updates the node's connection data, validates the initial map generation,
|
|
// and notifies other nodes that this node has come online.
|
|
// The stop function tears down the owning session if this connection is later declared stale.
|
|
func (b *Batcher) AddNode(
|
|
id types.NodeID,
|
|
c chan<- *tailcfg.MapResponse,
|
|
version tailcfg.CapabilityVersion,
|
|
stop func(),
|
|
) error {
|
|
addNodeStart := time.Now()
|
|
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
|
|
|
// Generate connection ID
|
|
connID := generateConnectionID()
|
|
|
|
// Create new connection entry
|
|
now := time.Now()
|
|
newEntry := &connectionEntry{
|
|
id: connID,
|
|
c: c,
|
|
version: version,
|
|
created: now,
|
|
stop: stop,
|
|
}
|
|
// Initialize last used timestamp
|
|
newEntry.lastUsed.Store(now.Unix())
|
|
|
|
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
|
|
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
|
|
|
|
if !loaded {
|
|
b.totalNodes.Add(1)
|
|
}
|
|
|
|
// Add connection to the list (lock-free)
|
|
nodeConn.addConnection(newEntry)
|
|
|
|
// Use the worker pool for controlled concurrency instead of direct generation
|
|
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
|
if err != nil {
|
|
nlog.Error().Err(err).Msg("initial map generation failed")
|
|
nodeConn.removeConnectionByChannel(c)
|
|
b.markDisconnectedIfNoConns(id, nodeConn)
|
|
|
|
return fmt.Errorf("generating initial map for node %d: %w", id, err)
|
|
}
|
|
|
|
// Use a blocking send with timeout for initial map since the channel should be ready
|
|
// and we want to avoid the race condition where the receiver isn't ready yet
|
|
select {
|
|
case c <- initialMap:
|
|
// Success
|
|
case <-time.After(5 * time.Second): //nolint:mnd
|
|
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
|
|
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
|
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
|
nodeConn.removeConnectionByChannel(c)
|
|
b.markDisconnectedIfNoConns(id, nodeConn)
|
|
|
|
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
|
|
}
|
|
|
|
// Update connection status
|
|
b.connected.Store(id, nil) // nil = connected
|
|
|
|
// Node will automatically receive updates through the normal flow
|
|
// The initial full map already contains all current state
|
|
|
|
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
|
|
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
|
Msg("node connection established in batcher")
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
|
// It validates the connection channel matches one of the current connections, closes that specific connection,
|
|
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
|
// Reports if the node still has active connections after removal.
|
|
func (b *Batcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
|
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
|
|
|
nodeConn, exists := b.nodes.Load(id)
|
|
if !exists || nodeConn == nil {
|
|
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
|
|
return false
|
|
}
|
|
|
|
// Remove specific connection
|
|
removed := nodeConn.removeConnectionByChannel(c)
|
|
if !removed {
|
|
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
|
|
}
|
|
|
|
// Check if node has any remaining active connections
|
|
if nodeConn.hasActiveConnections() {
|
|
nlog.Debug().Caller().
|
|
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
|
Msg("node connection removed but keeping online, other connections remain")
|
|
|
|
return true // Node still has active connections
|
|
}
|
|
|
|
// No active connections - keep the node entry alive for rapid reconnections
|
|
// The node will get a fresh full map when it reconnects
|
|
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
|
|
b.connected.Store(id, new(time.Now()))
|
|
|
|
return false
|
|
}
|
|
|
|
// AddWork queues a change to be processed by the batcher.
|
|
func (b *Batcher) AddWork(r ...change.Change) {
|
|
b.addWork(r...)
|
|
}
|
|
|
|
func (b *Batcher) Start() {
|
|
if !b.started.CompareAndSwap(false, true) {
|
|
return
|
|
}
|
|
|
|
b.wg.Add(1)
|
|
|
|
go b.doWork()
|
|
}
|
|
|
|
func (b *Batcher) Close() {
|
|
// Signal shutdown to all goroutines, only once.
|
|
// Workers and queueWork both select on done, so closing it
|
|
// is sufficient for graceful shutdown. We intentionally do NOT
|
|
// close workCh here because processBatchedChanges or
|
|
// MapResponseFromChange may still be sending on it concurrently.
|
|
b.doneOnce.Do(func() {
|
|
close(b.done)
|
|
})
|
|
|
|
// Wait for all worker goroutines (and doWork) to exit before
|
|
// tearing down node connections. This prevents workers from
|
|
// sending on connections that are being closed concurrently.
|
|
b.wg.Wait()
|
|
|
|
// Stop the ticker to prevent resource leaks.
|
|
b.tick.Stop()
|
|
|
|
// Close the underlying channels supplying the data to the clients.
|
|
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
|
|
if conn == nil {
|
|
return true
|
|
}
|
|
|
|
conn.close()
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
func (b *Batcher) doWork() {
|
|
defer b.wg.Done()
|
|
|
|
for i := range b.workers {
|
|
b.wg.Add(1)
|
|
|
|
go b.worker(i + 1)
|
|
}
|
|
|
|
// Create a cleanup ticker for removing truly disconnected nodes
|
|
cleanupTicker := time.NewTicker(5 * time.Minute)
|
|
defer cleanupTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-b.tick.C:
|
|
// Process batched changes
|
|
b.processBatchedChanges()
|
|
case <-cleanupTicker.C:
|
|
// Clean up nodes that have been offline for too long
|
|
b.cleanupOfflineNodes()
|
|
case <-b.done:
|
|
log.Info().Msg("batcher done channel closed, stopping to feed workers")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Batcher) worker(workerID int) {
|
|
defer b.wg.Done()
|
|
|
|
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
|
|
|
|
for {
|
|
select {
|
|
case w, ok := <-b.workCh:
|
|
if !ok {
|
|
wlog.Debug().Msg("worker channel closing, shutting down")
|
|
return
|
|
}
|
|
|
|
b.workProcessed.Add(1)
|
|
|
|
// If the resultCh is set, it means that this is a work request
|
|
// where there is a blocking function waiting for the map that
|
|
// is being generated.
|
|
// This is used for synchronous map generation.
|
|
if w.resultCh != nil {
|
|
var result workResult
|
|
|
|
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
|
|
var err error
|
|
|
|
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
|
|
|
result.err = err
|
|
if result.err != nil {
|
|
b.workErrors.Add(1)
|
|
wlog.Error().Err(result.err).
|
|
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
|
Str(zf.Reason, w.c.Reason).
|
|
Msg("failed to generate map response for synchronous work")
|
|
} else if result.mapResponse != nil {
|
|
// Update peer tracking for synchronous responses too
|
|
nc.updateSentPeers(result.mapResponse)
|
|
}
|
|
} else {
|
|
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
|
|
|
|
b.workErrors.Add(1)
|
|
wlog.Error().Err(result.err).
|
|
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
|
Msg("node not found for synchronous work")
|
|
}
|
|
|
|
// Send result
|
|
select {
|
|
case w.resultCh <- result:
|
|
case <-b.done:
|
|
return
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
// If resultCh is nil, this is an asynchronous work request
|
|
// that should be processed and sent to the node instead of
|
|
// returned to the caller.
|
|
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
|
|
// Apply change to node - this will handle offline nodes gracefully
|
|
// and queue work for when they reconnect
|
|
err := nc.change(w.c)
|
|
if err != nil {
|
|
b.workErrors.Add(1)
|
|
wlog.Error().Err(err).
|
|
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
|
Str(zf.Reason, w.c.Reason).
|
|
Msg("failed to apply change")
|
|
}
|
|
}
|
|
case <-b.done:
|
|
wlog.Debug().Msg("batcher shutting down, exiting worker")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Batcher) addWork(r ...change.Change) {
|
|
b.addToBatch(r...)
|
|
}
|
|
|
|
// queueWork safely queues work.
|
|
func (b *Batcher) queueWork(w work) {
|
|
b.workQueuedCount.Add(1)
|
|
|
|
select {
|
|
case b.workCh <- w:
|
|
// Successfully queued
|
|
case <-b.done:
|
|
// Batcher is shutting down
|
|
return
|
|
}
|
|
}
|
|
|
|
// addToBatch adds changes to the pending batch.
|
|
func (b *Batcher) addToBatch(changes ...change.Change) {
|
|
// Clean up any nodes being permanently removed from the system.
|
|
//
|
|
// This handles the case where a node is deleted from state but the batcher
|
|
// still has it registered. By cleaning up here, we prevent "node not found"
|
|
// errors when workers try to generate map responses for deleted nodes.
|
|
//
|
|
// Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually
|
|
// deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy
|
|
// changes that affect peer visibility do NOT use this field - they set
|
|
// RequiresRuntimePeerComputation=true and compute removed peers at runtime,
|
|
// putting them in tailcfg.MapResponse.PeersRemoved (a different struct).
|
|
// Therefore, this cleanup only removes nodes that are truly being deleted,
|
|
// not nodes that are still connected but have lost visibility of certain peers.
|
|
//
|
|
// See: https://github.com/juanfont/headscale/issues/2924
|
|
for _, ch := range changes {
|
|
for _, removedID := range ch.PeersRemoved {
|
|
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
|
|
b.totalNodes.Add(-1)
|
|
log.Debug().
|
|
Uint64(zf.NodeID, removedID.Uint64()).
|
|
Msg("removed deleted node from batcher")
|
|
}
|
|
|
|
b.connected.Delete(removedID)
|
|
}
|
|
}
|
|
|
|
// Short circuit if any of the changes is a full update, which
|
|
// means we can skip sending individual changes.
|
|
if change.HasFull(changes) {
|
|
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
|
if nc == nil {
|
|
return true
|
|
}
|
|
|
|
nc.pendingMu.Lock()
|
|
nc.pending = []change.Change{change.FullUpdate()}
|
|
nc.pendingMu.Unlock()
|
|
|
|
return true
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
broadcast, targeted := change.SplitTargetedAndBroadcast(changes)
|
|
|
|
// Handle targeted changes - send only to the specific node
|
|
for _, ch := range targeted {
|
|
if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil {
|
|
nc.appendPending(ch)
|
|
}
|
|
}
|
|
|
|
// Handle broadcast changes - send to all nodes, filtering as needed
|
|
if len(broadcast) > 0 {
|
|
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
|
|
if nc == nil {
|
|
return true
|
|
}
|
|
|
|
filtered := change.FilterForNode(nodeID, broadcast)
|
|
|
|
if len(filtered) > 0 {
|
|
nc.appendPending(filtered...)
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// processBatchedChanges processes all pending batched changes.
|
|
func (b *Batcher) processBatchedChanges() {
|
|
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
|
|
if nc == nil {
|
|
return true
|
|
}
|
|
|
|
pending := nc.drainPending()
|
|
if len(pending) == 0 {
|
|
return true
|
|
}
|
|
|
|
// Send all batched changes for this node
|
|
for _, ch := range pending {
|
|
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
|
|
// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node
|
|
// reconnects between the hasActiveConnections() check and the Delete() call.
|
|
// TODO(kradalby): reevaluate if we want to keep this.
|
|
func (b *Batcher) cleanupOfflineNodes() {
|
|
cleanupThreshold := 15 * time.Minute
|
|
now := time.Now()
|
|
|
|
var nodesToCleanup []types.NodeID
|
|
|
|
// Find nodes that have been offline for too long
|
|
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
|
|
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
|
|
nodesToCleanup = append(nodesToCleanup, nodeID)
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
// Clean up the identified nodes using Compute() for atomic check-and-delete.
|
|
// This prevents a TOCTOU race where a node reconnects (adding an active
|
|
// connection) between the hasActiveConnections() check and the Delete() call.
|
|
cleaned := 0
|
|
|
|
for _, nodeID := range nodesToCleanup {
|
|
b.nodes.Compute(
|
|
nodeID,
|
|
func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) {
|
|
if !loaded || conn == nil || conn.hasActiveConnections() {
|
|
return conn, xsync.CancelOp
|
|
}
|
|
|
|
// Perform all bookkeeping inside the Compute callback so
|
|
// that a concurrent AddNode (which calls LoadOrStore on
|
|
// b.nodes) cannot slip in between the delete and the
|
|
// connected/counter updates.
|
|
b.connected.Delete(nodeID)
|
|
b.totalNodes.Add(-1)
|
|
|
|
cleaned++
|
|
|
|
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
|
|
Dur("offline_duration", cleanupThreshold).
|
|
Msg("cleaning up node that has been offline for too long")
|
|
|
|
return conn, xsync.DeleteOp
|
|
},
|
|
)
|
|
}
|
|
|
|
if cleaned > 0 {
|
|
log.Info().Int(zf.CleanedNodes, cleaned).
|
|
Msg("completed cleanup of long-offline nodes")
|
|
}
|
|
}
|
|
|
|
// IsConnected is lock-free read that checks if a node has any active connections.
|
|
func (b *Batcher) IsConnected(id types.NodeID) bool {
|
|
// First check if we have active connections for this node
|
|
if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil {
|
|
if nodeConn.hasActiveConnections() {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check disconnected timestamp with grace period
|
|
val, ok := b.connected.Load(id)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// nil means connected
|
|
if val == nil {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// ConnectedMap returns a lock-free map of all connected nodes.
|
|
func (b *Batcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|
ret := xsync.NewMap[types.NodeID, bool]()
|
|
|
|
// First, add all nodes with active connections
|
|
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
|
if nodeConn == nil {
|
|
return true
|
|
}
|
|
|
|
if nodeConn.hasActiveConnections() {
|
|
ret.Store(id, true)
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
// Then add all entries from the connected map
|
|
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
|
// Only add if not already added as connected above
|
|
if _, exists := ret.Load(id); !exists {
|
|
if val == nil {
|
|
// nil means connected
|
|
ret.Store(id, true)
|
|
} else {
|
|
// timestamp means disconnected
|
|
ret.Store(id, false)
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return ret
|
|
}
|
|
|
|
// markDisconnectedIfNoConns stores a disconnect timestamp in b.connected
|
|
// when the node has no remaining active connections. This prevents
|
|
// IsConnected from returning a stale true after all connections have been
|
|
// removed on an error path (e.g. AddNode failure).
|
|
func (b *Batcher) markDisconnectedIfNoConns(id types.NodeID, nc *multiChannelNodeConn) {
|
|
if !nc.hasActiveConnections() {
|
|
now := time.Now()
|
|
b.connected.Store(id, &now)
|
|
}
|
|
}
|
|
|
|
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
|
// This allows synchronous map generation using the same worker pool.
|
|
func (b *Batcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) {
|
|
resultCh := make(chan workResult, 1)
|
|
|
|
// Queue the work with a result channel using the safe queueing method
|
|
b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh})
|
|
|
|
// Wait for the result
|
|
select {
|
|
case result := <-resultCh:
|
|
return result.mapResponse, result.err
|
|
case <-b.done:
|
|
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
|
|
}
|
|
}
|
|
|
|
// connectionEntry represents a single connection to a node.
|
|
type connectionEntry struct {
|
|
id string // unique connection ID
|
|
c chan<- *tailcfg.MapResponse
|
|
version tailcfg.CapabilityVersion
|
|
created time.Time
|
|
stop func()
|
|
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
|
closed atomic.Bool // Indicates if this connection has been closed
|
|
}
|
|
|
|
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
|
type multiChannelNodeConn struct {
|
|
id types.NodeID
|
|
mapper *mapper
|
|
log zerolog.Logger
|
|
|
|
mutex sync.RWMutex
|
|
connections []*connectionEntry
|
|
|
|
// pendingMu protects pending changes independently of the connection mutex.
|
|
// This avoids contention between addToBatch (which appends changes) and
|
|
// send() (which sends data to connections).
|
|
pendingMu sync.Mutex
|
|
pending []change.Change
|
|
|
|
closeOnce sync.Once
|
|
updateCount atomic.Int64
|
|
|
|
// lastSentPeers tracks which peers were last sent to this node.
|
|
// This enables computing diffs for policy changes instead of sending
|
|
// full peer lists (which clients interpret as "no change" when empty).
|
|
// Using xsync.Map for lock-free concurrent access.
|
|
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
|
|
}
|
|
|
|
// generateConnectionID generates a unique connection identifier.
|
|
func generateConnectionID() string {
|
|
bytes := make([]byte, 8)
|
|
_, _ = rand.Read(bytes)
|
|
|
|
return hex.EncodeToString(bytes)
|
|
}
|
|
|
|
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
|
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
|
return &multiChannelNodeConn{
|
|
id: id,
|
|
mapper: mapper,
|
|
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
|
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
|
|
}
|
|
}
|
|
|
|
func (mc *multiChannelNodeConn) close() {
|
|
mc.closeOnce.Do(func() {
|
|
mc.mutex.Lock()
|
|
defer mc.mutex.Unlock()
|
|
|
|
for _, conn := range mc.connections {
|
|
mc.stopConnection(conn)
|
|
}
|
|
})
|
|
}
|
|
|
|
// stopConnection marks a connection as closed and tears down the owning session
|
|
// at most once, even if multiple cleanup paths race to remove it.
|
|
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
|
|
if conn.closed.CompareAndSwap(false, true) {
|
|
if conn.stop != nil {
|
|
conn.stop()
|
|
}
|
|
}
|
|
}
|
|
|
|
// removeConnectionAtIndexLocked removes the active connection at index.
|
|
// If stopConnection is true, it also stops that session.
|
|
// Caller must hold mc.mutex.
|
|
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
|
|
conn := mc.connections[i]
|
|
copy(mc.connections[i:], mc.connections[i+1:])
|
|
mc.connections[len(mc.connections)-1] = nil // release pointer for GC
|
|
mc.connections = mc.connections[:len(mc.connections)-1]
|
|
|
|
if stopConnection {
|
|
mc.stopConnection(conn)
|
|
}
|
|
|
|
return conn
|
|
}
|
|
|
|
// addConnection adds a new connection.
|
|
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
|
mc.mutex.Lock()
|
|
defer mc.mutex.Unlock()
|
|
|
|
mc.connections = append(mc.connections, entry)
|
|
mc.log.Debug().Str(zf.ConnID, entry.id).
|
|
Int("total_connections", len(mc.connections)).
|
|
Msg("connection added")
|
|
}
|
|
|
|
// removeConnectionByChannel removes a connection by matching channel pointer.
|
|
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
|
|
mc.mutex.Lock()
|
|
defer mc.mutex.Unlock()
|
|
|
|
for i, entry := range mc.connections {
|
|
if entry.c == c {
|
|
mc.removeConnectionAtIndexLocked(i, false)
|
|
mc.log.Debug().Str(zf.ConnID, entry.id).
|
|
Int("remaining_connections", len(mc.connections)).
|
|
Msg("connection removed")
|
|
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// hasActiveConnections checks if the node has any active connections.
|
|
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
|
|
mc.mutex.RLock()
|
|
defer mc.mutex.RUnlock()
|
|
|
|
return len(mc.connections) > 0
|
|
}
|
|
|
|
// getActiveConnectionCount returns the number of active connections.
|
|
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
|
mc.mutex.RLock()
|
|
defer mc.mutex.RUnlock()
|
|
|
|
return len(mc.connections)
|
|
}
|
|
|
|
// appendPending appends changes to this node's pending change list.
|
|
// Thread-safe via pendingMu; does not contend with the connection mutex.
|
|
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
|
|
mc.pendingMu.Lock()
|
|
mc.pending = append(mc.pending, changes...)
|
|
mc.pendingMu.Unlock()
|
|
}
|
|
|
|
// drainPending atomically removes and returns all pending changes.
|
|
// Returns nil if there are no pending changes.
|
|
func (mc *multiChannelNodeConn) drainPending() []change.Change {
|
|
mc.pendingMu.Lock()
|
|
p := mc.pending
|
|
mc.pending = nil
|
|
mc.pendingMu.Unlock()
|
|
|
|
return p
|
|
}
|
|
|
|
// send broadcasts data to all active connections for the node.
|
|
//
|
|
// To avoid holding the write lock during potentially slow sends (each stale
|
|
// connection can block for up to 50ms), the method snapshots connections under
|
|
// a read lock, sends without any lock held, then write-locks only to remove
|
|
// failures. New connections added between the snapshot and cleanup are safe:
|
|
// they receive a full initial map via AddNode, so missing this update causes
|
|
// no data loss.
|
|
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|
if data == nil {
|
|
return nil
|
|
}
|
|
|
|
// Snapshot connections under read lock.
|
|
mc.mutex.RLock()
|
|
|
|
if len(mc.connections) == 0 {
|
|
mc.mutex.RUnlock()
|
|
mc.log.Trace().
|
|
Msg("send: no active connections, skipping")
|
|
|
|
return nil
|
|
}
|
|
|
|
// Copy the slice so we can release the read lock before sending.
|
|
snapshot := make([]*connectionEntry, len(mc.connections))
|
|
copy(snapshot, mc.connections)
|
|
mc.mutex.RUnlock()
|
|
|
|
mc.log.Trace().
|
|
Int("total_connections", len(snapshot)).
|
|
Msg("send: broadcasting")
|
|
|
|
// Send to all connections without holding any lock.
|
|
// Stale connection timeouts (50ms each) happen here without blocking
|
|
// other goroutines that need the mutex.
|
|
var (
|
|
lastErr error
|
|
successCount int
|
|
failed []*connectionEntry
|
|
)
|
|
|
|
for _, conn := range snapshot {
|
|
err := conn.send(data)
|
|
if err != nil {
|
|
lastErr = err
|
|
|
|
failed = append(failed, conn)
|
|
|
|
mc.log.Warn().Err(err).
|
|
Str(zf.ConnID, conn.id).
|
|
Msg("send: connection failed")
|
|
} else {
|
|
successCount++
|
|
}
|
|
}
|
|
|
|
// Write-lock only to remove failed connections.
|
|
if len(failed) > 0 {
|
|
mc.mutex.Lock()
|
|
// Remove by pointer identity: only remove entries that still exist
|
|
// in the current connections slice and match a failed pointer.
|
|
// New connections added since the snapshot are not affected.
|
|
failedSet := make(map[*connectionEntry]struct{}, len(failed))
|
|
for _, f := range failed {
|
|
failedSet[f] = struct{}{}
|
|
}
|
|
|
|
clean := mc.connections[:0]
|
|
for _, conn := range mc.connections {
|
|
if _, isFailed := failedSet[conn]; !isFailed {
|
|
clean = append(clean, conn)
|
|
} else {
|
|
mc.log.Debug().
|
|
Str(zf.ConnID, conn.id).
|
|
Msg("send: removing failed connection")
|
|
// Tear down the owning session so the old serveLongPoll
|
|
// goroutine exits instead of lingering as a stale session.
|
|
mc.stopConnection(conn)
|
|
}
|
|
}
|
|
|
|
// Nil out trailing slots so removed *connectionEntry values
|
|
// are not retained by the backing array.
|
|
for i := len(clean); i < len(mc.connections); i++ {
|
|
mc.connections[i] = nil
|
|
}
|
|
|
|
mc.connections = clean
|
|
mc.mutex.Unlock()
|
|
}
|
|
|
|
mc.updateCount.Add(1)
|
|
|
|
mc.log.Trace().
|
|
Int("successful_sends", successCount).
|
|
Int("failed_connections", len(failed)).
|
|
Msg("send: broadcast complete")
|
|
|
|
// Success if at least one send succeeded
|
|
if successCount > 0 {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
|
|
}
|
|
|
|
// send sends data to a single connection entry with timeout-based stale connection detection.
|
|
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
|
if data == nil {
|
|
return nil
|
|
}
|
|
|
|
// Check if the connection has been closed to prevent send on closed channel panic.
|
|
// This can happen during shutdown when Close() is called while workers are still processing.
|
|
if entry.closed.Load() {
|
|
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
|
|
}
|
|
|
|
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
|
// This is critical for detecting Docker containers that are forcefully terminated
|
|
// but still have channels that appear open.
|
|
//
|
|
// We use time.NewTimer + Stop instead of time.After to avoid leaking timers.
|
|
// time.After creates a timer that lives in the runtime's timer heap until it fires,
|
|
// even when the send succeeds immediately. On the hot path (1000+ nodes per tick),
|
|
// this leaks thousands of timers per second.
|
|
timer := time.NewTimer(50 * time.Millisecond) //nolint:mnd
|
|
defer timer.Stop()
|
|
|
|
select {
|
|
case entry.c <- data:
|
|
// Update last used timestamp on successful send
|
|
entry.lastUsed.Store(time.Now().Unix())
|
|
return nil
|
|
case <-timer.C:
|
|
// Connection is likely stale - client isn't reading from channel
|
|
// This catches the case where Docker containers are killed but channels remain open
|
|
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
|
}
|
|
}
|
|
|
|
// nodeID returns the node ID.
|
|
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
|
|
return mc.id
|
|
}
|
|
|
|
// version returns the capability version from the first active connection.
|
|
// All connections for a node should have the same version in practice.
|
|
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
|
mc.mutex.RLock()
|
|
defer mc.mutex.RUnlock()
|
|
|
|
if len(mc.connections) == 0 {
|
|
return 0
|
|
}
|
|
|
|
return mc.connections[0].version
|
|
}
|
|
|
|
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
|
|
// This must be called after successfully sending a response to keep track of
|
|
// what the client knows about, enabling accurate diffs for future updates.
|
|
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
// Full peer list replaces tracked state entirely
|
|
if resp.Peers != nil {
|
|
mc.lastSentPeers.Clear()
|
|
|
|
for _, peer := range resp.Peers {
|
|
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
|
}
|
|
}
|
|
|
|
// Incremental additions
|
|
for _, peer := range resp.PeersChanged {
|
|
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
|
}
|
|
|
|
// Incremental removals
|
|
for _, id := range resp.PeersRemoved {
|
|
mc.lastSentPeers.Delete(id)
|
|
}
|
|
}
|
|
|
|
// computePeerDiff compares the current peer list against what was last sent
|
|
// and returns the peers that were removed (in lastSentPeers but not in current).
|
|
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
|
|
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
|
|
for _, id := range currentPeers {
|
|
currentSet[id] = struct{}{}
|
|
}
|
|
|
|
var removed []tailcfg.NodeID
|
|
|
|
// Find removed: in lastSentPeers but not in current
|
|
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
|
|
if _, exists := currentSet[id]; !exists {
|
|
removed = append(removed, id)
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return removed
|
|
}
|
|
|
|
// change applies a change to all active connections for the node.
|
|
func (mc *multiChannelNodeConn) change(r change.Change) error {
|
|
return handleNodeChange(mc, mc.mapper, r)
|
|
}
|
|
|
|
// DebugNodeInfo contains debug information about a node's connections.
|
|
type DebugNodeInfo struct {
|
|
Connected bool `json:"connected"`
|
|
ActiveConnections int `json:"active_connections"`
|
|
}
|
|
|
|
// Debug returns a pre-baked map of node debug information for the debug interface.
|
|
func (b *Batcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|
result := make(map[types.NodeID]DebugNodeInfo)
|
|
|
|
// Get all nodes with their connection status using immediate connection logic
|
|
// (no grace period) for debug purposes
|
|
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
|
if nodeConn == nil {
|
|
return true
|
|
}
|
|
|
|
nodeConn.mutex.RLock()
|
|
activeConnCount := len(nodeConn.connections)
|
|
nodeConn.mutex.RUnlock()
|
|
|
|
// Use immediate connection status: if active connections exist, node is connected
|
|
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
|
|
connected := false
|
|
if activeConnCount > 0 {
|
|
connected = true
|
|
} else {
|
|
// Check connected map for immediate status
|
|
if val, ok := b.connected.Load(id); ok && val == nil {
|
|
connected = true
|
|
}
|
|
}
|
|
|
|
result[id] = DebugNodeInfo{
|
|
Connected: connected,
|
|
ActiveConnections: activeConnCount,
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
// Add all entries from the connected map to capture both connected and disconnected nodes
|
|
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
|
// Only add if not already processed above
|
|
if _, exists := result[id]; !exists {
|
|
// Use immediate connection status for debug (no grace period)
|
|
connected := (val == nil) // nil means connected, timestamp means disconnected
|
|
result[id] = DebugNodeInfo{
|
|
Connected: connected,
|
|
ActiveConnections: 0,
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
return result
|
|
}
|
|
|
|
func (b *Batcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
|
return b.mapper.debugMapResponses()
|
|
}
|
|
|
|
// WorkErrors returns the count of work errors encountered.
|
|
// This is primarily useful for testing and debugging.
|
|
func (b *Batcher) WorkErrors() int64 {
|
|
return b.workErrors.Load()
|
|
}
|