auth: generalise auth flow and introduce AuthVerdict

Generalise the registration pipeline to a more general auth pipeline
supporting both node registrations and SSH check auth requests.
Rename RegistrationID to AuthID, unexport AuthRequest fields, and
introduce AuthVerdict to unify the auth finish API.

Add the urlParam generic helper for extracting typed URL parameters
from chi routes, used by the new auth request handler.

Updates #1850
This commit is contained in:
Kristoffer Dalby
2026-02-24 18:48:57 +00:00
parent 30338441c1
commit cb3b6949ea
19 changed files with 443 additions and 336 deletions

View File

@@ -12,7 +12,6 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
@@ -26,8 +25,8 @@ import (
const (
randomByteSize = 16
defaultOAuthOptionsCount = 3
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
authCacheExpiration = time.Minute * 15
authCacheCleanup = time.Minute * 20
)
var (
@@ -44,17 +43,21 @@ var (
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
)
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
RegistrationID types.RegistrationID
Verifier *string
// AuthInfo contains both auth ID and verifier information for OIDC validation.
type AuthInfo struct {
AuthID types.AuthID
Verifier *string
Registration bool
}
type AuthProviderOIDC struct {
h *Headscale
serverURL string
cfg *types.OIDCConfig
registrationCache *zcache.Cache[string, RegistrationInfo]
h *Headscale
serverURL string
cfg *types.OIDCConfig
// authCache holds auth information between
// the auth and the callback steps.
authCache *zcache.Cache[string, AuthInfo]
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration,
registerCacheCleanup,
authCache := zcache.New[string, AuthInfo](
authCacheExpiration,
authCacheCleanup,
)
return &AuthProviderOIDC{
h: h,
serverURL: serverURL,
cfg: cfg,
registrationCache: registrationCache,
h: h,
serverURL: serverURL,
cfg: cfg,
authCache: authCache,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderOIDC) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
a.authHandler(writer, req, false)
}
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
registrationID.String())
authID.String())
}
// RegisterHandler registers the OIDC callback handler with the given router.
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
// Listens in /register/:registration_id.
// Listens in /register/:auth_id.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr := vars["registration_id"]
a.authHandler(writer, req, true)
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
// authHandler takes an incoming request that needs to be authenticated and
// validates and prepares it for the OIDC flow.
func (a *AuthProviderOIDC) authHandler(
writer http.ResponseWriter,
req *http.Request,
registration bool,
) {
authID, err := authIDFromRequest(req)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
httpError(writer, err)
return
}
@@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
return
}
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
RegistrationID: registrationId,
registrationInfo := AuthInfo{
AuthID: authID,
Registration: registration,
}
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info
a.registrationCache.Set(state, registrationInfo)
a.authCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
@@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
registrationId := a.getRegistrationIDFromState(state)
authInfo := a.getAuthInfoFromState(state)
if authInfo == nil {
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// Register the node if it does not exist.
if registrationId != nil {
return
}
// If this is a registration flow, then we need to register the node.
if authInfo.Registration {
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return
@@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// TODO(kradalby): handle login flow (without registration) if needed.
// We need to send an update here to whatever might be waiting for this auth flow.
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
regInfo, ok := a.authCache.Get(state)
if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
@@ -507,14 +533,14 @@ func doOIDCAuthorization(
return nil
}
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
// getAuthInfoFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
authInfo, ok := a.authCache.Get(state)
if !ok {
return nil
}
return &regInfo.RegistrationID
return &authInfo
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
@@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.RegistrationID,
registrationID types.AuthID,
expiry time.Time,
) (bool, error) {
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(