mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-24 17:48:49 +02:00
state: replace zcache with bounded LRU for auth cache
Replace zcache with golang-lru/v2/expirable for both the state auth cache and the OIDC state cache. Add tuning.register_cache_max_entries (default 1024) to cap the number of pending registration entries. Introduce types.RegistrationData to replace caching a full *Node; only the fields the registration callback path reads are retained. Remove the dead HSDatabase.regCache field. Drop zgo.at/zcache/v2 from go.mod.
This commit is contained in:
64
hscontrol/state/auth_cache_test.go
Normal file
64
hscontrol/state/auth_cache_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestAuthCacheBoundedLRU verifies that the registration auth cache is
|
||||
// bounded by a maximum entry count, that exceeding the maxEntries evicts the
|
||||
// oldest entry, and that the eviction callback resolves the parked
|
||||
// AuthRequest with ErrRegistrationExpired so any waiting goroutine wakes.
|
||||
func TestAuthCacheBoundedLRU(t *testing.T) {
|
||||
const maxEntries = 4
|
||||
|
||||
cache := expirable.NewLRU[types.AuthID, *types.AuthRequest](
|
||||
maxEntries,
|
||||
func(_ types.AuthID, rn *types.AuthRequest) {
|
||||
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||
},
|
||||
time.Hour, // long TTL — we test eviction by size, not by time
|
||||
)
|
||||
|
||||
entries := make([]*types.AuthRequest, 0, maxEntries+1)
|
||||
ids := make([]types.AuthID, 0, maxEntries+1)
|
||||
|
||||
for range maxEntries + 1 {
|
||||
id := types.MustAuthID()
|
||||
entry := types.NewAuthRequest()
|
||||
cache.Add(id, entry)
|
||||
ids = append(ids, id)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
// Cap should be respected.
|
||||
assert.Equal(t, maxEntries, cache.Len(), "cache must not exceed the configured maxEntries")
|
||||
|
||||
// The oldest entry must have been evicted.
|
||||
_, ok := cache.Get(ids[0])
|
||||
assert.False(t, ok, "oldest entry must be evicted when maxEntries is exceeded")
|
||||
|
||||
// The eviction callback must have woken the parked AuthRequest.
|
||||
select {
|
||||
case verdict := <-entries[0].WaitForAuth():
|
||||
require.False(t, verdict.Accept(), "evicted entry must not signal Accept")
|
||||
require.ErrorIs(t,
|
||||
verdict.Err, ErrRegistrationExpired,
|
||||
"evicted entry must surface ErrRegistrationExpired, got: %v",
|
||||
verdict.Err,
|
||||
)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("eviction callback did not wake the parked AuthRequest")
|
||||
}
|
||||
|
||||
// All non-evicted entries must still be retrievable.
|
||||
for i := 1; i <= maxEntries; i++ {
|
||||
_, ok := cache.Get(ids[i])
|
||||
assert.True(t, ok, "non-evicted entry %d should still be in the cache", i)
|
||||
}
|
||||
}
|
||||
@@ -211,15 +211,13 @@ func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy {
|
||||
|
||||
// DebugRegistrationCache returns debug information about the registration cache.
|
||||
func (s *State) DebugRegistrationCache() map[string]any {
|
||||
// The cache doesn't expose internal statistics, so we provide basic info
|
||||
result := map[string]any{
|
||||
"type": "zcache",
|
||||
"expiration": registerCacheExpiration.String(),
|
||||
"cleanup": registerCacheCleanup.String(),
|
||||
"status": "active",
|
||||
return map[string]any{
|
||||
"type": "expirable-lru",
|
||||
"expiration": registerCacheExpiration.String(),
|
||||
"max_entries": defaultRegisterCacheMaxEntries,
|
||||
"current_len": s.authCache.Len(),
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// DebugConfig returns debug information about the current configuration.
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
hsdb "github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
@@ -30,15 +31,18 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
zcache "zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// registerCacheExpiration defines how long node registration entries remain in cache.
|
||||
registerCacheExpiration = time.Minute * 15
|
||||
|
||||
// registerCacheCleanup defines the interval for cleaning up expired cache entries.
|
||||
registerCacheCleanup = time.Minute * 20
|
||||
// defaultRegisterCacheMaxEntries is the default upper bound on the number
|
||||
// of pending registration entries the auth cache will hold. With a 15-minute
|
||||
// TTL and a stripped-down RegistrationData payload (~200 bytes per entry),
|
||||
// 1024 entries cap the worst-case cache footprint at well under 1 MiB even
|
||||
// under sustained unauthenticated cache-fill attempts.
|
||||
defaultRegisterCacheMaxEntries = 1024
|
||||
|
||||
// defaultNodeStoreBatchSize is the default number of write operations to batch
|
||||
// before rebuilding the in-memory node snapshot.
|
||||
@@ -126,8 +130,12 @@ type State struct {
|
||||
// polMan handles policy evaluation and management
|
||||
polMan policy.PolicyManager
|
||||
|
||||
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
|
||||
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||
// authCache holds any pending authentication requests from either auth
|
||||
// type (Web and OIDC). It is a bounded LRU keyed by AuthID; oldest
|
||||
// entries are evicted once the size cap is reached, and entries that
|
||||
// time out have their auth verdict resolved with ErrRegistrationExpired
|
||||
// via the eviction callback so any waiting goroutines wake.
|
||||
authCache *expirable.LRU[types.AuthID, *types.AuthRequest]
|
||||
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
@@ -166,26 +174,20 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
|
||||
}
|
||||
|
||||
cacheCleanup := registerCacheCleanup
|
||||
if cfg.Tuning.RegisterCacheCleanup != 0 {
|
||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||
cacheMaxEntries := defaultRegisterCacheMaxEntries
|
||||
if cfg.Tuning.RegisterCacheMaxEntries > 0 {
|
||||
cacheMaxEntries = cfg.Tuning.RegisterCacheMaxEntries
|
||||
}
|
||||
|
||||
authCache := zcache.New[types.AuthID, types.AuthRequest](
|
||||
cacheExpiration,
|
||||
cacheCleanup,
|
||||
)
|
||||
|
||||
authCache.OnEvicted(
|
||||
func(id types.AuthID, rn types.AuthRequest) {
|
||||
authCache := expirable.NewLRU[types.AuthID, *types.AuthRequest](
|
||||
cacheMaxEntries,
|
||||
func(id types.AuthID, rn *types.AuthRequest) {
|
||||
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||
},
|
||||
cacheExpiration,
|
||||
)
|
||||
|
||||
db, err := hsdb.NewHeadscaleDatabase(
|
||||
cfg,
|
||||
authCache,
|
||||
)
|
||||
db, err := hsdb.NewHeadscaleDatabase(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing database: %w", err)
|
||||
}
|
||||
@@ -1252,19 +1254,14 @@ func (s *State) DeletePreAuthKey(id uint64) error {
|
||||
return s.db.DeletePreAuthKey(id)
|
||||
}
|
||||
|
||||
// GetAuthCacheEntry retrieves a node registration from cache.
|
||||
// GetAuthCacheEntry retrieves a pending auth request from the cache.
|
||||
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
||||
entry, found := s.authCache.Get(id)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &entry, true
|
||||
return s.authCache.Get(id)
|
||||
}
|
||||
|
||||
// SetAuthCacheEntry stores a node registration in cache.
|
||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
||||
s.authCache.Set(id, entry)
|
||||
// SetAuthCacheEntry stores a pending auth request in the cache.
|
||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry *types.AuthRequest) {
|
||||
s.authCache.Add(id, entry)
|
||||
}
|
||||
|
||||
// SetLastSSHAuth records a successful SSH check authentication
|
||||
@@ -1296,25 +1293,6 @@ func (s *State) ClearSSHCheckAuth() {
|
||||
s.sshCheckAuth = make(map[sshCheckPair]time.Time)
|
||||
}
|
||||
|
||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
||||
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
||||
if !nv.Hostinfo().Valid() {
|
||||
log.Warn().
|
||||
Caller().
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had nil hostinfo, generated default hostname")
|
||||
} else if nv.Hostinfo().Hostname() == "" {
|
||||
log.Warn().
|
||||
Caller().
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had empty hostname, generated default")
|
||||
}
|
||||
}
|
||||
|
||||
// preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity.
|
||||
// If no existing node is provided, it creates new netinfo from the provided hostinfo.
|
||||
func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo {
|
||||
@@ -1349,15 +1327,15 @@ type newNodeParams struct {
|
||||
type authNodeUpdateParams struct {
|
||||
// Node to update; must be valid and in NodeStore.
|
||||
ExistingNode types.NodeView
|
||||
// Client data: keys, hostinfo, endpoints.
|
||||
RegEntry *types.AuthRequest
|
||||
// Cached registration payload from the originating client request.
|
||||
RegData *types.RegistrationData
|
||||
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
||||
ValidHostinfo *tailcfg.Hostinfo
|
||||
// Hostname from hostinfo, or generated from keys if client omits it.
|
||||
Hostname string
|
||||
// Auth user; may differ from ExistingNode.User() on conversion.
|
||||
User *types.User
|
||||
// Overrides RegEntry.Node.Expiry; ignored for tagged nodes.
|
||||
// Overrides RegData.Expiry; ignored for tagged nodes.
|
||||
Expiry *time.Time
|
||||
// Only used when IsConvertFromTag=true.
|
||||
RegisterMethod string
|
||||
@@ -1369,7 +1347,7 @@ type authNodeUpdateParams struct {
|
||||
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
||||
// persists changes to the database.
|
||||
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
||||
regNv := params.RegEntry.Node()
|
||||
regData := params.RegData
|
||||
// Log the operation type
|
||||
if params.IsConvertFromTag {
|
||||
log.Info().
|
||||
@@ -1379,15 +1357,17 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
} else {
|
||||
log.Info().
|
||||
Object("existing", params.ExistingNode).
|
||||
Object("incoming", regNv).
|
||||
Str("incoming.hostname", regData.Hostname).
|
||||
Str("incoming.machine_key", regData.MachineKey.ShortString()).
|
||||
Msg("Updating existing node registration via reauth")
|
||||
}
|
||||
|
||||
// Process RequestTags during reauth (#2979)
|
||||
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
||||
// Process RequestTags during reauth (#2979).
|
||||
// Due to json:",omitempty", empty/nil from the cached Hostinfo
|
||||
// means "clear tags".
|
||||
var requestTags []string
|
||||
if regNv.Hostinfo().Valid() {
|
||||
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
|
||||
if regData.Hostinfo != nil {
|
||||
requestTags = regData.Hostinfo.RequestTags
|
||||
}
|
||||
|
||||
oldTags := params.ExistingNode.Tags().AsSlice()
|
||||
@@ -1405,8 +1385,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
|
||||
// Update existing node in NodeStore - validation passed, safe to mutate
|
||||
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
||||
node.NodeKey = regNv.NodeKey()
|
||||
node.DiscoKey = regNv.DiscoKey()
|
||||
node.NodeKey = regData.NodeKey
|
||||
node.DiscoKey = regData.DiscoKey
|
||||
node.Hostname = params.Hostname
|
||||
|
||||
// Preserve NetInfo from existing node when re-registering
|
||||
@@ -1417,7 +1397,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
params.ValidHostinfo,
|
||||
)
|
||||
|
||||
node.Endpoints = regNv.Endpoints().AsSlice()
|
||||
node.Endpoints = regData.Endpoints
|
||||
// Do NOT reset IsOnline here. Online status is managed exclusively by
|
||||
// Connect()/Disconnect() in the poll session lifecycle. Resetting it
|
||||
// during re-registration causes a false offline blip: the change
|
||||
@@ -1425,12 +1405,12 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
// to peers, even though Connect() will immediately set it back to true.
|
||||
node.LastSeen = new(time.Now())
|
||||
|
||||
// Set RegisterMethod - for conversion this is the new method,
|
||||
// for reauth we preserve the existing one from regEntry
|
||||
// On conversion (tagged → user) we set the new register method.
|
||||
// On plain reauth we preserve the existing node.RegisterMethod;
|
||||
// the cached RegistrationData no longer carries it because the
|
||||
// producer never populated it.
|
||||
if params.IsConvertFromTag {
|
||||
node.RegisterMethod = params.RegisterMethod
|
||||
} else {
|
||||
node.RegisterMethod = regNv.RegisterMethod()
|
||||
}
|
||||
|
||||
// Track tagged status BEFORE processing tags
|
||||
@@ -1450,7 +1430,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
node.Expiry = regData.Expiry
|
||||
}
|
||||
case !wasTagged && isTagged:
|
||||
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
||||
@@ -1460,14 +1440,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
node.Expiry = regData.Expiry
|
||||
}
|
||||
case !isTagged:
|
||||
// Personal → Personal: update expiry from client
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
node.Expiry = regData.Expiry
|
||||
}
|
||||
}
|
||||
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
||||
@@ -1795,29 +1775,20 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("finding user: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname from the registration cache entry
|
||||
hostname := util.EnsureHostname(
|
||||
regEntry.Node().Hostinfo(),
|
||||
regEntry.Node().MachineKey().String(),
|
||||
regEntry.Node().NodeKey().String(),
|
||||
)
|
||||
regData := regEntry.RegistrationData()
|
||||
|
||||
// Ensure we have valid hostinfo
|
||||
// Hostname was already validated/normalised at producer time. Build
|
||||
// the initial Hostinfo from the cached client-supplied Hostinfo (or
|
||||
// an empty stub if the client did not send one).
|
||||
hostname := regData.Hostname
|
||||
hostinfo := &tailcfg.Hostinfo{}
|
||||
if regEntry.Node().Hostinfo().Valid() {
|
||||
hostinfo = regEntry.Node().Hostinfo().AsStruct()
|
||||
if regData.Hostinfo != nil {
|
||||
hostinfo = regData.Hostinfo.Clone()
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
logHostinfoValidation(
|
||||
regEntry.Node(),
|
||||
user.Name,
|
||||
hostname,
|
||||
)
|
||||
|
||||
// Lookup existing nodes
|
||||
machineKey := regEntry.Node().MachineKey()
|
||||
machineKey := regData.MachineKey
|
||||
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
||||
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
||||
|
||||
@@ -1839,7 +1810,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
|
||||
// Common params for update operations
|
||||
updateParams := authNodeUpdateParams{
|
||||
RegEntry: regEntry,
|
||||
RegData: regData,
|
||||
ValidHostinfo: hostinfo,
|
||||
Hostname: hostname,
|
||||
User: user,
|
||||
@@ -1874,7 +1845,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
Msg("Creating new node for different user (same machine key exists for another user)")
|
||||
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
logger, user, regData, hostname, hostinfo,
|
||||
expiry, registrationMethod, existingNodeAnyUser,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1882,7 +1853,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
} else {
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
logger, user, regData, hostname, hostinfo,
|
||||
expiry, registrationMethod, types.NodeView{},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1893,8 +1864,8 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
// Signal to waiting clients
|
||||
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
|
||||
|
||||
// Delete from registration cache
|
||||
s.authCache.Delete(authID)
|
||||
// Remove from registration cache
|
||||
s.authCache.Remove(authID)
|
||||
|
||||
// Update policy managers
|
||||
usersChange, err := s.updatePolicyManagerUsers()
|
||||
@@ -1923,7 +1894,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) createNewNodeFromAuth(
|
||||
logger zerolog.Logger,
|
||||
user *types.User,
|
||||
regEntry *types.AuthRequest,
|
||||
regData *types.RegistrationData,
|
||||
hostname string,
|
||||
validHostinfo *tailcfg.Hostinfo,
|
||||
expiry *time.Time,
|
||||
@@ -1936,13 +1907,13 @@ func (s *State) createNewNodeFromAuth(
|
||||
|
||||
return s.createAndSaveNewNode(newNodeParams{
|
||||
User: *user,
|
||||
MachineKey: regEntry.Node().MachineKey(),
|
||||
NodeKey: regEntry.Node().NodeKey(),
|
||||
DiscoKey: regEntry.Node().DiscoKey(),
|
||||
MachineKey: regData.MachineKey,
|
||||
NodeKey: regData.NodeKey,
|
||||
DiscoKey: regData.DiscoKey,
|
||||
Hostname: hostname,
|
||||
Hostinfo: validHostinfo,
|
||||
Endpoints: regEntry.Node().Endpoints().AsSlice(),
|
||||
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
|
||||
Endpoints: regData.Endpoints,
|
||||
Expiry: cmp.Or(expiry, regData.Expiry),
|
||||
RegisterMethod: registrationMethod,
|
||||
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user