hscontrol/poll,state: fix grace period disconnect TOCTOU race

When a node disconnects, serveLongPoll defers a cleanup that starts a
grace period goroutine. This goroutine polls batcher.IsConnected() and,
if the node has not reconnected within ~10 seconds, calls
state.Disconnect() to mark it offline. A TOCTOU race exists: the node
can reconnect (calling Connect()) between the IsConnected check and
the Disconnect() call, causing the stale Disconnect() to overwrite
the new session's online status.

Fix with a monotonic per-node generation counter:

- State.Connect() increments the counter and returns the current
  generation alongside the change list.
- State.Disconnect() accepts the generation from the caller and
  rejects the call if a newer generation exists, making stale
  disconnects from old sessions a no-op.
- serveLongPoll captures the generation at Connect() time and passes
  it to Disconnect() in the deferred cleanup.
- RemoveNode's return value is now checked: if another session already
  owns the batcher slot (reconnect happened), the old session skips
  the grace period entirely.

Update batcher_test.go to track per-node connect generations and
pass them through to Disconnect(), matching production behavior.

Fixes the following test failures:
- server_state_online_after_reconnect_within_grace
- update_history_no_false_offline
- nodestore_correct_after_rapid_reconnect
- rapid_reconnect_peer_never_sees_offline
This commit is contained in:
Kristoffer Dalby
2026-03-17 14:35:18 +00:00
parent 00c41b6422
commit b09af3846b
3 changed files with 100 additions and 11 deletions

View File

@@ -39,14 +39,20 @@ type testBatcherWrapper struct {
*Batcher
state *state.State
// connectGens tracks per-node connect generations so RemoveNode can pass
// the correct generation to State.Disconnect(), matching production behavior.
connectGens sync.Map // types.NodeID → uint64
}
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error {
// Mark node as online in state before AddNode to match production behavior
// This ensures the NodeStore has correct online status for change processing
if t.state != nil {
// Use Connect to properly mark node online in NodeStore but don't send its changes
_ = t.state.Connect(id)
// Use Connect to properly mark node online in NodeStore and track the
// generation so RemoveNode can pass it to Disconnect().
_, gen := t.state.Connect(id)
t.connectGens.Store(id, gen)
}
// First add the node to the real batcher
@@ -71,8 +77,15 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
// Mark node as offline in state BEFORE removing from batcher
// This ensures the NodeStore has correct offline status when the change is processed
if t.state != nil {
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
_, _ = t.state.Disconnect(id)
var gen uint64
if v, ok := t.connectGens.LoadAndDelete(id); ok {
if g, ok := v.(uint64); ok {
gen = g
}
}
_, _ = t.state.Disconnect(id, gen)
}
// Send the offline notification that poll.go would normally send

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/rs/zerolog"
@@ -147,11 +148,24 @@ func (m *mapSession) serveLongPoll() {
m.log.Trace().Caller().Msg("long poll session started")
// connectGen is set by Connect() below and captured by the deferred cleanup closure.
// It allows Disconnect() to reject stale calls from old sessions — if a newer session
// has called Connect() (incrementing the generation), the old session's Disconnect()
// sees a mismatched generation and becomes a no-op.
var connectGen uint64
// Clean up the session when the client disconnects
defer func() {
m.stopFromBatcher()
_ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
stillConnected := m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
// If another session already exists for this node (reconnect
// happened before this cleanup ran), skip the grace period
// entirely — the node is not actually disconnecting.
if stillConnected {
return
}
// When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather).
// Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects.
@@ -176,7 +190,11 @@ func (m *mapSession) serveLongPoll() {
}
if disconnected {
disconnectChanges, err := m.h.state.Disconnect(m.node.ID)
// Pass the generation from our Connect() call. If a newer session has
// connected since (bumping the generation), Disconnect() will detect
// the mismatch and skip the state update, preventing the race where
// an old grace period goroutine overwrites a newer session's online status.
disconnectChanges, err := m.h.state.Disconnect(m.node.ID, connectGen)
if err != nil {
m.log.Error().Caller().Err(err).Msg("failed to disconnect node")
}
@@ -215,7 +233,9 @@ func (m *mapSession) serveLongPoll() {
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
// While this results in two notifications, it ensures route data is synchronized before
// primary route selection occurs, which is critical for proper HA subnet router failover.
connectChanges := m.h.state.Connect(m.node.ID)
var connectChanges []change.Change
connectChanges, connectGen = m.h.state.Connect(m.node.ID)
m.log.Info().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("node has connected")

View File

@@ -99,6 +99,12 @@ type State struct {
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
// connectGen tracks a per-node monotonic generation counter so stale
// Disconnect() calls from old poll sessions are rejected. Connect()
// increments the counter and returns the current value; Disconnect()
// only proceeds when the generation it carries matches the latest.
connectGen sync.Map // types.NodeID → *atomic.Uint64
// sshCheckAuth tracks when source nodes last completed SSH check auth.
//
// For rules without explicit checkPeriod (default 12h), auth covers any
@@ -508,7 +514,15 @@ func (s *State) DeleteNode(node types.NodeView) (change.Change, error) {
}
// Connect marks a node as connected and updates its primary routes in the state.
func (s *State) Connect(id types.NodeID) []change.Change {
// It returns the list of changes and a generation number. The generation number
// must be passed to Disconnect() so that stale disconnects from old poll sessions
// are rejected (see the grace period logic in poll.go).
func (s *State) Connect(id types.NodeID) ([]change.Change, uint64) {
// Increment the connect generation for this node. This ensures that any
// in-flight Disconnect() from a previous session will see a stale generation
// and become a no-op.
gen := s.nextConnectGen(id)
// Update online status in NodeStore before creating change notification
// so the NodeStore already reflects the correct state when other nodes
// process the NodeCameOnline change for full map generation.
@@ -517,7 +531,7 @@ func (s *State) Connect(id types.NodeID) []change.Change {
// n.LastSeen = ptr.To(now)
})
if !ok {
return nil
return nil, gen
}
c := []change.Change{change.NodeOnlineFor(node)}
@@ -532,11 +546,53 @@ func (s *State) Connect(id types.NodeID) []change.Change {
c = append(c, change.NodeAdded(id))
}
return c
return c, gen
}
// nextConnectGen atomically increments and returns the connect generation for a node.
func (s *State) nextConnectGen(id types.NodeID) uint64 {
val, _ := s.connectGen.LoadOrStore(id, &atomic.Uint64{})
counter, ok := val.(*atomic.Uint64)
if !ok {
return 0
}
return counter.Add(1)
}
// connectGeneration returns the current connect generation for a node.
func (s *State) connectGeneration(id types.NodeID) uint64 {
val, ok := s.connectGen.Load(id)
if !ok {
return 0
}
counter, ok := val.(*atomic.Uint64)
if !ok {
return 0
}
return counter.Load()
}
// Disconnect marks a node as disconnected and updates its primary routes in the state.
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
// The gen parameter is the generation returned by Connect(). If a newer Connect() has
// been called since the session that is disconnecting, the generation will not match
// and this call becomes a no-op, preventing stale disconnects from overwriting the
// online status set by a newer session.
func (s *State) Disconnect(id types.NodeID, gen uint64) ([]change.Change, error) {
// Check if this disconnect is stale. A newer Connect() will have incremented
// the generation, so if ours doesn't match, a newer session owns this node.
if current := s.connectGeneration(id); current != gen {
log.Debug().
Uint64("disconnect_gen", gen).
Uint64("current_gen", current).
Msg("stale disconnect rejected, newer session active")
return nil, nil
}
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
now := time.Now()
n.LastSeen = &now