mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-24 01:28:49 +02:00
all: fix golangci-lint issues (#3064)
This commit is contained in:
@@ -16,6 +16,14 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Mapper errors.
|
||||
var (
|
||||
ErrInvalidNodeID = errors.New("invalid nodeID")
|
||||
ErrMapperNil = errors.New("mapper is nil")
|
||||
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
||||
ErrNodeNotFoundMapper = errors.New("node not found")
|
||||
)
|
||||
|
||||
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "headscale",
|
||||
Name: "mapresponse_generated_total",
|
||||
@@ -81,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
||||
}
|
||||
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
||||
}
|
||||
|
||||
// Handle self-only responses
|
||||
@@ -136,7 +144,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
||||
if nc == nil {
|
||||
return errors.New("nodeConnection is nil")
|
||||
return ErrNodeConnectionNil
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
||||
@@ -2,6 +2,7 @@ package mapper
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
@@ -18,7 +19,13 @@ import (
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var errConnectionClosed = errors.New("connection channel already closed")
|
||||
// LockFreeBatcher errors.
|
||||
var (
|
||||
errConnectionClosed = errors.New("connection channel already closed")
|
||||
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
|
||||
ErrBatcherShuttingDown = errors.New("batcher shutting down")
|
||||
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
@@ -81,6 +88,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
if err != nil {
|
||||
nlog.Error().Err(err).Msg("initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("generating initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
@@ -90,11 +98,12 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second): //nolint:mnd
|
||||
nlog.Error().Err(errors.New("timeout")).Msg("initial map send timeout") //nolint:err113
|
||||
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||
nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
|
||||
nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
|
||||
Msg("initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("sending initial map to node %d: timeout", id)
|
||||
|
||||
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
@@ -135,6 +144,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
||||
nlog.Debug().Caller().
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("node connection removed but keeping online, other connections remain")
|
||||
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
@@ -219,10 +229,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
var err error
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
@@ -235,7 +247,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
||||
nc.updateSentPeers(result.mapResponse)
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
wlog.Error().Err(result.err).
|
||||
@@ -402,6 +414,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -454,6 +467,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -469,6 +483,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -488,7 +503,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,8 +538,9 @@ type multiChannelNodeConn struct {
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
_, _ = rand.Read(bytes)
|
||||
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
@@ -557,7 +573,9 @@ func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
@@ -579,9 +597,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
|
||||
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("successfully removed connection")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -615,6 +635,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
// This is not an error - the node will receive a full map when it reconnects
|
||||
mc.log.Debug().Caller().
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
@@ -623,7 +644,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
|
||||
successCount := 0
|
||||
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
@@ -632,8 +655,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
err := conn.send(data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
|
||||
failedConnections = append(failedConnections, i)
|
||||
mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
|
||||
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
|
||||
@@ -695,7 +720,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -805,6 +830,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -819,6 +845,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
||||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ type batcherTestCase struct {
|
||||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
|
||||
state *state.State
|
||||
}
|
||||
|
||||
@@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
||||
}
|
||||
|
||||
// Finally remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return t.Batcher.RemoveNode(id, c)
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
@@ -129,8 +125,6 @@ const (
|
||||
SMALL_BUFFER_SIZE = 3
|
||||
TINY_BUFFER_SIZE = 1 // For maximum contention
|
||||
LARGE_BUFFER_SIZE = 200
|
||||
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
// TestData contains all test entities created for a test scenario.
|
||||
@@ -241,8 +235,8 @@ func setupBatcherWithTestData(
|
||||
}
|
||||
|
||||
derpMap, err := derp.GetDERPMap(cfg.DERP)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, derpMap)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, derpMap)
|
||||
|
||||
state.SetDERPMap(derpMap)
|
||||
|
||||
@@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
|
||||
}
|
||||
|
||||
// getStats returns a copy of the statistics for a node.
|
||||
//
|
||||
//nolint:unused
|
||||
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
@@ -386,16 +382,14 @@ type UpdateInfo struct {
|
||||
}
|
||||
|
||||
// parseUpdateAndAnalyze parses an update and returns detailed information.
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
|
||||
info := UpdateInfo{
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
|
||||
return UpdateInfo{
|
||||
PeerCount: len(resp.Peers),
|
||||
PatchCount: len(resp.PeersChangedPatch),
|
||||
IsFull: len(resp.Peers) > 0,
|
||||
IsPatch: len(resp.PeersChangedPatch) > 0,
|
||||
IsDERP: resp.DERPMap != nil,
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// start begins consuming updates from the node's channel and tracking stats.
|
||||
@@ -417,7 +411,8 @@ func (n *node) start() {
|
||||
atomic.AddInt64(&n.updateCount, 1)
|
||||
|
||||
// Parse update and track detailed stats
|
||||
if info, err := parseUpdateAndAnalyze(data); err == nil {
|
||||
info := parseUpdateAndAnalyze(data)
|
||||
{
|
||||
// Track update types
|
||||
if info.IsFull {
|
||||
atomic.AddInt64(&n.fullCount, 1)
|
||||
@@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
||||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Wait for connection to be established
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
@@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
@@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
connectedCount := 0
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
@@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
||||
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
||||
|
||||
t.Logf("✅ All nodes achieved full connectivity!")
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
|
||||
// Disconnect all nodes
|
||||
@@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
tn := testData.Nodes[0]
|
||||
tn2 := testData.Nodes[1]
|
||||
tn := &testData.Nodes[0]
|
||||
tn2 := &testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
@@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// Drain any initial messages from first node
|
||||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
drainChannelTimeout(tn.ch, 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
@@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case data := <-ch:
|
||||
count++
|
||||
// Optional: add debug output if needed
|
||||
_ = data
|
||||
case <-ch:
|
||||
// Drain message
|
||||
case <-timer.C:
|
||||
return
|
||||
}
|
||||
@@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
@@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
||||
// even when real node updates are being processed, ensuring no race conditions
|
||||
// occur during channel replacement with actual workload.
|
||||
func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
// Create test environment with real database and nodes
|
||||
@@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
testNode := &testData.Nodes[0]
|
||||
|
||||
var (
|
||||
channelIssues int
|
||||
@@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Go(func() {
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Add real work during connection chaos
|
||||
@@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
|
||||
wg.Go(func() {
|
||||
runtime.Gosched() // Yield to introduce timing variability
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
|
||||
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Remove second connection
|
||||
@@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := testData.Nodes[0]
|
||||
testNode := &testData.Nodes[0]
|
||||
|
||||
var (
|
||||
panics int
|
||||
@@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
@@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
for range i % 3 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Yield to allow workers to process and close channels
|
||||
@@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
||||
// real node data. The test validates that stable clients continue to function
|
||||
// normally and receive proper updates despite the connection churn from other clients,
|
||||
// ensuring system stability under concurrent load.
|
||||
//
|
||||
//nolint:gocyclo // complex concurrent test scenario
|
||||
func TestBatcherConcurrentClients(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent client test in short mode")
|
||||
@@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
|
||||
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
|
||||
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
@@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, reason := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
@@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety
|
||||
for i := range numCycles {
|
||||
for _, node := range churningNodes {
|
||||
for j := range churningNodes {
|
||||
node := &churningNodes[j]
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// Connect churning node
|
||||
@@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
@@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
@@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
for range i % 5 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
@@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
|
||||
if i%7 == 0 && len(allNodes) > 0 {
|
||||
// Node-specific changes using real nodes
|
||||
node := allNodes[i%len(allNodes)]
|
||||
node := &allNodes[i%len(allNodes)]
|
||||
// Use a valid expiry time for testing since test nodes don't have expiry set
|
||||
testExpiry := time.Now().Add(24 * time.Hour)
|
||||
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
|
||||
@@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
|
||||
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
|
||||
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
if stats, exists := allStats[node.n.ID]; exists {
|
||||
stableUpdateCount += stats.TotalUpdates
|
||||
t.Logf("Stable node %d: %d updates",
|
||||
@@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, node := range churningNodes {
|
||||
for i := range churningNodes {
|
||||
node := &churningNodes[i]
|
||||
if stats, exists := allStats[node.n.ID]; exists {
|
||||
churningUpdateCount += stats.TotalUpdates
|
||||
}
|
||||
@@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify all stable clients are still functional
|
||||
for _, node := range stableNodes {
|
||||
for i := range stableNodes {
|
||||
node := &stableNodes[i]
|
||||
if !batcher.IsConnected(node.n.ID) {
|
||||
t.Errorf("Stable node %d lost connection during racing", node.n.ID)
|
||||
}
|
||||
@@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
||||
// It validates that the system remains stable with no deadlocks, panics, or
|
||||
// missed updates under sustained high load. The test uses real node data to
|
||||
// generate authentic update scenarios and tracks comprehensive statistics.
|
||||
//
|
||||
//nolint:gocyclo,thelper // complex scalability test scenario
|
||||
func XTestBatcherScalability(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping scalability test in short mode")
|
||||
@@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
description string
|
||||
}
|
||||
|
||||
var testCases []testCase
|
||||
testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes))
|
||||
|
||||
// Generate all combinations of the test matrix
|
||||
for _, nodeCount := range nodes {
|
||||
@@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
@@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
}
|
||||
|
||||
// Connection/disconnection cycles for subset of nodes
|
||||
for i, node := range chaosNodes {
|
||||
for i := range chaosNodes {
|
||||
node := &chaosNodes[i]
|
||||
// Only add work if this is connection chaos or mixed
|
||||
if tc.chaosType == "connection" || tc.chaosType == "mixed" {
|
||||
wg.Add(2)
|
||||
@@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) {
|
||||
channel,
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
@@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
t.Logf("Created %d nodes in database", len(allNodes))
|
||||
|
||||
// Connect nodes one at a time and wait for each to be connected
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
|
||||
// Wait for node to be connected
|
||||
@@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
||||
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
|
||||
|
||||
// Check how many peers each node should see
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
peers := testData.State.ListPeers(node.n.ID)
|
||||
t.Logf("Node %d should see %d peers from state", i, peers.Len())
|
||||
}
|
||||
@@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
for i, node := range allNodes {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node %d: %v", i, err)
|
||||
@@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
for i, node := range allNodes {
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
}
|
||||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
@@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
debugInfo := debugBatcher.Debug()
|
||||
disconnectedCount := 0
|
||||
|
||||
for i, node := range allNodes {
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
if info, exists := debugInfo[node.n.ID]; exists {
|
||||
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
|
||||
|
||||
@@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
@@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
@@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo // complex multi-connection test scenario
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
@@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
node1 := &testData.Nodes[0]
|
||||
node2 := &testData.Nodes[1]
|
||||
|
||||
t.Logf("=== MULTI-CONNECTION TEST ===")
|
||||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
@@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
@@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
@@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
@@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
@@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
||||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -36,6 +35,7 @@ const (
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
@@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFoundMapper)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
return nil, ErrNodeNotFoundMapper
|
||||
}
|
||||
|
||||
// Get unreduced matchers for peer relationship determination.
|
||||
@@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs))
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
// Build finalizes the response and returns marshaled bytes.
|
||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
||||
@@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build()
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
require.Nil(t, data)
|
||||
require.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
@@ -50,6 +49,7 @@ type mapper struct {
|
||||
created time.Time
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
type patch struct {
|
||||
timestamp time.Time
|
||||
change *tailcfg.PeerChange
|
||||
@@ -60,7 +60,6 @@ func newMapper(
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
@@ -76,6 +75,7 @@ func generateUserProfiles(
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.UserView)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
|
||||
user := node.Owner()
|
||||
if !user.Valid() {
|
||||
log.Error().
|
||||
@@ -84,14 +84,17 @@ func generateUserProfiles(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
userID := user.Model().ID
|
||||
userMap[userID] = &user
|
||||
ids = append(ids, userID)
|
||||
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.Owner()
|
||||
if !peerUser.Valid() {
|
||||
continue
|
||||
}
|
||||
|
||||
peerUserID := peerUser.Model().ID
|
||||
userMap[peerUserID] = &peerUser
|
||||
ids = append(ids, peerUserID)
|
||||
@@ -99,7 +102,9 @@ func generateUserProfiles(
|
||||
|
||||
slices.Sort(ids)
|
||||
ids = slices.Compact(ids)
|
||||
|
||||
var profiles []tailcfg.UserProfile
|
||||
|
||||
for _, id := range ids {
|
||||
if userMap[id] != nil {
|
||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||
@@ -149,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
//
|
||||
//nolint:unused
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
@@ -316,6 +323,7 @@ func writeDebugMapResponse(
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -329,6 +337,7 @@ func writeDebugMapResponse(
|
||||
)
|
||||
|
||||
log.Trace().Msgf("writing MapResponse to %s", mapResponsePath)
|
||||
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -337,7 +346,7 @@ func writeDebugMapResponse(
|
||||
|
||||
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
if debugDumpMapResponsePath == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: no data when debug path not set
|
||||
}
|
||||
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
@@ -350,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
}
|
||||
|
||||
result := make(map[types.NodeID][]tailcfg.MapResponse)
|
||||
|
||||
for _, node := range nodes {
|
||||
if !node.IsDir() {
|
||||
continue
|
||||
@@ -385,6 +395,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
||||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("unmarshalling file %s", file.Name())
|
||||
|
||||
@@ -3,14 +3,10 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
@@ -81,90 +77,3 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockState is a mock implementation that provides the required methods.
|
||||
type mockState struct {
|
||||
polMan policy.PolicyManager
|
||||
derpMap *tailcfg.DERPMap
|
||||
primary *routes.PrimaryRoutes
|
||||
nodes types.Nodes
|
||||
peers types.Nodes
|
||||
}
|
||||
|
||||
func (m *mockState) DERPMap() *tailcfg.DERPMap {
|
||||
return m.derpMap
|
||||
}
|
||||
|
||||
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
if m.polMan == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
return m.polMan.Filter()
|
||||
}
|
||||
|
||||
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
if m.polMan == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.polMan.SSHPolicy(node)
|
||||
}
|
||||
|
||||
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
if m.polMan == nil {
|
||||
return false
|
||||
}
|
||||
return m.polMan.NodeCanHaveTag(node, tag)
|
||||
}
|
||||
|
||||
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
|
||||
if m.primary == nil {
|
||||
return nil
|
||||
}
|
||||
return m.primary.PrimaryRoutes(nodeID)
|
||||
}
|
||||
|
||||
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(peerIDs) > 0 {
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
// Return all peers except the node itself
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if peer.ID != nodeID {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(nodeIDs) > 0 {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return m.nodes, nil
|
||||
}
|
||||
|
||||
func Test_fullMapResponse(t *testing.T) {
|
||||
t.Skip("Test needs to be refactored for new state-based architecture")
|
||||
// TODO: Refactor this test to work with the new state-based mapper
|
||||
// The test architecture needs to be updated to work with the state interface
|
||||
// instead of the old direct dependency injection pattern
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
func TestTailNode(t *testing.T) {
|
||||
mustNK := func(str string) key.NodePublic {
|
||||
var k key.NodePublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -26,6 +27,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
mustDK := func(str string) key.DiscoPublic {
|
||||
var k key.DiscoPublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -33,6 +35,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
mustMK := func(str string) key.MachinePublic {
|
||||
var k key.MachinePublic
|
||||
|
||||
_ = k.UnmarshalText([]byte(str))
|
||||
|
||||
return k
|
||||
@@ -255,7 +258,7 @@ func TestNodeExpiry(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "localtime",
|
||||
exp: tp(time.Time{}.Local()),
|
||||
exp: tp(time.Time{}.Local()), //nolint:gosmopolitan
|
||||
wantTimeZero: true,
|
||||
},
|
||||
}
|
||||
@@ -284,7 +287,9 @@ func TestNodeExpiry(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("nodeExpiry() error = %v", err)
|
||||
}
|
||||
|
||||
var deseri tailcfg.Node
|
||||
|
||||
err = json.Unmarshal(seri, &deseri)
|
||||
if err != nil {
|
||||
t.Fatalf("nodeExpiry() error = %v", err)
|
||||
|
||||
Reference in New Issue
Block a user