mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-19 07:54:17 +01:00
mapper: remove Batcher interface, rename to Batcher struct
Remove the Batcher interface since there is only one implementation. Rename LockFreeBatcher to Batcher and merge batcher_lockfree.go into batcher.go. Drop type assertions in debug.go now that mapBatcher is a concrete *mapper.Batcher pointer.
This commit is contained in:
@@ -101,7 +101,7 @@ type Headscale struct {
|
||||
// Things that generate changes
|
||||
extraRecordMan *dns.ExtraRecordsMan
|
||||
authProvider AuthProvider
|
||||
mapBatcher mapper.Batcher
|
||||
mapBatcher *mapper.Batcher
|
||||
|
||||
clientStreamsOpen sync.WaitGroup
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/arl/statsviz"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"tailscale.com/tsweb"
|
||||
@@ -329,38 +328,18 @@ func (h *Headscale) debugBatcher() string {
|
||||
|
||||
var nodes []nodeStatus
|
||||
|
||||
// Try to get detailed debug info if we have a LockFreeBatcher
|
||||
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
|
||||
debugInfo := batcher.Debug()
|
||||
for nodeID, info := range debugInfo {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: info.Connected,
|
||||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to basic connection info
|
||||
connectedMap := h.mapBatcher.ConnectedMap()
|
||||
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: connected,
|
||||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
||||
return true
|
||||
debugInfo := h.mapBatcher.Debug()
|
||||
for nodeID, info := range debugInfo {
|
||||
nodes = append(nodes, nodeStatus{
|
||||
id: nodeID,
|
||||
connected: info.Connected,
|
||||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by node ID
|
||||
@@ -410,28 +389,13 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
||||
TotalNodes: 0,
|
||||
}
|
||||
|
||||
// Try to get detailed debug info if we have a LockFreeBatcher
|
||||
if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok {
|
||||
debugInfo := batcher.Debug()
|
||||
for nodeID, debugData := range debugInfo {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: debugData.Connected,
|
||||
ActiveConnections: debugData.ActiveConnections,
|
||||
}
|
||||
info.TotalNodes++
|
||||
debugInfo := h.mapBatcher.Debug()
|
||||
for nodeID, debugData := range debugInfo {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: debugData.Connected,
|
||||
ActiveConnections: debugData.ActiveConnections,
|
||||
}
|
||||
} else {
|
||||
// Fallback to basic connection info
|
||||
connectedMap := h.mapBatcher.ConnectedMap()
|
||||
connectedMap.Range(func(nodeID types.NodeID, connected bool) bool {
|
||||
info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
|
||||
return true
|
||||
})
|
||||
info.TotalNodes++
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -148,8 +148,8 @@ func BenchmarkUpdateSentPeers(b *testing.B) {
|
||||
|
||||
// benchBatcher creates a lightweight batcher for benchmarks. Unlike the test
|
||||
// helper, it doesn't register cleanup and suppresses logging.
|
||||
func benchBatcher(nodeCount, bufferSize int) (*LockFreeBatcher, map[types.NodeID]chan *tailcfg.MapResponse) {
|
||||
b := &LockFreeBatcher{
|
||||
func benchBatcher(nodeCount, bufferSize int) (*Batcher, map[types.NodeID]chan *tailcfg.MapResponse) {
|
||||
b := &Batcher{
|
||||
tick: time.NewTicker(1 * time.Hour), // never fires during bench
|
||||
workers: 4,
|
||||
workCh: make(chan work, 4*200),
|
||||
|
||||
@@ -35,7 +35,7 @@ import (
|
||||
// lightweightBatcher provides a batcher with pre-populated nodes for testing
|
||||
// the batching, channel, and concurrency mechanics without database overhead.
|
||||
type lightweightBatcher struct {
|
||||
b *LockFreeBatcher
|
||||
b *Batcher
|
||||
channels map[types.NodeID]chan *tailcfg.MapResponse
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ type lightweightBatcher struct {
|
||||
func setupLightweightBatcher(t *testing.T, nodeCount, bufferSize int) *lightweightBatcher {
|
||||
t.Helper()
|
||||
|
||||
b := &LockFreeBatcher{
|
||||
b := &Batcher{
|
||||
tick: time.NewTicker(10 * time.Millisecond),
|
||||
workers: 4,
|
||||
workCh: make(chan work, 4*200),
|
||||
@@ -86,7 +86,7 @@ func (lb *lightweightBatcher) cleanup() {
|
||||
}
|
||||
|
||||
// countTotalPending counts total pending change entries across all nodes.
|
||||
func countTotalPending(b *LockFreeBatcher) int {
|
||||
func countTotalPending(b *Batcher) int {
|
||||
count := 0
|
||||
|
||||
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
@@ -101,7 +101,7 @@ func countTotalPending(b *LockFreeBatcher) int {
|
||||
}
|
||||
|
||||
// countNodesPending counts how many nodes have pending changes.
|
||||
func countNodesPending(b *LockFreeBatcher) int {
|
||||
func countNodesPending(b *Batcher) int {
|
||||
count := 0
|
||||
|
||||
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
@@ -120,7 +120,7 @@ func countNodesPending(b *LockFreeBatcher) int {
|
||||
}
|
||||
|
||||
// getPendingForNode returns pending changes for a specific node.
|
||||
func getPendingForNode(b *LockFreeBatcher, id types.NodeID) []change.Change {
|
||||
func getPendingForNode(b *Batcher, id types.NodeID) []change.Change {
|
||||
nc, ok := b.nodes.Load(id)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -1167,7 +1167,7 @@ func TestScale1000_MultiChannelBroadcast(t *testing.T) {
|
||||
)
|
||||
|
||||
// Create nodes with varying connection counts
|
||||
b := &LockFreeBatcher{
|
||||
b := &Batcher{
|
||||
tick: time.NewTicker(10 * time.Millisecond),
|
||||
workers: 4,
|
||||
workCh: make(chan work, 4*200),
|
||||
@@ -1569,7 +1569,7 @@ func TestScale1000_WorkChannelSaturation(t *testing.T) {
|
||||
defer zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
|
||||
// Create batcher with SMALL work channel to force saturation
|
||||
b := &LockFreeBatcher{
|
||||
b := &Batcher{
|
||||
tick: time.NewTicker(10 * time.Millisecond),
|
||||
workers: 2,
|
||||
workCh: make(chan work, 10), // Very small - will saturate
|
||||
|
||||
@@ -1,988 +0,0 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// LockFreeBatcher errors.
|
||||
var (
|
||||
errConnectionClosed = errors.New("connection channel already closed")
|
||||
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
|
||||
ErrBatcherShuttingDown = errors.New("batcher shutting down")
|
||||
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
nodes *xsync.Map[types.NodeID, *multiChannelNodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
workChOnce sync.Once // Ensures workCh is only closed once
|
||||
done chan struct{}
|
||||
doneOnce sync.Once // Ensures done is only closed once
|
||||
|
||||
started atomic.Bool // Ensures Start() is only called once
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
workQueuedCount atomic.Int64
|
||||
workProcessed atomic.Int64
|
||||
workErrors atomic.Int64
|
||||
}
|
||||
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// The stop function tears down the owning session if this connection is later declared stale.
|
||||
func (b *LockFreeBatcher) AddNode(
|
||||
id types.NodeID,
|
||||
c chan<- *tailcfg.MapResponse,
|
||||
version tailcfg.CapabilityVersion,
|
||||
stop func(),
|
||||
) error {
|
||||
addNodeStart := time.Now()
|
||||
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
||||
|
||||
// Generate connection ID
|
||||
connID := generateConnectionID()
|
||||
|
||||
// Create new connection entry
|
||||
now := time.Now()
|
||||
newEntry := &connectionEntry{
|
||||
id: connID,
|
||||
c: c,
|
||||
version: version,
|
||||
created: now,
|
||||
stop: stop,
|
||||
}
|
||||
// Initialize last used timestamp
|
||||
newEntry.lastUsed.Store(now.Unix())
|
||||
|
||||
// Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection
|
||||
nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper))
|
||||
|
||||
if !loaded {
|
||||
b.totalNodes.Add(1)
|
||||
}
|
||||
|
||||
// Add connection to the list (lock-free)
|
||||
nodeConn.addConnection(newEntry)
|
||||
|
||||
// Use the worker pool for controlled concurrency instead of direct generation
|
||||
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
|
||||
if err != nil {
|
||||
nlog.Error().Err(err).Msg("initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("generating initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Use a blocking send with timeout for initial map since the channel should be ready
|
||||
// and we want to avoid the race condition where the receiver isn't ready yet
|
||||
select {
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second): //nolint:mnd
|
||||
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
|
||||
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
// Node will automatically receive updates through the normal flow
|
||||
// The initial full map already contains all current state
|
||||
|
||||
nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("node connection established in batcher")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches one of the current connections, closes that specific connection,
|
||||
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
|
||||
// Reports if the node still has active connections after removal.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
|
||||
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
|
||||
|
||||
nodeConn, exists := b.nodes.Load(id)
|
||||
if !exists || nodeConn == nil {
|
||||
nlog.Debug().Caller().Msg("removeNode called for non-existent node")
|
||||
return false
|
||||
}
|
||||
|
||||
// Remove specific connection
|
||||
removed := nodeConn.removeConnectionByChannel(c)
|
||||
if !removed {
|
||||
nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
|
||||
}
|
||||
|
||||
// Check if node has any remaining active connections
|
||||
if nodeConn.hasActiveConnections() {
|
||||
nlog.Debug().Caller().
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("node connection removed but keeping online, other connections remain")
|
||||
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// No active connections - keep the node entry alive for rapid reconnections
|
||||
// The node will get a fresh full map when it reconnects
|
||||
nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, new(time.Now()))
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
func (b *LockFreeBatcher) AddWork(r ...change.Change) {
|
||||
b.addWork(r...)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
if !b.started.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
b.done = make(chan struct{})
|
||||
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
// Signal shutdown to all goroutines, only once
|
||||
b.doneOnce.Do(func() {
|
||||
if b.done != nil {
|
||||
close(b.done)
|
||||
}
|
||||
})
|
||||
|
||||
// Only close workCh once using sync.Once to prevent races
|
||||
b.workChOnce.Do(func() {
|
||||
close(b.workCh)
|
||||
})
|
||||
|
||||
// Close the underlying channels supplying the data to the clients.
|
||||
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
|
||||
if conn == nil {
|
||||
return true
|
||||
}
|
||||
conn.close()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
// Create a cleanup ticker for removing truly disconnected nodes
|
||||
cleanupTicker := time.NewTicker(5 * time.Minute)
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-cleanupTicker.C:
|
||||
// Clean up nodes that have been offline for too long
|
||||
b.cleanupOfflineNodes()
|
||||
case <-b.done:
|
||||
log.Info().Msg("batcher done channel closed, stopping to feed workers")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-b.workCh:
|
||||
if !ok {
|
||||
wlog.Debug().Msg("worker channel closing, shutting down")
|
||||
return
|
||||
}
|
||||
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
// where there is a blocking function waiting for the map that
|
||||
// is being generated.
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
|
||||
var err error
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
wlog.Error().Err(result.err).
|
||||
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||
Str(zf.Reason, w.c.Reason).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
} else if result.mapResponse != nil {
|
||||
// Update peer tracking for synchronous responses too
|
||||
nc.updateSentPeers(result.mapResponse)
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
wlog.Error().Err(result.err).
|
||||
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||
Msg("node not found for synchronous work")
|
||||
}
|
||||
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.done:
|
||||
return
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// If resultCh is nil, this is an asynchronous work request
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists && nc != nil {
|
||||
// Apply change to node - this will handle offline nodes gracefully
|
||||
// and queue work for when they reconnect
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
wlog.Error().Err(err).
|
||||
Uint64(zf.NodeID, w.nodeID.Uint64()).
|
||||
Str(zf.Reason, w.c.Reason).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
}
|
||||
case <-b.done:
|
||||
wlog.Debug().Msg("batcher shutting down, exiting worker")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(r ...change.Change) {
|
||||
b.addToBatch(r...)
|
||||
}
|
||||
|
||||
// queueWork safely queues work.
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.done:
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds changes to the pending batch.
|
||||
func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
|
||||
// Clean up any nodes being permanently removed from the system.
|
||||
//
|
||||
// This handles the case where a node is deleted from state but the batcher
|
||||
// still has it registered. By cleaning up here, we prevent "node not found"
|
||||
// errors when workers try to generate map responses for deleted nodes.
|
||||
//
|
||||
// Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually
|
||||
// deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy
|
||||
// changes that affect peer visibility do NOT use this field - they set
|
||||
// RequiresRuntimePeerComputation=true and compute removed peers at runtime,
|
||||
// putting them in tailcfg.MapResponse.PeersRemoved (a different struct).
|
||||
// Therefore, this cleanup only removes nodes that are truly being deleted,
|
||||
// not nodes that are still connected but have lost visibility of certain peers.
|
||||
//
|
||||
// See: https://github.com/juanfont/headscale/issues/2924
|
||||
for _, ch := range changes {
|
||||
for _, removedID := range ch.PeersRemoved {
|
||||
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
|
||||
b.totalNodes.Add(-1)
|
||||
log.Debug().
|
||||
Uint64(zf.NodeID, removedID.Uint64()).
|
||||
Msg("removed deleted node from batcher")
|
||||
}
|
||||
|
||||
b.connected.Delete(removedID)
|
||||
}
|
||||
}
|
||||
|
||||
// Short circuit if any of the changes is a full update, which
|
||||
// means we can skip sending individual changes.
|
||||
if change.HasFull(changes) {
|
||||
b.nodes.Range(func(_ types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
if nc == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
nc.pendingMu.Lock()
|
||||
nc.pending = []change.Change{change.FullUpdate()}
|
||||
nc.pendingMu.Unlock()
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
broadcast, targeted := change.SplitTargetedAndBroadcast(changes)
|
||||
|
||||
// Handle targeted changes - send only to the specific node
|
||||
for _, ch := range targeted {
|
||||
if nc, ok := b.nodes.Load(ch.TargetNode); ok && nc != nil {
|
||||
nc.appendPending(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle broadcast changes - send to all nodes, filtering as needed
|
||||
if len(broadcast) > 0 {
|
||||
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
if nc == nil {
|
||||
return true
|
||||
}
|
||||
filtered := change.FilterForNode(nodeID, broadcast)
|
||||
|
||||
if len(filtered) > 0 {
|
||||
nc.appendPending(filtered...)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes.
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.nodes.Range(func(nodeID types.NodeID, nc *multiChannelNodeConn) bool {
|
||||
if nc == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
pending := nc.drainPending()
|
||||
if len(pending) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Send all batched changes for this node
|
||||
for _, ch := range pending {
|
||||
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
|
||||
// Uses Compute() for atomic check-and-delete to prevent TOCTOU races where a node
|
||||
// reconnects between the hasActiveConnections() check and the Delete() call.
|
||||
// TODO(kradalby): reevaluate if we want to keep this.
|
||||
func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
cleanupThreshold := 15 * time.Minute
|
||||
now := time.Now()
|
||||
|
||||
var nodesToCleanup []types.NodeID
|
||||
|
||||
// Find nodes that have been offline for too long
|
||||
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
|
||||
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
|
||||
nodesToCleanup = append(nodesToCleanup, nodeID)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up the identified nodes using Compute() for atomic check-and-delete.
|
||||
// This prevents a TOCTOU race where a node reconnects (adding an active
|
||||
// connection) between the hasActiveConnections() check and the Delete() call.
|
||||
cleaned := 0
|
||||
for _, nodeID := range nodesToCleanup {
|
||||
deleted := false
|
||||
|
||||
b.nodes.Compute(
|
||||
nodeID,
|
||||
func(conn *multiChannelNodeConn, loaded bool) (*multiChannelNodeConn, xsync.ComputeOp) {
|
||||
if !loaded || conn == nil || conn.hasActiveConnections() {
|
||||
return conn, xsync.CancelOp
|
||||
}
|
||||
|
||||
deleted = true
|
||||
|
||||
return conn, xsync.DeleteOp
|
||||
},
|
||||
)
|
||||
|
||||
if deleted {
|
||||
log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
|
||||
Dur("offline_duration", cleanupThreshold).
|
||||
Msg("cleaning up node that has been offline for too long")
|
||||
|
||||
b.connected.Delete(nodeID)
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
cleaned++
|
||||
}
|
||||
}
|
||||
|
||||
if cleaned > 0 {
|
||||
log.Info().Int(zf.CleanedNodes, cleaned).
|
||||
Msg("completed cleanup of long-offline nodes")
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read that checks if a node has any active connections.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
// First check if we have active connections for this node
|
||||
if nodeConn, exists := b.nodes.Load(id); exists && nodeConn != nil {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check disconnected timestamp with grace period
|
||||
val, ok := b.connected.Load(id)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// nil means connected
|
||||
if val == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
// First, add all nodes with active connections
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
if nodeConn == nil {
|
||||
return true
|
||||
}
|
||||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Then add all entries from the connected map
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// Only add if not already added as connected above
|
||||
if _, exists := ret.Load(id); !exists {
|
||||
if val == nil {
|
||||
// nil means connected
|
||||
ret.Store(id, true)
|
||||
} else {
|
||||
// timestamp means disconnected
|
||||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||
// This allows synchronous map generation using the same worker pool.
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) {
|
||||
resultCh := make(chan workResult, 1)
|
||||
|
||||
// Queue the work with a result channel using the safe queueing method
|
||||
b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh})
|
||||
|
||||
// Wait for the result
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
|
||||
}
|
||||
}
|
||||
|
||||
// connectionEntry represents a single connection to a node.
|
||||
type connectionEntry struct {
|
||||
id string // unique connection ID
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
created time.Time
|
||||
stop func()
|
||||
lastUsed atomic.Int64 // Unix timestamp of last successful send
|
||||
closed atomic.Bool // Indicates if this connection has been closed
|
||||
}
|
||||
|
||||
// multiChannelNodeConn manages multiple concurrent connections for a single node.
|
||||
type multiChannelNodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
log zerolog.Logger
|
||||
|
||||
mutex sync.RWMutex
|
||||
connections []*connectionEntry
|
||||
|
||||
// pendingMu protects pending changes independently of the connection mutex.
|
||||
// This avoids contention between addToBatch (which appends changes) and
|
||||
// send() (which sends data to connections).
|
||||
pendingMu sync.Mutex
|
||||
pending []change.Change
|
||||
|
||||
closeOnce sync.Once
|
||||
updateCount atomic.Int64
|
||||
|
||||
// lastSentPeers tracks which peers were last sent to this node.
|
||||
// This enables computing diffs for policy changes instead of sending
|
||||
// full peer lists (which clients interpret as "no change" when empty).
|
||||
// Using xsync.Map for lock-free concurrent access.
|
||||
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
|
||||
}
|
||||
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
_, _ = rand.Read(bytes)
|
||||
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
|
||||
return &multiChannelNodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
|
||||
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *multiChannelNodeConn) close() {
|
||||
mc.closeOnce.Do(func() {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for _, conn := range mc.connections {
|
||||
mc.stopConnection(conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// stopConnection marks a connection as closed and tears down the owning session
|
||||
// at most once, even if multiple cleanup paths race to remove it.
|
||||
func (mc *multiChannelNodeConn) stopConnection(conn *connectionEntry) {
|
||||
if conn.closed.CompareAndSwap(false, true) {
|
||||
if conn.stop != nil {
|
||||
conn.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeConnectionAtIndexLocked removes the active connection at index.
|
||||
// If stopConnection is true, it also stops that session.
|
||||
// Caller must hold mc.mutex.
|
||||
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, stopConnection bool) *connectionEntry {
|
||||
conn := mc.connections[i]
|
||||
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
|
||||
|
||||
if stopConnection {
|
||||
mc.stopConnection(conn)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
|
||||
Int("total_connections", len(mc.connections)).
|
||||
Dur("mutex_wait_time", mutexWaitDur).
|
||||
Msg("successfully added connection after mutex wait")
|
||||
}
|
||||
|
||||
// removeConnectionByChannel removes a connection by matching channel pointer.
|
||||
func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
mc.mutex.Lock()
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
for i, entry := range mc.connections {
|
||||
if entry.c == c {
|
||||
mc.removeConnectionAtIndexLocked(i, false)
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("successfully removed connection")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// hasActiveConnections checks if the node has any active connections.
|
||||
func (mc *multiChannelNodeConn) hasActiveConnections() bool {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections) > 0
|
||||
}
|
||||
|
||||
// getActiveConnectionCount returns the number of active connections.
|
||||
func (mc *multiChannelNodeConn) getActiveConnectionCount() int {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
return len(mc.connections)
|
||||
}
|
||||
|
||||
// appendPending appends changes to this node's pending change list.
|
||||
// Thread-safe via pendingMu; does not contend with the connection mutex.
|
||||
func (mc *multiChannelNodeConn) appendPending(changes ...change.Change) {
|
||||
mc.pendingMu.Lock()
|
||||
mc.pending = append(mc.pending, changes...)
|
||||
mc.pendingMu.Unlock()
|
||||
}
|
||||
|
||||
// drainPending atomically removes and returns all pending changes.
|
||||
// Returns nil if there are no pending changes.
|
||||
func (mc *multiChannelNodeConn) drainPending() []change.Change {
|
||||
mc.pendingMu.Lock()
|
||||
p := mc.pending
|
||||
mc.pending = nil
|
||||
mc.pendingMu.Unlock()
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// send broadcasts data to all active connections for the node.
|
||||
// send broadcasts data to all connections using a two-phase approach to avoid
|
||||
// holding the write lock during potentially slow sends. Each stale connection
|
||||
// can block for up to 50ms (see connectionEntry.send), so N stale connections
|
||||
// under a single write lock would block for N*50ms. The two-phase approach:
|
||||
//
|
||||
// 1. RLock: snapshot the connections slice (cheap pointer copy)
|
||||
// 2. Unlock: send to all connections without any lock held (timeouts happen here)
|
||||
// 3. Lock: remove only the failed connections by pointer identity
|
||||
//
|
||||
// New connections added during step 2 are safe: they receive a full initial
|
||||
// map via AddNode, so missing this particular update causes no data loss.
|
||||
func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Phase 1: snapshot connections under read lock.
|
||||
mc.mutex.RLock()
|
||||
if len(mc.connections) == 0 {
|
||||
mc.mutex.RUnlock()
|
||||
mc.log.Debug().Caller().
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy the slice header (shares underlying array, but that's fine since
|
||||
// we only read; writes go through the write lock in phase 3).
|
||||
snapshot := make([]*connectionEntry, len(mc.connections))
|
||||
copy(snapshot, mc.connections)
|
||||
mc.mutex.RUnlock()
|
||||
|
||||
mc.log.Debug().Caller().
|
||||
Int("total_connections", len(snapshot)).
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
// Phase 2: send to all connections without holding any lock.
|
||||
// Stale connection timeouts (50ms each) happen here without blocking
|
||||
// other goroutines that need the mutex.
|
||||
var (
|
||||
lastErr error
|
||||
successCount int
|
||||
failed []*connectionEntry
|
||||
)
|
||||
|
||||
for _, conn := range snapshot {
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
err := conn.send(data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
|
||||
failed = append(failed, conn)
|
||||
|
||||
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: successfully sent to connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: write-lock only to remove failed connections.
|
||||
if len(failed) > 0 {
|
||||
mc.mutex.Lock()
|
||||
// Remove by pointer identity: only remove entries that still exist
|
||||
// in the current connections slice and match a failed pointer.
|
||||
// New connections added between phase 1 and 3 are not affected.
|
||||
failedSet := make(map[*connectionEntry]struct{}, len(failed))
|
||||
for _, f := range failed {
|
||||
failedSet[f] = struct{}{}
|
||||
}
|
||||
|
||||
clean := mc.connections[:0]
|
||||
for _, conn := range mc.connections {
|
||||
if _, isFailed := failedSet[conn]; !isFailed {
|
||||
clean = append(clean, conn)
|
||||
} else {
|
||||
mc.log.Debug().Caller().
|
||||
Str(zf.ConnID, conn.id).
|
||||
Msg("send: removing failed connection")
|
||||
// Tear down the owning session so the old serveLongPoll
|
||||
// goroutine exits instead of lingering as a stale session.
|
||||
mc.stopConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
mc.connections = clean
|
||||
mc.mutex.Unlock()
|
||||
}
|
||||
|
||||
mc.updateCount.Add(1)
|
||||
|
||||
mc.log.Debug().
|
||||
Int("successful_sends", successCount).
|
||||
Int("failed_connections", len(failed)).
|
||||
Msg("send: completed broadcast")
|
||||
|
||||
// Success if at least one send succeeded
|
||||
if successCount > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr)
|
||||
}
|
||||
|
||||
// send sends data to a single connection entry with timeout-based stale connection detection.
|
||||
func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the connection has been closed to prevent send on closed channel panic.
|
||||
// This can happen during shutdown when Close() is called while workers are still processing.
|
||||
if entry.closed.Load() {
|
||||
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
|
||||
}
|
||||
|
||||
// Use a short timeout to detect stale connections where the client isn't reading the channel.
|
||||
// This is critical for detecting Docker containers that are forcefully terminated
|
||||
// but still have channels that appear open.
|
||||
select {
|
||||
case entry.c <- data:
|
||||
// Update last used timestamp on successful send
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
return nil
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// nodeID returns the node ID.
|
||||
func (mc *multiChannelNodeConn) nodeID() types.NodeID {
|
||||
return mc.id
|
||||
}
|
||||
|
||||
// version returns the capability version from the first active connection.
|
||||
// All connections for a node should have the same version in practice.
|
||||
func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
|
||||
mc.mutex.RLock()
|
||||
defer mc.mutex.RUnlock()
|
||||
|
||||
if len(mc.connections) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return mc.connections[0].version
|
||||
}
|
||||
|
||||
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
|
||||
// This must be called after successfully sending a response to keep track of
|
||||
// what the client knows about, enabling accurate diffs for future updates.
|
||||
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Full peer list replaces tracked state entirely
|
||||
if resp.Peers != nil {
|
||||
mc.lastSentPeers.Clear()
|
||||
|
||||
for _, peer := range resp.Peers {
|
||||
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
// Incremental additions
|
||||
for _, peer := range resp.PeersChanged {
|
||||
mc.lastSentPeers.Store(peer.ID, struct{}{})
|
||||
}
|
||||
|
||||
// Incremental removals
|
||||
for _, id := range resp.PeersRemoved {
|
||||
mc.lastSentPeers.Delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
// computePeerDiff compares the current peer list against what was last sent
|
||||
// and returns the peers that were removed (in lastSentPeers but not in current).
|
||||
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
|
||||
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
|
||||
for _, id := range currentPeers {
|
||||
currentSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
var removed []tailcfg.NodeID
|
||||
|
||||
// Find removed: in lastSentPeers but not in current
|
||||
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
|
||||
if _, exists := currentSet[id]; !exists {
|
||||
removed = append(removed, id)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return removed
|
||||
}
|
||||
|
||||
// change applies a change to all active connections for the node.
|
||||
func (mc *multiChannelNodeConn) change(r change.Change) error {
|
||||
return handleNodeChange(mc, mc.mapper, r)
|
||||
}
|
||||
|
||||
// DebugNodeInfo contains debug information about a node's connections.
|
||||
type DebugNodeInfo struct {
|
||||
Connected bool `json:"connected"`
|
||||
ActiveConnections int `json:"active_connections"`
|
||||
}
|
||||
|
||||
// Debug returns a pre-baked map of node debug information for the debug interface.
|
||||
func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
result := make(map[types.NodeID]DebugNodeInfo)
|
||||
|
||||
// Get all nodes with their connection status using immediate connection logic
|
||||
// (no grace period) for debug purposes
|
||||
b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool {
|
||||
if nodeConn == nil {
|
||||
return true
|
||||
}
|
||||
nodeConn.mutex.RLock()
|
||||
activeConnCount := len(nodeConn.connections)
|
||||
nodeConn.mutex.RUnlock()
|
||||
|
||||
// Use immediate connection status: if active connections exist, node is connected
|
||||
// If not, check the connected map for nil (connected) vs timestamp (disconnected)
|
||||
connected := false
|
||||
if activeConnCount > 0 {
|
||||
connected = true
|
||||
} else {
|
||||
// Check connected map for immediate status
|
||||
if val, ok := b.connected.Load(id); ok && val == nil {
|
||||
connected = true
|
||||
}
|
||||
}
|
||||
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Add all entries from the connected map to capture both connected and disconnected nodes
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// Only add if not already processed above
|
||||
if _, exists := result[id]; !exists {
|
||||
// Use immediate connection status for debug (no grace period)
|
||||
connected := (val == nil) // nil means connected, timestamp means disconnected
|
||||
result[id] = DebugNodeInfo{
|
||||
Connected: connected,
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
return b.mapper.debugMapResponses()
|
||||
}
|
||||
|
||||
// WorkErrors returns the count of work errors encountered.
|
||||
// This is primarily useful for testing and debugging.
|
||||
func (b *LockFreeBatcher) WorkErrors() int64 {
|
||||
return b.workErrors.Load()
|
||||
}
|
||||
@@ -25,6 +25,8 @@ import (
|
||||
|
||||
var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher")
|
||||
|
||||
type batcherFunc func(cfg *types.Config, state *state.State) *Batcher
|
||||
|
||||
// batcherTestCase defines a batcher function with a descriptive name for testing.
|
||||
type batcherTestCase struct {
|
||||
name string
|
||||
@@ -34,7 +36,7 @@ type batcherTestCase struct {
|
||||
// testBatcherWrapper wraps a real batcher to add online/offline notifications
|
||||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
*Batcher
|
||||
|
||||
state *state.State
|
||||
}
|
||||
@@ -85,13 +87,13 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
func wrapBatcherForTest(b Batcher, state *state.State) Batcher {
|
||||
func wrapBatcherForTest(b *Batcher, state *state.State) *testBatcherWrapper {
|
||||
return &testBatcherWrapper{Batcher: b, state: state}
|
||||
}
|
||||
|
||||
// allBatcherFunctions contains all batcher implementations to test.
|
||||
var allBatcherFunctions = []batcherTestCase{
|
||||
{"LockFree", NewBatcherAndMapper},
|
||||
{"Default", NewBatcherAndMapper},
|
||||
}
|
||||
|
||||
// emptyCache creates an empty registration cache for testing.
|
||||
@@ -134,7 +136,7 @@ type TestData struct {
|
||||
Nodes []node
|
||||
State *state.State
|
||||
Config *types.Config
|
||||
Batcher Batcher
|
||||
Batcher *testBatcherWrapper
|
||||
}
|
||||
|
||||
type node struct {
|
||||
@@ -2354,46 +2356,35 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
// Check debug status after reconnection.
|
||||
t.Logf("Checking debug status...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
debugInfo := batcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
// Check if the debug info shows the node as connected
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if !info.Connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
|
||||
// Also check IsConnected method
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Logf("Node %d IsConnected() returns false", i)
|
||||
}
|
||||
}
|
||||
|
||||
if disconnectedCount > 0 {
|
||||
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
|
||||
// This is expected behavior for multi-channel batcher according to user
|
||||
// "it has never worked with the multi"
|
||||
} else {
|
||||
t.Logf("All nodes show as connected - working correctly")
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
// Also check IsConnected method
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Logf("Node %d IsConnected() returns false", i)
|
||||
}
|
||||
}
|
||||
|
||||
if disconnectedCount > 0 {
|
||||
t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes))
|
||||
} else {
|
||||
t.Logf("Batcher does not implement Debug() method")
|
||||
t.Logf("All nodes show as connected - working correctly")
|
||||
}
|
||||
|
||||
// Test if "disconnected" nodes can actually receive updates.
|
||||
@@ -2491,37 +2482,25 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
// Verify debug status shows correct connection count.
|
||||
t.Logf("Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
debugInfo := batcher.Debug()
|
||||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
t.Errorf("Node1 should have 3 active connections, got %d", activeConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
}
|
||||
if info.ActiveConnections != 3 {
|
||||
t.Errorf("Node1 should have 3 active connections, got %d", info.ActiveConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
|
||||
if info, exists := debugInfo[node2.n.ID]; exists {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 1 {
|
||||
t.Errorf("Node2 should have 1 active connection, got %d", activeConnections)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !info.Connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if info, exists := debugInfo[node2.n.ID]; exists {
|
||||
if info.ActiveConnections != 1 {
|
||||
t.Errorf("Node2 should have 1 active connection, got %d", info.ActiveConnections)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2604,20 +2583,12 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
runtime.Gosched()
|
||||
|
||||
// Verify debug status shows 2 connections now
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 2 {
|
||||
t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
|
||||
}
|
||||
}
|
||||
}
|
||||
debugInfo2 := batcher.Debug()
|
||||
if info, exists := debugInfo2[node1.n.ID]; exists {
|
||||
if info.ActiveConnections != 2 {
|
||||
t.Errorf("Node1 should have 2 active connections after removal, got %d", info.ActiveConnections)
|
||||
} else {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2731,11 +2702,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) {
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect")
|
||||
|
||||
// Get initial work errors count
|
||||
var initialWorkErrors int64
|
||||
if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok {
|
||||
initialWorkErrors = lfb.WorkErrors()
|
||||
t.Logf("Initial work errors: %d", initialWorkErrors)
|
||||
}
|
||||
lfb := unwrapBatcher(batcher)
|
||||
initialWorkErrors := lfb.WorkErrors()
|
||||
t.Logf("Initial work errors: %d", initialWorkErrors)
|
||||
|
||||
// Clear channels to prepare for the test
|
||||
drainCh(node1.ch)
|
||||
@@ -2777,11 +2746,7 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) {
|
||||
// With the fix, no new errors should occur because the deleted node
|
||||
// was cleaned up from batcher state when NodeRemoved was processed
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var finalWorkErrors int64
|
||||
if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok {
|
||||
finalWorkErrors = lfb.WorkErrors()
|
||||
}
|
||||
|
||||
finalWorkErrors := lfb.WorkErrors()
|
||||
newErrors := finalWorkErrors - initialWorkErrors
|
||||
assert.Zero(c, newErrors, "Fix for #2924: should have no work errors after node deletion")
|
||||
}, 5*time.Second, 100*time.Millisecond, "waiting for work processing to complete without errors")
|
||||
@@ -2809,8 +2774,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, NORMAL_BUFFER_SIZE)
|
||||
defer cleanup()
|
||||
|
||||
lfb, ok := unwrapBatcher(testData.Batcher).(*LockFreeBatcher)
|
||||
require.True(t, ok, "expected LockFreeBatcher")
|
||||
lfb := unwrapBatcher(testData.Batcher)
|
||||
|
||||
nodeID := testData.Nodes[0].n.ID
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
@@ -2838,8 +2802,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) {
|
||||
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, NORMAL_BUFFER_SIZE)
|
||||
defer cleanup()
|
||||
|
||||
lfb, ok := unwrapBatcher(testData.Batcher).(*LockFreeBatcher)
|
||||
require.True(t, ok, "expected LockFreeBatcher")
|
||||
lfb := unwrapBatcher(testData.Batcher)
|
||||
|
||||
nodeID := testData.Nodes[0].n.ID
|
||||
ch1 := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
@@ -2867,11 +2830,7 @@ func TestRemoveNodeChannelAlreadyRemoved(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// unwrapBatcher extracts the underlying batcher from wrapper types.
|
||||
func unwrapBatcher(b Batcher) Batcher {
|
||||
if wrapper, ok := b.(*testBatcherWrapper); ok {
|
||||
return unwrapBatcher(wrapper.Batcher)
|
||||
}
|
||||
|
||||
return b
|
||||
// unwrapBatcher extracts the underlying *Batcher from the test wrapper.
|
||||
func unwrapBatcher(b *testBatcherWrapper) *Batcher {
|
||||
return b.Batcher
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ type mapper struct {
|
||||
// Configuration
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
batcher Batcher
|
||||
batcher *Batcher
|
||||
|
||||
created time.Time
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user