mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-20 15:51:40 +02:00
auth: generalise auth flow and introduce AuthVerdict
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
This commit is contained in:
@@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
|
||||
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||
|
||||
// ErrRegistrationExpired is returned when a registration has expired.
|
||||
var ErrRegistrationExpired = errors.New("registration expired")
|
||||
|
||||
// State manages Headscale's core state, coordinating between database, policy management,
|
||||
// IP allocation, and DERP routing. All methods are thread-safe.
|
||||
type State struct {
|
||||
@@ -82,8 +85,10 @@ type State struct {
|
||||
derpMap atomic.Pointer[tailcfg.DERPMap]
|
||||
// polMan handles policy evaluation and management
|
||||
polMan policy.PolicyManager
|
||||
// registrationCache caches node registration data to reduce database load
|
||||
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
||||
|
||||
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
|
||||
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
}
|
||||
@@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
|
||||
authCache := zcache.New[types.AuthID, types.AuthRequest](
|
||||
cacheExpiration,
|
||||
cacheCleanup,
|
||||
)
|
||||
|
||||
registrationCache.OnEvicted(
|
||||
func(id types.RegistrationID, rn types.RegisterNode) {
|
||||
rn.SendAndClose(nil)
|
||||
authCache.OnEvicted(
|
||||
func(id types.AuthID, rn types.AuthRequest) {
|
||||
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||
},
|
||||
)
|
||||
|
||||
db, err := hsdb.NewHeadscaleDatabase(
|
||||
cfg,
|
||||
registrationCache,
|
||||
authCache,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing database: %w", err)
|
||||
@@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
return &State{
|
||||
cfg: cfg,
|
||||
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
polMan: polMan,
|
||||
registrationCache: registrationCache,
|
||||
primaryRoutes: routes.New(),
|
||||
nodeStore: nodeStore,
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
polMan: polMan,
|
||||
authCache: authCache,
|
||||
primaryRoutes: routes.New(),
|
||||
nodeStore: nodeStore,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1057,9 +1062,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
|
||||
return s.db.DeletePreAuthKey(id)
|
||||
}
|
||||
|
||||
// GetRegistrationCacheEntry retrieves a node registration from cache.
|
||||
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) {
|
||||
entry, found := s.registrationCache.Get(id)
|
||||
// GetAuthCacheEntry retrieves a node registration from cache.
|
||||
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
||||
entry, found := s.authCache.Get(id)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
@@ -1067,26 +1072,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
|
||||
return &entry, true
|
||||
}
|
||||
|
||||
// SetRegistrationCacheEntry stores a node registration in cache.
|
||||
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) {
|
||||
s.registrationCache.Set(id, entry)
|
||||
// SetAuthCacheEntry stores a node registration in cache.
|
||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
||||
s.authCache.Set(id, entry)
|
||||
}
|
||||
|
||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
||||
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) {
|
||||
if hostinfo == nil {
|
||||
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
||||
if !nv.Hostinfo().Valid() {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Str(zf.MachineKey, machineKey).
|
||||
Str(zf.NodeKey, nodeKey).
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had nil hostinfo, generated default hostname")
|
||||
} else if hostinfo.Hostname == "" {
|
||||
} else if nv.Hostinfo().Hostname() == "" {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Str(zf.MachineKey, machineKey).
|
||||
Str(zf.NodeKey, nodeKey).
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had empty hostname, generated default")
|
||||
@@ -1128,7 +1131,7 @@ type authNodeUpdateParams struct {
|
||||
// Node to update; must be valid and in NodeStore.
|
||||
ExistingNode types.NodeView
|
||||
// Client data: keys, hostinfo, endpoints.
|
||||
RegEntry *types.RegisterNode
|
||||
RegEntry *types.AuthRequest
|
||||
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
||||
ValidHostinfo *tailcfg.Hostinfo
|
||||
// Hostname from hostinfo, or generated from keys if client omits it.
|
||||
@@ -1147,6 +1150,7 @@ type authNodeUpdateParams struct {
|
||||
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
||||
// persists changes to the database.
|
||||
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
||||
regNv := params.RegEntry.Node()
|
||||
// Log the operation type
|
||||
if params.IsConvertFromTag {
|
||||
log.Info().
|
||||
@@ -1155,16 +1159,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
Msg("Converting tagged node to user-owned node")
|
||||
} else {
|
||||
log.Info().
|
||||
EmbedObject(params.ExistingNode).
|
||||
Interface("hostinfo", params.RegEntry.Node.Hostinfo).
|
||||
Object("existing", params.ExistingNode).
|
||||
Object("incoming", regNv).
|
||||
Msg("Updating existing node registration via reauth")
|
||||
}
|
||||
|
||||
// Process RequestTags during reauth (#2979)
|
||||
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
||||
var requestTags []string
|
||||
if params.RegEntry.Node.Hostinfo != nil {
|
||||
requestTags = params.RegEntry.Node.Hostinfo.RequestTags
|
||||
if regNv.Hostinfo().Valid() {
|
||||
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
|
||||
}
|
||||
|
||||
oldTags := params.ExistingNode.Tags().AsSlice()
|
||||
@@ -1182,8 +1186,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
|
||||
// Update existing node in NodeStore - validation passed, safe to mutate
|
||||
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
||||
node.NodeKey = params.RegEntry.Node.NodeKey
|
||||
node.DiscoKey = params.RegEntry.Node.DiscoKey
|
||||
node.NodeKey = regNv.NodeKey()
|
||||
node.DiscoKey = regNv.DiscoKey()
|
||||
node.Hostname = params.Hostname
|
||||
|
||||
// Preserve NetInfo from existing node when re-registering
|
||||
@@ -1194,7 +1198,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
params.ValidHostinfo,
|
||||
)
|
||||
|
||||
node.Endpoints = params.RegEntry.Node.Endpoints
|
||||
node.Endpoints = regNv.Endpoints().AsSlice()
|
||||
node.IsOnline = new(false)
|
||||
node.LastSeen = new(time.Now())
|
||||
|
||||
@@ -1203,7 +1207,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.IsConvertFromTag {
|
||||
node.RegisterMethod = params.RegisterMethod
|
||||
} else {
|
||||
node.RegisterMethod = params.RegEntry.Node.RegisterMethod
|
||||
node.RegisterMethod = regNv.RegisterMethod()
|
||||
}
|
||||
|
||||
// Track tagged status BEFORE processing tags
|
||||
@@ -1223,7 +1227,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = params.RegEntry.Node.Expiry
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
}
|
||||
case !wasTagged && isTagged:
|
||||
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
||||
@@ -1233,14 +1237,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = params.RegEntry.Node.Expiry
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
}
|
||||
case !isTagged:
|
||||
// Personal → Personal: update expiry from client
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = params.RegEntry.Node.Expiry
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
}
|
||||
}
|
||||
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
||||
@@ -1527,13 +1531,13 @@ func (s *State) processReauthTags(
|
||||
|
||||
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
|
||||
func (s *State) HandleNodeFromAuthPath(
|
||||
registrationID types.RegistrationID,
|
||||
authID types.AuthID,
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (types.NodeView, change.Change, error) {
|
||||
// Get the registration entry from cache
|
||||
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
|
||||
regEntry, ok := s.GetAuthCacheEntry(authID)
|
||||
if !ok {
|
||||
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
|
||||
}
|
||||
@@ -1546,25 +1550,27 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
|
||||
// Ensure we have a valid hostname from the registration cache entry
|
||||
hostname := util.EnsureHostname(
|
||||
regEntry.Node.Hostinfo,
|
||||
regEntry.Node.MachineKey.String(),
|
||||
regEntry.Node.NodeKey.String(),
|
||||
regEntry.Node().Hostinfo(),
|
||||
regEntry.Node().MachineKey().String(),
|
||||
regEntry.Node().NodeKey().String(),
|
||||
)
|
||||
|
||||
// Ensure we have valid hostinfo
|
||||
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{})
|
||||
validHostinfo.Hostname = hostname
|
||||
hostinfo := &tailcfg.Hostinfo{}
|
||||
if regEntry.Node().Hostinfo().Valid() {
|
||||
hostinfo = regEntry.Node().Hostinfo().AsStruct()
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
logHostinfoValidation(
|
||||
regEntry.Node.MachineKey.ShortString(),
|
||||
regEntry.Node.NodeKey.String(),
|
||||
regEntry.Node(),
|
||||
user.Name,
|
||||
hostname,
|
||||
regEntry.Node.Hostinfo,
|
||||
)
|
||||
|
||||
// Lookup existing nodes
|
||||
machineKey := regEntry.Node.MachineKey
|
||||
machineKey := regEntry.Node().MachineKey()
|
||||
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
||||
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
||||
|
||||
@@ -1578,7 +1584,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
|
||||
// Create logger with common fields for all auth operations
|
||||
logger := log.With().
|
||||
Str(zf.RegistrationID, registrationID.String()).
|
||||
Str(zf.RegistrationID, authID.String()).
|
||||
Str(zf.UserName, user.Name).
|
||||
Str(zf.MachineKey, machineKey.ShortString()).
|
||||
Str(zf.Method, registrationMethod).
|
||||
@@ -1587,7 +1593,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
// Common params for update operations
|
||||
updateParams := authNodeUpdateParams{
|
||||
RegEntry: regEntry,
|
||||
ValidHostinfo: validHostinfo,
|
||||
ValidHostinfo: hostinfo,
|
||||
Hostname: hostname,
|
||||
User: user,
|
||||
Expiry: expiry,
|
||||
@@ -1621,7 +1627,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
Msg("Creating new node for different user (same machine key exists for another user)")
|
||||
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, validHostinfo,
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
expiry, registrationMethod, existingNodeAnyUser,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1629,7 +1635,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
} else {
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, validHostinfo,
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
expiry, registrationMethod, types.NodeView{},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1638,10 +1644,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
|
||||
// Signal to waiting clients
|
||||
regEntry.SendAndClose(finalNode.AsStruct())
|
||||
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
|
||||
|
||||
// Delete from registration cache
|
||||
s.registrationCache.Delete(registrationID)
|
||||
s.authCache.Delete(authID)
|
||||
|
||||
// Update policy managers
|
||||
usersChange, err := s.updatePolicyManagerUsers()
|
||||
@@ -1670,7 +1676,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) createNewNodeFromAuth(
|
||||
logger zerolog.Logger,
|
||||
user *types.User,
|
||||
regEntry *types.RegisterNode,
|
||||
regEntry *types.AuthRequest,
|
||||
hostname string,
|
||||
validHostinfo *tailcfg.Hostinfo,
|
||||
expiry *time.Time,
|
||||
@@ -1683,13 +1689,13 @@ func (s *State) createNewNodeFromAuth(
|
||||
|
||||
return s.createAndSaveNewNode(newNodeParams{
|
||||
User: *user,
|
||||
MachineKey: regEntry.Node.MachineKey,
|
||||
NodeKey: regEntry.Node.NodeKey,
|
||||
DiscoKey: regEntry.Node.DiscoKey,
|
||||
MachineKey: regEntry.Node().MachineKey(),
|
||||
NodeKey: regEntry.Node().NodeKey(),
|
||||
DiscoKey: regEntry.Node().DiscoKey(),
|
||||
Hostname: hostname,
|
||||
Hostinfo: validHostinfo,
|
||||
Endpoints: regEntry.Node.Endpoints,
|
||||
Expiry: cmp.Or(expiry, regEntry.Node.Expiry),
|
||||
Endpoints: regEntry.Node().Endpoints().AsSlice(),
|
||||
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
|
||||
RegisterMethod: registrationMethod,
|
||||
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
||||
})
|
||||
@@ -1784,7 +1790,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
|
||||
// Ensure we have a valid hostname - handle nil/empty cases
|
||||
hostname := util.EnsureHostname(
|
||||
regReq.Hostinfo,
|
||||
regReq.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
regReq.NodeKey.String(),
|
||||
)
|
||||
@@ -1793,14 +1799,6 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
|
||||
validHostinfo.Hostname = hostname
|
||||
|
||||
logHostinfoValidation(
|
||||
machineKey.ShortString(),
|
||||
regReq.NodeKey.ShortString(),
|
||||
pakUsername(),
|
||||
hostname,
|
||||
regReq.Hostinfo,
|
||||
)
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str(zf.NodeName, hostname).
|
||||
|
||||
Reference in New Issue
Block a user