auth: generalise auth flow and introduce AuthVerdict

Generalise the registration pipeline to a more general auth pipeline
supporting both node registrations and SSH check auth requests.
Rename RegistrationID to AuthID, unexport AuthRequest fields, and
introduce AuthVerdict to unify the auth finish API.

Add the urlParam generic helper for extracting typed URL parameters
from chi routes, used by the new auth request handler.

Updates #1850
This commit is contained in:
Kristoffer Dalby
2026-02-24 18:48:57 +00:00
parent 30338441c1
commit cb3b6949ea
19 changed files with 443 additions and 336 deletions

View File

@@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
// ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique")
// ErrRegistrationExpired is returned when a registration has expired.
var ErrRegistrationExpired = errors.New("registration expired")
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
@@ -82,8 +85,10 @@ type State struct {
derpMap atomic.Pointer[tailcfg.DERPMap]
// polMan handles policy evaluation and management
polMan policy.PolicyManager
// registrationCache caches node registration data to reduce database load
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
// primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes
}
@@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
}
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
authCache := zcache.New[types.AuthID, types.AuthRequest](
cacheExpiration,
cacheCleanup,
)
registrationCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) {
rn.SendAndClose(nil)
authCache.OnEvicted(
func(id types.AuthID, rn types.AuthRequest) {
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
},
)
db, err := hsdb.NewHeadscaleDatabase(
cfg,
registrationCache,
authCache,
)
if err != nil {
return nil, fmt.Errorf("initializing database: %w", err)
@@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{
cfg: cfg,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
registrationCache: registrationCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
db: db,
ipAlloc: ipAlloc,
polMan: polMan,
authCache: authCache,
primaryRoutes: routes.New(),
nodeStore: nodeStore,
}, nil
}
@@ -1057,9 +1062,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id)
}
// GetRegistrationCacheEntry retrieves a node registration from cache.
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
entry, found := s.registrationCache.Get(id)
// GetAuthCacheEntry retrieves a node registration from cache.
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
entry, found := s.authCache.Get(id)
if !found {
return nil, false
}
@@ -1067,26 +1072,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
return &entry, true
}
// SetRegistrationCacheEntry stores a node registration in cache.
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
s.registrationCache.Set(id, entry)
// SetAuthCacheEntry stores a node registration in cache.
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
s.authCache.Set(id, entry)
}
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
if hostinfo == nil {
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
if !nv.Hostinfo().Valid() {
log.Warn().
Caller().
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
EmbedObject(nv).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had nil hostinfo, generated default hostname")
} else if hostinfo.Hostname == "" {
} else if nv.Hostinfo().Hostname() == "" {
log.Warn().
Caller().
Str(zf.MachineKey, machineKey).
Str(zf.NodeKey, nodeKey).
EmbedObject(nv).
Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname).
Msg("Registration had empty hostname, generated default")
@@ -1128,7 +1131,7 @@ type authNodeUpdateParams struct {
// Node to update; must be valid and in NodeStore.
ExistingNode types.NodeView
// Client data: keys, hostinfo, endpoints.
RegEntry *types.RegisterNode
RegEntry *types.AuthRequest
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
ValidHostinfo *tailcfg.Hostinfo
// Hostname from hostinfo, or generated from keys if client omits it.
@@ -1147,6 +1150,7 @@ type authNodeUpdateParams struct {
// an existing node. It updates the node in NodeStore, processes RequestTags, and
// persists changes to the database.
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
regNv := params.RegEntry.Node()
// Log the operation type
if params.IsConvertFromTag {
log.Info().
@@ -1155,16 +1159,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
Msg("Converting tagged node to user-owned node")
} else {
log.Info().
EmbedObject(params.ExistingNode).
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
Object("existing", params.ExistingNode).
Object("incoming", regNv).
Msg("Updating existing node registration via reauth")
}
// Process RequestTags during reauth (#2979)
// Due to json:",omitempty", we treat empty/nil as "clear tags"
var requestTags []string
if params.RegEntry.Node.Hostinfo != nil {
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
if regNv.Hostinfo().Valid() {
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
}
oldTags := params.ExistingNode.Tags().AsSlice()
@@ -1182,8 +1186,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
// Update existing node in NodeStore - validation passed, safe to mutate
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
node.NodeKey = params.RegEntry.Node.NodeKey
node.DiscoKey = params.RegEntry.Node.DiscoKey
node.NodeKey = regNv.NodeKey()
node.DiscoKey = regNv.DiscoKey()
node.Hostname = params.Hostname
// Preserve NetInfo from existing node when re-registering
@@ -1194,7 +1198,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
params.ValidHostinfo,
)
node.Endpoints = params.RegEntry.Node.Endpoints
node.Endpoints = regNv.Endpoints().AsSlice()
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
@@ -1203,7 +1207,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.IsConvertFromTag {
node.RegisterMethod = params.RegisterMethod
} else {
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
node.RegisterMethod = regNv.RegisterMethod()
}
// Track tagged status BEFORE processing tags
@@ -1223,7 +1227,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
case !wasTagged && isTagged:
// Personal → Tagged: clear expiry (tagged nodes don't expire)
@@ -1233,14 +1237,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
case !isTagged:
// Personal → Personal: update expiry from client
if params.Expiry != nil {
node.Expiry = params.Expiry
} else {
node.Expiry = params.RegEntry.Node.Expiry
node.Expiry = regNv.Expiry().Clone()
}
}
// Tagged → Tagged: keep existing expiry (nil) - no action needed
@@ -1527,13 +1531,13 @@ func (s *State) processReauthTags(
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
func (s *State) HandleNodeFromAuthPath(
registrationID types.RegistrationID,
authID types.AuthID,
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (types.NodeView, change.Change, error) {
// Get the registration entry from cache
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
regEntry, ok := s.GetAuthCacheEntry(authID)
if !ok {
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
}
@@ -1546,25 +1550,27 @@ func (s *State) HandleNodeFromAuthPath(
// Ensure we have a valid hostname from the registration cache entry
hostname := util.EnsureHostname(
regEntry.Node.Hostinfo,
regEntry.Node.MachineKey.String(),
regEntry.Node.NodeKey.String(),
regEntry.Node().Hostinfo(),
regEntry.Node().MachineKey().String(),
regEntry.Node().NodeKey().String(),
)
// Ensure we have valid hostinfo
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
hostinfo := &tailcfg.Hostinfo{}
if regEntry.Node().Hostinfo().Valid() {
hostinfo = regEntry.Node().Hostinfo().AsStruct()
}
hostinfo.Hostname = hostname
logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
regEntry.Node(),
user.Name,
hostname,
regEntry.Node.Hostinfo,
)
// Lookup existing nodes
machineKey := regEntry.Node.MachineKey
machineKey := regEntry.Node().MachineKey()
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
@@ -1578,7 +1584,7 @@ func (s *State) HandleNodeFromAuthPath(
// Create logger with common fields for all auth operations
logger := log.With().
Str(zf.RegistrationID, registrationID.String()).
Str(zf.RegistrationID, authID.String()).
Str(zf.UserName, user.Name).
Str(zf.MachineKey, machineKey.ShortString()).
Str(zf.Method, registrationMethod).
@@ -1587,7 +1593,7 @@ func (s *State) HandleNodeFromAuthPath(
// Common params for update operations
updateParams := authNodeUpdateParams{
RegEntry: regEntry,
ValidHostinfo: validHostinfo,
ValidHostinfo: hostinfo,
Hostname: hostname,
User: user,
Expiry: expiry,
@@ -1621,7 +1627,7 @@ func (s *State) HandleNodeFromAuthPath(
Msg("Creating new node for different user (same machine key exists for another user)")
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo,
logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, existingNodeAnyUser,
)
if err != nil {
@@ -1629,7 +1635,7 @@ func (s *State) HandleNodeFromAuthPath(
}
} else {
finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo,
logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, types.NodeView{},
)
if err != nil {
@@ -1638,10 +1644,10 @@ func (s *State) HandleNodeFromAuthPath(
}
// Signal to waiting clients
regEntry.SendAndClose(finalNode.AsStruct())
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
// Delete from registration cache
s.registrationCache.Delete(registrationID)
s.authCache.Delete(authID)
// Update policy managers
usersChange, err := s.updatePolicyManagerUsers()
@@ -1670,7 +1676,7 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) createNewNodeFromAuth(
logger zerolog.Logger,
user *types.User,
regEntry *types.RegisterNode,
regEntry *types.AuthRequest,
hostname string,
validHostinfo *tailcfg.Hostinfo,
expiry *time.Time,
@@ -1683,13 +1689,13 @@ func (s *State) createNewNodeFromAuth(
return s.createAndSaveNewNode(newNodeParams{
User: *user,
MachineKey: regEntry.Node.MachineKey,
NodeKey: regEntry.Node.NodeKey,
DiscoKey: regEntry.Node.DiscoKey,
MachineKey: regEntry.Node().MachineKey(),
NodeKey: regEntry.Node().NodeKey(),
DiscoKey: regEntry.Node().DiscoKey(),
Hostname: hostname,
Hostinfo: validHostinfo,
Endpoints: regEntry.Node.Endpoints,
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
Endpoints: regEntry.Node().Endpoints().AsSlice(),
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
RegisterMethod: registrationMethod,
ExistingNodeForNetinfo: existingNodeForNetinfo,
})
@@ -1784,7 +1790,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Ensure we have a valid hostname - handle nil/empty cases
hostname := util.EnsureHostname(
regReq.Hostinfo,
regReq.Hostinfo.View(),
machineKey.String(),
regReq.NodeKey.String(),
)
@@ -1793,14 +1799,6 @@ func (s *State) HandleNodeFromPreAuthKey(
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname
logHostinfoValidation(
machineKey.ShortString(),
regReq.NodeKey.ShortString(),
pakUsername(),
hostname,
regReq.Hostinfo,
)
log.Debug().
Caller().
Str(zf.NodeName, hostname).