hscontrol/poll,state: fix grace period disconnect TOCTOU race

When a node disconnects, serveLongPoll defers a cleanup that starts a
grace period goroutine. This goroutine polls batcher.IsConnected() and,
if the node has not reconnected within ~10 seconds, calls
state.Disconnect() to mark it offline. A TOCTOU race exists: the node
can reconnect (calling Connect()) between the IsConnected check and
the Disconnect() call, causing the stale Disconnect() to overwrite
the new session's online status.

Fix with a monotonic per-node generation counter:

- State.Connect() increments the counter and returns the current
  generation alongside the change list.
- State.Disconnect() accepts the generation from the caller and
  rejects the call if a newer generation exists, making stale
  disconnects from old sessions a no-op.
- serveLongPoll captures the generation at Connect() time and passes
  it to Disconnect() in the deferred cleanup.
- RemoveNode's return value is now checked: if another session already
  owns the batcher slot (reconnect happened), the old session skips
  the grace period entirely.

Update batcher_test.go to track per-node connect generations and
pass them through to Disconnect(), matching production behavior.

Fixes the following test failures:
- server_state_online_after_reconnect_within_grace
- update_history_no_false_offline
- nodestore_correct_after_rapid_reconnect
- rapid_reconnect_peer_never_sees_offline
This commit is contained in:
Kristoffer Dalby
2026-03-17 14:35:18 +00:00
parent 00c41b6422
commit b09af3846b
3 changed files with 100 additions and 11 deletions

View File

@@ -99,6 +99,12 @@ type State struct {
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
// connectGen tracks a per-node monotonic generation counter so stale
// Disconnect() calls from old poll sessions are rejected. Connect()
// increments the counter and returns the current value; Disconnect()
// only proceeds when the generation it carries matches the latest.
connectGen sync.Map // types.NodeID → *atomic.Uint64
// sshCheckAuth tracks when source nodes last completed SSH check auth.
//
// For rules without explicit checkPeriod (default 12h), auth covers any
@@ -508,7 +514,15 @@ func (s *State) DeleteNode(node types.NodeView) (change.Change, error) {
}
// Connect marks a node as connected and updates its primary routes in the state.
func (s *State) Connect(id types.NodeID) []change.Change {
// It returns the list of changes and a generation number. The generation number
// must be passed to Disconnect() so that stale disconnects from old poll sessions
// are rejected (see the grace period logic in poll.go).
func (s *State) Connect(id types.NodeID) ([]change.Change, uint64) {
// Increment the connect generation for this node. This ensures that any
// in-flight Disconnect() from a previous session will see a stale generation
// and become a no-op.
gen := s.nextConnectGen(id)
// Update online status in NodeStore before creating change notification
// so the NodeStore already reflects the correct state when other nodes
// process the NodeCameOnline change for full map generation.
@@ -517,7 +531,7 @@ func (s *State) Connect(id types.NodeID) []change.Change {
// n.LastSeen = ptr.To(now)
})
if !ok {
return nil
return nil, gen
}
c := []change.Change{change.NodeOnlineFor(node)}
@@ -532,11 +546,53 @@ func (s *State) Connect(id types.NodeID) []change.Change {
c = append(c, change.NodeAdded(id))
}
return c
return c, gen
}
// nextConnectGen atomically increments and returns the connect generation for a node.
func (s *State) nextConnectGen(id types.NodeID) uint64 {
val, _ := s.connectGen.LoadOrStore(id, &atomic.Uint64{})
counter, ok := val.(*atomic.Uint64)
if !ok {
return 0
}
return counter.Add(1)
}
// connectGeneration returns the current connect generation for a node.
func (s *State) connectGeneration(id types.NodeID) uint64 {
val, ok := s.connectGen.Load(id)
if !ok {
return 0
}
counter, ok := val.(*atomic.Uint64)
if !ok {
return 0
}
return counter.Load()
}
// Disconnect marks a node as disconnected and updates its primary routes in the state.
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
// The gen parameter is the generation returned by Connect(). If a newer Connect() has
// been called since the session that is disconnecting, the generation will not match
// and this call becomes a no-op, preventing stale disconnects from overwriting the
// online status set by a newer session.
func (s *State) Disconnect(id types.NodeID, gen uint64) ([]change.Change, error) {
// Check if this disconnect is stale. A newer Connect() will have incremented
// the generation, so if ours doesn't match, a newer session owns this node.
if current := s.connectGeneration(id); current != gen {
log.Debug().
Uint64("disconnect_gen", gen).
Uint64("current_gen", current).
Msg("stale disconnect rejected, newer session active")
return nil, nil
}
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
now := time.Now()
n.LastSeen = &now