From b09af3846b9362dd99f210d77e3afad38b6ef5a3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 17 Mar 2026 14:35:18 +0000 Subject: [PATCH] hscontrol/poll,state: fix grace period disconnect TOCTOU race When a node disconnects, serveLongPoll defers a cleanup that starts a grace period goroutine. This goroutine polls batcher.IsConnected() and, if the node has not reconnected within ~10 seconds, calls state.Disconnect() to mark it offline. A TOCTOU race exists: the node can reconnect (calling Connect()) between the IsConnected check and the Disconnect() call, causing the stale Disconnect() to overwrite the new session's online status. Fix with a monotonic per-node generation counter: - State.Connect() increments the counter and returns the current generation alongside the change list. - State.Disconnect() accepts the generation from the caller and rejects the call if a newer generation exists, making stale disconnects from old sessions a no-op. - serveLongPoll captures the generation at Connect() time and passes it to Disconnect() in the deferred cleanup. - RemoveNode's return value is now checked: if another session already owns the batcher slot (reconnect happened), the old session skips the grace period entirely. Update batcher_test.go to track per-node connect generations and pass them through to Disconnect(), matching production behavior. Fixes the following test failures: - server_state_online_after_reconnect_within_grace - update_history_no_false_offline - nodestore_correct_after_rapid_reconnect - rapid_reconnect_peer_never_sees_offline --- hscontrol/mapper/batcher_test.go | 21 +++++++++-- hscontrol/poll.go | 26 +++++++++++-- hscontrol/state/state.go | 64 ++++++++++++++++++++++++++++++-- 3 files changed, 100 insertions(+), 11 deletions(-) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 13d18bb2..1b89c214 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -39,14 +39,20 @@ type testBatcherWrapper struct { *Batcher state *state.State + + // connectGens tracks per-node connect generations so RemoveNode can pass + // the correct generation to State.Disconnect(), matching production behavior. + connectGens sync.Map // types.NodeID → uint64 } func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error { // Mark node as online in state before AddNode to match production behavior // This ensures the NodeStore has correct online status for change processing if t.state != nil { - // Use Connect to properly mark node online in NodeStore but don't send its changes - _ = t.state.Connect(id) + // Use Connect to properly mark node online in NodeStore and track the + // generation so RemoveNode can pass it to Disconnect(). + _, gen := t.state.Connect(id) + t.connectGens.Store(id, gen) } // First add the node to the real batcher @@ -71,8 +77,15 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe // Mark node as offline in state BEFORE removing from batcher // This ensures the NodeStore has correct offline status when the change is processed if t.state != nil { - // Use Disconnect to properly mark node offline in NodeStore but don't send its changes - _, _ = t.state.Disconnect(id) + var gen uint64 + + if v, ok := t.connectGens.LoadAndDelete(id); ok { + if g, ok := v.(uint64); ok { + gen = g + } + } + + _, _ = t.state.Disconnect(id, gen) } // Send the offline notification that poll.go would normally send diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 3179eb78..7dd4051f 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -11,6 +11,7 @@ import ( "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util/zlog/zf" "github.com/rs/zerolog" @@ -147,11 +148,24 @@ func (m *mapSession) serveLongPoll() { m.log.Trace().Caller().Msg("long poll session started") + // connectGen is set by Connect() below and captured by the deferred cleanup closure. + // It allows Disconnect() to reject stale calls from old sessions — if a newer session + // has called Connect() (incrementing the generation), the old session's Disconnect() + // sees a mismatched generation and becomes a no-op. + var connectGen uint64 + // Clean up the session when the client disconnects defer func() { m.stopFromBatcher() - _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) + stillConnected := m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) + + // If another session already exists for this node (reconnect + // happened before this cleanup ran), skip the grace period + // entirely — the node is not actually disconnecting. + if stillConnected { + return + } // When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather). // Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects. @@ -176,7 +190,11 @@ func (m *mapSession) serveLongPoll() { } if disconnected { - disconnectChanges, err := m.h.state.Disconnect(m.node.ID) + // Pass the generation from our Connect() call. If a newer session has + // connected since (bumping the generation), Disconnect() will detect + // the mismatch and skip the state update, preventing the race where + // an old grace period goroutine overwrites a newer session's online status. + disconnectChanges, err := m.h.state.Disconnect(m.node.ID, connectGen) if err != nil { m.log.Error().Caller().Err(err).Msg("failed to disconnect node") } @@ -215,7 +233,9 @@ func (m *mapSession) serveLongPoll() { // 2. Connect: marks the node online and recalculates primary routes based on the updated state // While this results in two notifications, it ensures route data is synchronized before // primary route selection occurs, which is critical for proper HA subnet router failover. - connectChanges := m.h.state.Connect(m.node.ID) + var connectChanges []change.Change + + connectChanges, connectGen = m.h.state.Connect(m.node.ID) m.log.Info().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("node has connected") diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index c5c917fa..39589228 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -99,6 +99,12 @@ type State struct { // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes + // connectGen tracks a per-node monotonic generation counter so stale + // Disconnect() calls from old poll sessions are rejected. Connect() + // increments the counter and returns the current value; Disconnect() + // only proceeds when the generation it carries matches the latest. + connectGen sync.Map // types.NodeID → *atomic.Uint64 + // sshCheckAuth tracks when source nodes last completed SSH check auth. // // For rules without explicit checkPeriod (default 12h), auth covers any @@ -508,7 +514,15 @@ func (s *State) DeleteNode(node types.NodeView) (change.Change, error) { } // Connect marks a node as connected and updates its primary routes in the state. -func (s *State) Connect(id types.NodeID) []change.Change { +// It returns the list of changes and a generation number. The generation number +// must be passed to Disconnect() so that stale disconnects from old poll sessions +// are rejected (see the grace period logic in poll.go). +func (s *State) Connect(id types.NodeID) ([]change.Change, uint64) { + // Increment the connect generation for this node. This ensures that any + // in-flight Disconnect() from a previous session will see a stale generation + // and become a no-op. + gen := s.nextConnectGen(id) + // Update online status in NodeStore before creating change notification // so the NodeStore already reflects the correct state when other nodes // process the NodeCameOnline change for full map generation. @@ -517,7 +531,7 @@ func (s *State) Connect(id types.NodeID) []change.Change { // n.LastSeen = ptr.To(now) }) if !ok { - return nil + return nil, gen } c := []change.Change{change.NodeOnlineFor(node)} @@ -532,11 +546,53 @@ func (s *State) Connect(id types.NodeID) []change.Change { c = append(c, change.NodeAdded(id)) } - return c + return c, gen +} + +// nextConnectGen atomically increments and returns the connect generation for a node. +func (s *State) nextConnectGen(id types.NodeID) uint64 { + val, _ := s.connectGen.LoadOrStore(id, &atomic.Uint64{}) + + counter, ok := val.(*atomic.Uint64) + if !ok { + return 0 + } + + return counter.Add(1) +} + +// connectGeneration returns the current connect generation for a node. +func (s *State) connectGeneration(id types.NodeID) uint64 { + val, ok := s.connectGen.Load(id) + if !ok { + return 0 + } + + counter, ok := val.(*atomic.Uint64) + if !ok { + return 0 + } + + return counter.Load() } // Disconnect marks a node as disconnected and updates its primary routes in the state. -func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { +// The gen parameter is the generation returned by Connect(). If a newer Connect() has +// been called since the session that is disconnecting, the generation will not match +// and this call becomes a no-op, preventing stale disconnects from overwriting the +// online status set by a newer session. +func (s *State) Disconnect(id types.NodeID, gen uint64) ([]change.Change, error) { + // Check if this disconnect is stale. A newer Connect() will have incremented + // the generation, so if ours doesn't match, a newer session owns this node. + if current := s.connectGeneration(id); current != gen { + log.Debug(). + Uint64("disconnect_gen", gen). + Uint64("current_gen", current). + Msg("stale disconnect rejected, newer session active") + + return nil, nil + } + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { now := time.Now() n.LastSeen = &now