mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-24 09:38:45 +02:00
hscontrol/poll,state: fix grace period disconnect TOCTOU race
When a node disconnects, serveLongPoll defers a cleanup that starts a grace period goroutine. This goroutine polls batcher.IsConnected() and, if the node has not reconnected within ~10 seconds, calls state.Disconnect() to mark it offline. A TOCTOU race exists: the node can reconnect (calling Connect()) between the IsConnected check and the Disconnect() call, causing the stale Disconnect() to overwrite the new session's online status. Fix with a monotonic per-node generation counter: - State.Connect() increments the counter and returns the current generation alongside the change list. - State.Disconnect() accepts the generation from the caller and rejects the call if a newer generation exists, making stale disconnects from old sessions a no-op. - serveLongPoll captures the generation at Connect() time and passes it to Disconnect() in the deferred cleanup. - RemoveNode's return value is now checked: if another session already owns the batcher slot (reconnect happened), the old session skips the grace period entirely. Update batcher_test.go to track per-node connect generations and pass them through to Disconnect(), matching production behavior. Fixes the following test failures: - server_state_online_after_reconnect_within_grace - update_history_no_false_offline - nodestore_correct_after_rapid_reconnect - rapid_reconnect_peer_never_sees_offline
This commit is contained in:
@@ -99,6 +99,12 @@ type State struct {
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
|
||||
// connectGen tracks a per-node monotonic generation counter so stale
|
||||
// Disconnect() calls from old poll sessions are rejected. Connect()
|
||||
// increments the counter and returns the current value; Disconnect()
|
||||
// only proceeds when the generation it carries matches the latest.
|
||||
connectGen sync.Map // types.NodeID → *atomic.Uint64
|
||||
|
||||
// sshCheckAuth tracks when source nodes last completed SSH check auth.
|
||||
//
|
||||
// For rules without explicit checkPeriod (default 12h), auth covers any
|
||||
@@ -508,7 +514,15 @@ func (s *State) DeleteNode(node types.NodeView) (change.Change, error) {
|
||||
}
|
||||
|
||||
// Connect marks a node as connected and updates its primary routes in the state.
|
||||
func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
// It returns the list of changes and a generation number. The generation number
|
||||
// must be passed to Disconnect() so that stale disconnects from old poll sessions
|
||||
// are rejected (see the grace period logic in poll.go).
|
||||
func (s *State) Connect(id types.NodeID) ([]change.Change, uint64) {
|
||||
// Increment the connect generation for this node. This ensures that any
|
||||
// in-flight Disconnect() from a previous session will see a stale generation
|
||||
// and become a no-op.
|
||||
gen := s.nextConnectGen(id)
|
||||
|
||||
// Update online status in NodeStore before creating change notification
|
||||
// so the NodeStore already reflects the correct state when other nodes
|
||||
// process the NodeCameOnline change for full map generation.
|
||||
@@ -517,7 +531,7 @@ func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
// n.LastSeen = ptr.To(now)
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, gen
|
||||
}
|
||||
|
||||
c := []change.Change{change.NodeOnlineFor(node)}
|
||||
@@ -532,11 +546,53 @@ func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
c = append(c, change.NodeAdded(id))
|
||||
}
|
||||
|
||||
return c
|
||||
return c, gen
|
||||
}
|
||||
|
||||
// nextConnectGen atomically increments and returns the connect generation for a node.
|
||||
func (s *State) nextConnectGen(id types.NodeID) uint64 {
|
||||
val, _ := s.connectGen.LoadOrStore(id, &atomic.Uint64{})
|
||||
|
||||
counter, ok := val.(*atomic.Uint64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
return counter.Add(1)
|
||||
}
|
||||
|
||||
// connectGeneration returns the current connect generation for a node.
|
||||
func (s *State) connectGeneration(id types.NodeID) uint64 {
|
||||
val, ok := s.connectGen.Load(id)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
counter, ok := val.(*atomic.Uint64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
return counter.Load()
|
||||
}
|
||||
|
||||
// Disconnect marks a node as disconnected and updates its primary routes in the state.
|
||||
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
|
||||
// The gen parameter is the generation returned by Connect(). If a newer Connect() has
|
||||
// been called since the session that is disconnecting, the generation will not match
|
||||
// and this call becomes a no-op, preventing stale disconnects from overwriting the
|
||||
// online status set by a newer session.
|
||||
func (s *State) Disconnect(id types.NodeID, gen uint64) ([]change.Change, error) {
|
||||
// Check if this disconnect is stale. A newer Connect() will have incremented
|
||||
// the generation, so if ours doesn't match, a newer session owns this node.
|
||||
if current := s.connectGeneration(id); current != gen {
|
||||
log.Debug().
|
||||
Uint64("disconnect_gen", gen).
|
||||
Uint64("current_gen", current).
|
||||
Msg("stale disconnect rejected, newer session active")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
now := time.Now()
|
||||
n.LastSeen = &now
|
||||
|
||||
Reference in New Issue
Block a user