change: smarter change notifications

This commit replaces the ChangeSet with a simpler bool based
change model that can be directly used in the map builder to
build the appropriate map response based on the change that
has occured. Previously, we fell back to sending full maps
for a lot of changes as that was consider "the safe" thing to
do to ensure no updates were missed.

This was slightly problematic as a node that already has a list
of peers will only do full replacement of the peers if the list
is non-empty, meaning that it was not possible to remove all
nodes (if for example policy changed).

Now we will keep track of last seen nodes, so we can send remove
ids, but also we are much smarter on how we send smaller, partial
maps when needed.

Fixes #2389

Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
This commit is contained in:
Kristoffer Dalby
2025-12-15 14:36:21 +00:00
parent f67ed36fe2
commit 5767ca5085
12 changed files with 1280 additions and 616 deletions

View File

@@ -13,18 +13,13 @@ import (
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var (
mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type and change type",
}, []string{"response_type", "change_type"})
errNodeNotFoundInNodeStore = errors.New("node not found in NodeStore")
)
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type",
}, []string{"response_type"})
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
@@ -36,8 +31,8 @@ type Batcher interface {
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c ...change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
AddWork(r ...change.Change)
MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error)
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
}
@@ -51,7 +46,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
}
}
@@ -69,15 +64,21 @@ type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
// updateSentPeers updates the tracking of which peers have been sent to this node.
updateSentPeers(resp *tailcfg.MapResponse)
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
if c.Empty() {
return nil, nil
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
nodeID := nc.nodeID()
version := nc.version()
if r.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send
}
// Validate inputs before processing
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
}
@@ -86,141 +87,58 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
// Handle self-only responses
if r.IsSelfOnly() && r.TargetNode != nodeID {
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
}
var (
mapResp *tailcfg.MapResponse
err error
responseType string
mapResp *tailcfg.MapResponse
err error
)
// Record metric when function exits
defer func() {
if err == nil && mapResp != nil && responseType != "" {
mapResponseGenerated.WithLabelValues(responseType, c.Change.String()).Inc()
}
}()
// Track metric using categorized type, not free-form reason
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
switch c.Change {
case change.DERP:
responseType = "derp"
mapResp, err = mapper.derpMapResponse(nodeID)
// Check if this requires runtime peer visibility computation (e.g., policy changes)
if r.RequiresRuntimePeerComputation {
currentPeers := mapper.state.ListPeers(nodeID)
case change.NodeCameOnline, change.NodeWentOffline:
if c.IsSubnetRouter {
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
responseType = "full"
mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
// Trust the change type for online/offline status to avoid race conditions
// between NodeStore updates and change processing
responseType = string(patchResponseDebug)
onlineStatus := c.Change == change.NodeCameOnline
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
Online: ptr.To(onlineStatus),
},
})
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
for _, peer := range currentPeers.All() {
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
}
case change.NodeNewOrUpdate:
// If the node is the one being updated, we send a self update that preserves peer information
// to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes)
// without losing its view of peer status during rapid reconnection cycles
if c.IsSelfUpdate(nodeID) {
responseType = "self"
mapResp, err = mapper.selfMapResponse(nodeID, version)
} else {
responseType = "change"
mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID)
}
case change.NodeRemove:
responseType = "remove"
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
case change.NodeKeyExpiry:
// If the node is the one whose key is expiring, we send a "full" self update
// as nodes will ignore patch updates about themselves (?).
if c.IsSelfUpdate(nodeID) {
responseType = "self"
mapResp, err = mapper.selfMapResponse(nodeID, version)
// mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
responseType = "patch"
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
KeyExpiry: c.NodeExpiry,
},
})
}
case change.NodeEndpoint, change.NodeDERP:
// Endpoint or DERP changes can be sent as lightweight patches.
// Query the NodeStore for the current peer state to construct the PeerChange.
// Even if only endpoint or only DERP changed, we include both in the patch
// since they're often updated together and it's minimal overhead.
responseType = "patch"
peer, found := mapper.state.GetNodeByID(c.NodeID)
if !found {
return nil, fmt.Errorf("%w: %d", errNodeNotFoundInNodeStore, c.NodeID)
}
peerChange := &tailcfg.PeerChange{
NodeID: c.NodeID.NodeID(),
Endpoints: peer.Endpoints().AsSlice(),
DERPRegion: 0, // Will be set below if available
}
// Extract DERP region from Hostinfo if available
if hi := peer.AsStruct().Hostinfo; hi != nil && hi.NetInfo != nil {
peerChange.DERPRegion = hi.NetInfo.PreferredDERP
}
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{peerChange})
default:
// The following will always hit this:
// change.Full, change.Policy
responseType = "full"
mapResp, err = mapper.fullMapResponse(nodeID, version)
removedPeers := nc.computePeerDiff(currentPeerIDs)
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers)
} else {
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
// TODO(kradalby): Is this necessary?
// Validate the generated map response - only check for nil response
// Note: mapResp.Node can be nil for peer updates, which is valid
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return errors.New("nodeConnection is nil")
}
nodeID := nc.nodeID()
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
var data *tailcfg.MapResponse
var err error
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
data, err := generateMapResponse(nc, mapper, r)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some change types
// No data to send is valid for some response types
return nil
}
@@ -230,6 +148,9 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
// Update peer tracking after successful send
nc.updateSentPeers(data)
return nil
}
@@ -241,7 +162,7 @@ type workResult struct {
// work represents a unit of work to be processed by workers.
type work struct {
c change.ChangeSet
r change.Change
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}

View File

@@ -33,7 +33,7 @@ type LockFreeBatcher struct {
done chan struct{}
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
pendingChanges *xsync.Map[types.NodeID, []change.Change]
// Metrics
totalNodes atomic.Int64
@@ -141,8 +141,8 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
}
// AddWork queues a change to be processed by the batcher.
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
b.addWork(c...)
func (b *LockFreeBatcher) AddWork(r ...change.Change) {
b.addWork(r...)
}
func (b *LockFreeBatcher) Start() {
@@ -211,15 +211,19 @@ func (b *LockFreeBatcher) worker(workerID int) {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
var err error
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.r)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("worker.id", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Str("reason", w.r.Reason).
Msg("failed to generate map response for synchronous work")
} else if result.mapResponse != nil {
// Update peer tracking for synchronous responses too
nc.updateSentPeers(result.mapResponse)
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
@@ -247,13 +251,13 @@ func (b *LockFreeBatcher) worker(workerID int) {
if nc, exists := b.nodes.Load(w.nodeID); exists {
// Apply change to node - this will handle offline nodes gracefully
// and queue work for when they reconnect
err := nc.change(w.c)
err := nc.change(w.r)
if err != nil {
b.workErrors.Add(1)
log.Error().Err(err).
Int("worker.id", workerID).
Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()).
Uint64("node.id", w.nodeID.Uint64()).
Str("reason", w.r.Reason).
Msg("failed to apply change")
}
}
@@ -264,8 +268,8 @@ func (b *LockFreeBatcher) worker(workerID int) {
}
}
func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
b.addToBatch(c...)
func (b *LockFreeBatcher) addWork(r ...change.Change) {
b.addToBatch(r...)
}
// queueWork safely queues work.
@@ -281,38 +285,43 @@ func (b *LockFreeBatcher) queueWork(w work) {
}
}
// addToBatch adds a change to the pending batch.
func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
// Short circuit if any of the changes is a full update, which
// addToBatch adds a response to the pending batch.
func (b *LockFreeBatcher) addToBatch(responses ...change.Change) {
// Short circuit if any of the responses is a full update, which
// means we can skip sending individual changes.
if change.HasFull(c) {
if change.HasFull(responses) {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()})
return true
})
return
}
all, self := change.SplitAllAndSelf(c)
for _, changeSet := range self {
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
changes = append(changes, changeSet)
b.pendingChanges.Store(changeSet.NodeID, changes)
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
rel := change.RemoveUpdatesForSelf(nodeID, all)
broadcast, targeted := change.SplitTargetedAndBroadcast(responses)
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, rel...)
b.pendingChanges.Store(nodeID, changes)
// Handle targeted responses - send only to the specific node
for _, resp := range targeted {
changes, _ := b.pendingChanges.LoadOrStore(resp.TargetNode, []change.Change{})
changes = append(changes, resp)
b.pendingChanges.Store(resp.TargetNode, changes)
}
return true
})
// Handle broadcast responses - send to all nodes, filtering as needed
if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 {
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{})
changes = append(changes, filtered...)
b.pendingChanges.Store(nodeID, changes)
}
return true
})
}
}
// processBatchedChanges processes all pending batched changes.
@@ -322,14 +331,14 @@ func (b *LockFreeBatcher) processBatchedChanges() {
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
if len(changes) == 0 {
b.pendingChanges.Range(func(nodeID types.NodeID, responses []change.Change) bool {
if len(responses) == 0 {
return true
}
// Send all batched changes for this node
for _, c := range changes {
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
// Send all batched responses for this node
for _, r := range responses {
b.queueWork(work{r: r, nodeID: nodeID, resultCh: nil})
}
// Clear the pending changes for this node
@@ -432,11 +441,11 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
b.queueWork(work{r: r, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
@@ -466,6 +475,12 @@ type multiChannelNodeConn struct {
connections []*connectionEntry
updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node.
// This enables computing diffs for policy changes instead of sending
// full peer lists (which clients interpret as "no change" when empty).
// Using xsync.Map for lock-free concurrent access.
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
}
// generateConnectionID generates a unique connection identifier.
@@ -478,8 +493,9 @@ func generateConnectionID() string {
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id,
mapper: mapper,
id: id,
mapper: mapper,
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
}
}
@@ -662,9 +678,59 @@ func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
return mc.connections[0].version
}
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
// This must be called after successfully sending a response to keep track of
// what the client knows about, enabling accurate diffs for future updates.
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
if resp == nil {
return
}
// Full peer list replaces tracked state entirely
if resp.Peers != nil {
mc.lastSentPeers.Clear()
for _, peer := range resp.Peers {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
}
// Incremental additions
for _, peer := range resp.PeersChanged {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
// Incremental removals
for _, id := range resp.PeersRemoved {
mc.lastSentPeers.Delete(id)
}
}
// computePeerDiff compares the current peer list against what was last sent
// and returns the peers that were removed (in lastSentPeers but not in current).
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
for _, id := range currentPeers {
currentSet[id] = struct{}{}
}
var removed []tailcfg.NodeID
// Find removed: in lastSentPeers but not in current
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
if _, exists := currentSet[id]; !exists {
removed = append(removed, id)
}
return true
})
return removed
}
// change applies a change to all active connections for the node.
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
return handleNodeChange(mc, mc.mapper, c)
func (mc *multiChannelNodeConn) change(r change.Change) error {
return handleNodeChange(mc, mc.mapper, r)
}
// DebugNodeInfo contains debug information about a node's connections.

View File

@@ -59,7 +59,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo
return fmt.Errorf("%w: %d", errNodeNotFoundAfterAdd, id)
}
t.AddWork(change.NodeOnline(node))
t.AddWork(change.NodeOnlineFor(node))
return nil
}
@@ -76,7 +76,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
// Do this BEFORE removing from batcher so the change can be processed
node, ok := t.state.GetNodeByID(id)
if ok {
t.AddWork(change.NodeOffline(node))
t.AddWork(change.NodeOfflineFor(node))
}
// Finally remove from the real batcher
@@ -557,9 +557,9 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
}, time.Second, 10*time.Millisecond, "waiting for node connection")
// Generate work and wait for updates to be processed
batcher.AddWork(change.FullSet)
batcher.AddWork(change.PolicySet)
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.FullUpdate())
batcher.AddWork(change.PolicyChange())
batcher.AddWork(change.DERPMap())
// Wait for updates to be processed (at least 1 update received)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@@ -661,7 +661,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
// Yield to scheduler for large node counts to prevent overwhelming the work queue
if tc.nodeCount > 100 && i%50 == 49 {
@@ -832,7 +832,7 @@ func TestBatcherBasicOperations(t *testing.T) {
}
// Test work processing with DERP change
batcher.AddWork(change.DERPChange())
batcher.AddWork(change.DERPMap())
// Wait for update and validate content
select {
@@ -959,31 +959,31 @@ func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout ti
// }{
// {
// name: "DERP change",
// changeSet: change.DERPSet,
// changeSet: change.DERPMapResponse(),
// expectData: true,
// description: "DERP changes should generate map updates",
// },
// {
// name: "Node key expiry",
// changeSet: change.KeyExpiry(testNodes[1].n.ID),
// changeSet: change.KeyExpiryFor(testNodes[1].n.ID),
// expectData: true,
// description: "Node key expiry with real node data",
// },
// {
// name: "Node new registration",
// changeSet: change.NodeAdded(testNodes[1].n.ID),
// changeSet: change.NodeAddedResponse(testNodes[1].n.ID),
// expectData: true,
// description: "New node registration with real data",
// },
// {
// name: "Full update",
// changeSet: change.FullSet,
// changeSet: change.FullUpdateResponse(),
// expectData: true,
// description: "Full updates with real node data",
// },
// {
// name: "Policy change",
// changeSet: change.PolicySet,
// changeSet: change.PolicyChangeResponse(),
// expectData: true,
// description: "Policy updates with real node data",
// },
@@ -1057,13 +1057,13 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
var receivedUpdates []*tailcfg.MapResponse
// Add multiple changes rapidly to test batching
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry))
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.KeyExpiryFor(testNodes[1].n.ID, testExpiry))
batcher.AddWork(change.DERPMap())
batcher.AddWork(change.NodeAdded(testNodes[1].n.ID))
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
// Collect updates with timeout
updateCount := 0
@@ -1087,8 +1087,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
t.Logf("Update %d: nil update", updateCount)
}
case <-timeout:
// Expected: 5 changes should generate 6 updates (no batching in current implementation)
expectedUpdates := 6
// Expected: 5 explicit changes + 1 initial from AddNode + 1 NodeOnline from wrapper = 7 updates
expectedUpdates := 7
t.Logf("Received %d updates from %d changes (expected %d)",
updateCount, 5, expectedUpdates)
@@ -1160,7 +1160,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
// Add real work during connection chaos
if i%10 == 0 {
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
}
// Rapid second connection - should replace ch1
@@ -1260,7 +1260,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
// Consumer goroutine to validate data and detect channel issues
go func() {
@@ -1302,7 +1302,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
if i%10 == 0 {
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry))
batcher.AddWork(change.KeyExpiryFor(testNode.n.ID, testExpiry))
}
// Rapid removal creates race between worker and removal
@@ -1510,12 +1510,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Generate various types of work during racing
if i%3 == 0 {
// DERP changes
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
}
if i%5 == 0 {
// Full updates using real node data
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
}
if i%7 == 0 && len(allNodes) > 0 {
@@ -1523,7 +1523,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
node := allNodes[i%len(allNodes)]
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry))
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
}
// Yield to allow some batching
@@ -1778,7 +1778,7 @@ func XTestBatcherScalability(t *testing.T) {
}
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect")
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
// Wait for initial update to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@@ -1887,7 +1887,7 @@ func XTestBatcherScalability(t *testing.T) {
// Add work to create load
if index%5 == 0 {
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
}
}(
node.n.ID,
@@ -1914,11 +1914,11 @@ func XTestBatcherScalability(t *testing.T) {
// Generate different types of work to ensure updates are sent
switch index % 4 {
case 0:
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
case 1:
batcher.AddWork(change.PolicySet)
batcher.AddWork(change.PolicyChange())
case 2:
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
default:
// Pick a random node and generate a node change
if len(testNodes) > 0 {
@@ -1927,7 +1927,7 @@ func XTestBatcherScalability(t *testing.T) {
change.NodeAdded(testNodes[nodeIdx].n.ID),
)
} else {
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
}
}
}(i)
@@ -2165,7 +2165,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Send a full update - this should generate full peer lists
t.Logf("Sending FullSet update...")
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
// Wait for FullSet work items to be processed
t.Logf("Waiting for FullSet to be processed...")
@@ -2261,7 +2261,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
t.Logf("Total updates received across all nodes: %d", totalUpdates)
if !foundFullUpdate {
t.Errorf("CRITICAL: No FULL updates received despite sending change.FullSet!")
t.Errorf("CRITICAL: No FULL updates received despite sending change.FullUpdateResponse()!")
t.Errorf(
"This confirms the bug - FullSet updates are not generating full peer responses",
)
@@ -2372,7 +2372,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
// Send a change that should reach all nodes
batcher.AddWork(change.DERPChange())
batcher.AddWork(change.DERPMap())
receivedCount := 0
timeout := time.After(500 * time.Millisecond)
@@ -2508,11 +2508,7 @@ func TestBatcherMultiConnection(t *testing.T) {
clearChannel(node2.ch)
// Send a change notification from node2 (so node1 should receive it on all connections)
testChangeSet := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
testChangeSet := change.NodeAdded(node2.n.ID)
batcher.AddWork(testChangeSet)
@@ -2591,11 +2587,7 @@ func TestBatcherMultiConnection(t *testing.T) {
clearChannel(node1.ch)
clearChannel(thirdChannel)
testChangeSet2 := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
testChangeSet2 := change.NodeAdded(node2.n.ID)
batcher.AddWork(testChangeSet2)
@@ -2629,7 +2621,11 @@ func TestBatcherMultiConnection(t *testing.T) {
remaining1Received, remaining3Received)
}
// Verify second channel no longer receives updates (should be closed/removed)
// Drain secondChannel of any messages received before removal
// (the test wrapper sends NodeOffline before removal, which may have reached this channel)
clearChannel(secondChannel)
// Verify second channel no longer receives new updates after being removed
select {
case <-secondChannel:
t.Errorf("Removed connection still received update - this should not happen")

View File

@@ -29,10 +29,8 @@ type debugType string
const (
fullResponseDebug debugType = "full"
selfResponseDebug debugType = "self"
patchResponseDebug debugType = "patch"
removeResponseDebug debugType = "remove"
changeResponseDebug debugType = "change"
derpResponseDebug debugType = "derp"
policyResponseDebug debugType = "policy"
)
// NewMapResponseBuilder creates a new builder with basic fields set.

View File

@@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/rs/zerolog/log"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
@@ -179,52 +180,108 @@ func (m *mapper) selfMapResponse(
return ma, err
}
func (m *mapper) derpMapResponse(
nodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(derpResponseDebug).
WithDERPMap().
Build()
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *mapper) peerChangedPatchResponse(
nodeID types.NodeID,
changed []*tailcfg.PeerChange,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(patchResponseDebug).
WithPeerChangedPatch(changed).
Build()
}
// peerChangeResponse returns a MapResponse with changed or added nodes.
func (m *mapper) peerChangeResponse(
// policyChangeResponse creates a MapResponse for policy changes.
// It sends:
// - PeersRemoved for peers that are no longer visible after the policy change
// - PeersChanged for remaining peers (their AllowedIPs may have changed due to policy)
// - Updated PacketFilters
// - Updated SSHPolicy (SSH rules may reference users/groups that changed)
// This avoids the issue where an empty Peers slice is interpreted by Tailscale
// clients as "no change" rather than "no peers".
func (m *mapper) policyChangeResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
removedPeers []tailcfg.NodeID,
currentPeers views.Slice[types.NodeView],
) (*tailcfg.MapResponse, error) {
peers := m.state.ListPeers(nodeID, changedNodeID)
return m.NewMapResponseBuilder(nodeID).
WithDebugType(changeResponseDebug).
builder := m.NewMapResponseBuilder(nodeID).
WithDebugType(policyResponseDebug).
WithCapabilityVersion(capVer).
WithUserProfiles(peers).
WithPeerChanges(peers).
Build()
WithPacketFilters().
WithSSHPolicy()
if len(removedPeers) > 0 {
// Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved
removedIDs := make([]types.NodeID, len(removedPeers))
for i, id := range removedPeers {
removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent
}
builder.WithPeersRemoved(removedIDs...)
}
// Send remaining peers in PeersChanged - their AllowedIPs may have
// changed due to the policy update (e.g., different routes allowed).
if currentPeers.Len() > 0 {
builder.WithPeerChanges(currentPeers)
}
return builder.Build()
}
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
func (m *mapper) peerRemovedResponse(
// buildFromChange builds a MapResponse from a change.Change specification.
// This provides fine-grained control over what gets included in the response.
func (m *mapper) buildFromChange(
nodeID types.NodeID,
removedNodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
resp *change.Change,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(removeResponseDebug).
WithPeersRemoved(removedNodeID).
Build()
if resp.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error
}
// If this is a self-update (the changed node is the receiving node),
// send a self-update response to ensure the node sees its own changes.
if resp.OriginNode != 0 && resp.OriginNode == nodeID {
return m.selfMapResponse(nodeID, capVer)
}
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithDebugType(changeResponseDebug)
if resp.IncludeSelf {
builder.WithSelfNode()
}
if resp.IncludeDERPMap {
builder.WithDERPMap()
}
if resp.IncludeDNS {
builder.WithDNSConfig()
}
if resp.IncludeDomain {
builder.WithDomain()
}
if resp.IncludePolicy {
builder.WithPacketFilters()
builder.WithSSHPolicy()
}
if resp.SendAllPeers {
peers := m.state.ListPeers(nodeID)
builder.WithUserProfiles(peers)
builder.WithPeers(peers)
} else {
if len(resp.PeersChanged) > 0 {
peers := m.state.ListPeers(nodeID, resp.PeersChanged...)
builder.WithUserProfiles(peers)
builder.WithPeerChanges(peers)
}
if len(resp.PeersRemoved) > 0 {
builder.WithPeersRemoved(resp.PeersRemoved...)
}
}
if len(resp.PeerPatches) > 0 {
builder.WithPeerChangedPatch(resp.PeerPatches)
}
return builder.Build()
}
func writeDebugMapResponse(