mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-19 16:21:23 +01:00
hscontrol/poll,state: fix grace period disconnect TOCTOU race
When a node disconnects, serveLongPoll defers a cleanup that starts a grace period goroutine. This goroutine polls batcher.IsConnected() and, if the node has not reconnected within ~10 seconds, calls state.Disconnect() to mark it offline. A TOCTOU race exists: the node can reconnect (calling Connect()) between the IsConnected check and the Disconnect() call, causing the stale Disconnect() to overwrite the new session's online status. Fix with a monotonic per-node generation counter: - State.Connect() increments the counter and returns the current generation alongside the change list. - State.Disconnect() accepts the generation from the caller and rejects the call if a newer generation exists, making stale disconnects from old sessions a no-op. - serveLongPoll captures the generation at Connect() time and passes it to Disconnect() in the deferred cleanup. - RemoveNode's return value is now checked: if another session already owns the batcher slot (reconnect happened), the old session skips the grace period entirely. Update batcher_test.go to track per-node connect generations and pass them through to Disconnect(), matching production behavior. Fixes the following test failures: - server_state_online_after_reconnect_within_grace - update_history_no_false_offline - nodestore_correct_after_rapid_reconnect - rapid_reconnect_peer_never_sees_offline
This commit is contained in:
@@ -39,14 +39,20 @@ type testBatcherWrapper struct {
|
||||
*Batcher
|
||||
|
||||
state *state.State
|
||||
|
||||
// connectGens tracks per-node connect generations so RemoveNode can pass
|
||||
// the correct generation to State.Disconnect(), matching production behavior.
|
||||
connectGens sync.Map // types.NodeID → uint64
|
||||
}
|
||||
|
||||
func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, stop func()) error {
|
||||
// Mark node as online in state before AddNode to match production behavior
|
||||
// This ensures the NodeStore has correct online status for change processing
|
||||
if t.state != nil {
|
||||
// Use Connect to properly mark node online in NodeStore but don't send its changes
|
||||
_ = t.state.Connect(id)
|
||||
// Use Connect to properly mark node online in NodeStore and track the
|
||||
// generation so RemoveNode can pass it to Disconnect().
|
||||
_, gen := t.state.Connect(id)
|
||||
t.connectGens.Store(id, gen)
|
||||
}
|
||||
|
||||
// First add the node to the real batcher
|
||||
@@ -71,8 +77,15 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
||||
// Mark node as offline in state BEFORE removing from batcher
|
||||
// This ensures the NodeStore has correct offline status when the change is processed
|
||||
if t.state != nil {
|
||||
// Use Disconnect to properly mark node offline in NodeStore but don't send its changes
|
||||
_, _ = t.state.Disconnect(id)
|
||||
var gen uint64
|
||||
|
||||
if v, ok := t.connectGens.LoadAndDelete(id); ok {
|
||||
if g, ok := v.(uint64); ok {
|
||||
gen = g
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = t.state.Disconnect(id, gen)
|
||||
}
|
||||
|
||||
// Send the offline notification that poll.go would normally send
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -147,11 +148,24 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
m.log.Trace().Caller().Msg("long poll session started")
|
||||
|
||||
// connectGen is set by Connect() below and captured by the deferred cleanup closure.
|
||||
// It allows Disconnect() to reject stale calls from old sessions — if a newer session
|
||||
// has called Connect() (incrementing the generation), the old session's Disconnect()
|
||||
// sees a mismatched generation and becomes a no-op.
|
||||
var connectGen uint64
|
||||
|
||||
// Clean up the session when the client disconnects
|
||||
defer func() {
|
||||
m.stopFromBatcher()
|
||||
|
||||
_ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
|
||||
stillConnected := m.h.mapBatcher.RemoveNode(m.node.ID, m.ch)
|
||||
|
||||
// If another session already exists for this node (reconnect
|
||||
// happened before this cleanup ran), skip the grace period
|
||||
// entirely — the node is not actually disconnecting.
|
||||
if stillConnected {
|
||||
return
|
||||
}
|
||||
|
||||
// When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather).
|
||||
// Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects.
|
||||
@@ -176,7 +190,11 @@ func (m *mapSession) serveLongPoll() {
|
||||
}
|
||||
|
||||
if disconnected {
|
||||
disconnectChanges, err := m.h.state.Disconnect(m.node.ID)
|
||||
// Pass the generation from our Connect() call. If a newer session has
|
||||
// connected since (bumping the generation), Disconnect() will detect
|
||||
// the mismatch and skip the state update, preventing the race where
|
||||
// an old grace period goroutine overwrites a newer session's online status.
|
||||
disconnectChanges, err := m.h.state.Disconnect(m.node.ID, connectGen)
|
||||
if err != nil {
|
||||
m.log.Error().Caller().Err(err).Msg("failed to disconnect node")
|
||||
}
|
||||
@@ -215,7 +233,9 @@ func (m *mapSession) serveLongPoll() {
|
||||
// 2. Connect: marks the node online and recalculates primary routes based on the updated state
|
||||
// While this results in two notifications, it ensures route data is synchronized before
|
||||
// primary route selection occurs, which is critical for proper HA subnet router failover.
|
||||
connectChanges := m.h.state.Connect(m.node.ID)
|
||||
var connectChanges []change.Change
|
||||
|
||||
connectChanges, connectGen = m.h.state.Connect(m.node.ID)
|
||||
|
||||
m.log.Info().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("node has connected")
|
||||
|
||||
|
||||
@@ -99,6 +99,12 @@ type State struct {
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
|
||||
// connectGen tracks a per-node monotonic generation counter so stale
|
||||
// Disconnect() calls from old poll sessions are rejected. Connect()
|
||||
// increments the counter and returns the current value; Disconnect()
|
||||
// only proceeds when the generation it carries matches the latest.
|
||||
connectGen sync.Map // types.NodeID → *atomic.Uint64
|
||||
|
||||
// sshCheckAuth tracks when source nodes last completed SSH check auth.
|
||||
//
|
||||
// For rules without explicit checkPeriod (default 12h), auth covers any
|
||||
@@ -508,7 +514,15 @@ func (s *State) DeleteNode(node types.NodeView) (change.Change, error) {
|
||||
}
|
||||
|
||||
// Connect marks a node as connected and updates its primary routes in the state.
|
||||
func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
// It returns the list of changes and a generation number. The generation number
|
||||
// must be passed to Disconnect() so that stale disconnects from old poll sessions
|
||||
// are rejected (see the grace period logic in poll.go).
|
||||
func (s *State) Connect(id types.NodeID) ([]change.Change, uint64) {
|
||||
// Increment the connect generation for this node. This ensures that any
|
||||
// in-flight Disconnect() from a previous session will see a stale generation
|
||||
// and become a no-op.
|
||||
gen := s.nextConnectGen(id)
|
||||
|
||||
// Update online status in NodeStore before creating change notification
|
||||
// so the NodeStore already reflects the correct state when other nodes
|
||||
// process the NodeCameOnline change for full map generation.
|
||||
@@ -517,7 +531,7 @@ func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
// n.LastSeen = ptr.To(now)
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, gen
|
||||
}
|
||||
|
||||
c := []change.Change{change.NodeOnlineFor(node)}
|
||||
@@ -532,11 +546,53 @@ func (s *State) Connect(id types.NodeID) []change.Change {
|
||||
c = append(c, change.NodeAdded(id))
|
||||
}
|
||||
|
||||
return c
|
||||
return c, gen
|
||||
}
|
||||
|
||||
// nextConnectGen atomically increments and returns the connect generation for a node.
|
||||
func (s *State) nextConnectGen(id types.NodeID) uint64 {
|
||||
val, _ := s.connectGen.LoadOrStore(id, &atomic.Uint64{})
|
||||
|
||||
counter, ok := val.(*atomic.Uint64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
return counter.Add(1)
|
||||
}
|
||||
|
||||
// connectGeneration returns the current connect generation for a node.
|
||||
func (s *State) connectGeneration(id types.NodeID) uint64 {
|
||||
val, ok := s.connectGen.Load(id)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
counter, ok := val.(*atomic.Uint64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
return counter.Load()
|
||||
}
|
||||
|
||||
// Disconnect marks a node as disconnected and updates its primary routes in the state.
|
||||
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
|
||||
// The gen parameter is the generation returned by Connect(). If a newer Connect() has
|
||||
// been called since the session that is disconnecting, the generation will not match
|
||||
// and this call becomes a no-op, preventing stale disconnects from overwriting the
|
||||
// online status set by a newer session.
|
||||
func (s *State) Disconnect(id types.NodeID, gen uint64) ([]change.Change, error) {
|
||||
// Check if this disconnect is stale. A newer Connect() will have incremented
|
||||
// the generation, so if ours doesn't match, a newer session owns this node.
|
||||
if current := s.connectGeneration(id); current != gen {
|
||||
log.Debug().
|
||||
Uint64("disconnect_gen", gen).
|
||||
Uint64("current_gen", current).
|
||||
Msg("stale disconnect rejected, newer session active")
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
now := time.Now()
|
||||
n.LastSeen = &now
|
||||
|
||||
Reference in New Issue
Block a user