mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-23 00:58:43 +02:00
auth: generalise auth flow and introduce AuthVerdict
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/templates"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@@ -26,8 +25,8 @@ import (
|
||||
const (
|
||||
randomByteSize = 16
|
||||
defaultOAuthOptionsCount = 3
|
||||
registerCacheExpiration = time.Minute * 15
|
||||
registerCacheCleanup = time.Minute * 20
|
||||
authCacheExpiration = time.Minute * 15
|
||||
authCacheCleanup = time.Minute * 20
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -44,17 +43,21 @@ var (
|
||||
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
|
||||
)
|
||||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
RegistrationID types.RegistrationID
|
||||
Verifier *string
|
||||
// AuthInfo contains both auth ID and verifier information for OIDC validation.
|
||||
type AuthInfo struct {
|
||||
AuthID types.AuthID
|
||||
Verifier *string
|
||||
Registration bool
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
|
||||
// authCache holds auth information between
|
||||
// the auth and the callback steps.
|
||||
authCache *zcache.Cache[string, AuthInfo]
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
@@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
registrationCache := zcache.New[string, RegistrationInfo](
|
||||
registerCacheExpiration,
|
||||
registerCacheCleanup,
|
||||
authCache := zcache.New[string, AuthInfo](
|
||||
authCacheExpiration,
|
||||
authCacheCleanup,
|
||||
)
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
registrationCache: registrationCache,
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
authCache: authCache,
|
||||
|
||||
oidcProvider: oidcProvider,
|
||||
oauth2Config: oauth2Config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
|
||||
func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/auth/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
authID.String())
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) AuthHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
a.authHandler(writer, req, false)
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/register/%s",
|
||||
strings.TrimSuffix(a.serverURL, "/"),
|
||||
registrationID.String())
|
||||
authID.String())
|
||||
}
|
||||
|
||||
// RegisterHandler registers the OIDC callback handler with the given router.
|
||||
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
|
||||
// Listens in /register/:registration_id.
|
||||
// Listens in /register/:auth_id.
|
||||
func (a *AuthProviderOIDC) RegisterHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
registrationIdStr := vars["registration_id"]
|
||||
a.authHandler(writer, req, true)
|
||||
}
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
|
||||
// authHandler takes an incoming request that needs to be authenticated and
|
||||
// validates and prepares it for the OIDC flow.
|
||||
func (a *AuthProviderOIDC) authHandler(
|
||||
writer http.ResponseWriter,
|
||||
req *http.Request,
|
||||
registration bool,
|
||||
) {
|
||||
authID, err := authIDFromRequest(req)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
|
||||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize registration info with machine key
|
||||
registrationInfo := RegistrationInfo{
|
||||
RegistrationID: registrationId,
|
||||
registrationInfo := AuthInfo{
|
||||
AuthID: authID,
|
||||
Registration: registration,
|
||||
}
|
||||
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
|
||||
@@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
a.registrationCache.Set(state, registrationInfo)
|
||||
a.authCache.Set(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||
@@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// If the node exists, then the node should be reauthenticated,
|
||||
// if the node does not exist, and the machine key exists, then
|
||||
// this is a new node that should be registered.
|
||||
registrationId := a.getRegistrationIDFromState(state)
|
||||
authInfo := a.getAuthInfoFromState(state)
|
||||
if authInfo == nil {
|
||||
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// If this is a registration flow, then we need to register the node.
|
||||
if authInfo.Registration {
|
||||
verb := "Reauthenticated"
|
||||
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed")
|
||||
log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
|
||||
|
||||
return
|
||||
@@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
return
|
||||
}
|
||||
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
// TODO(kradalby): handle login flow (without registration) if needed.
|
||||
// We need to send an update here to whatever might be waiting for this auth flow.
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
|
||||
@@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
||||
var exchangeOpts []oauth2.AuthCodeOption
|
||||
|
||||
if a.cfg.PKCE.Enabled {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
regInfo, ok := a.authCache.Get(state)
|
||||
if !ok {
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
@@ -507,14 +533,14 @@ func doOIDCAuthorization(
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRegistrationIDFromState retrieves the registration ID from the state.
|
||||
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
|
||||
regInfo, ok := a.registrationCache.Get(state)
|
||||
// getAuthInfoFromState retrieves the registration ID from the state.
|
||||
func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
|
||||
authInfo, ok := a.authCache.Get(state)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ®Info.RegistrationID
|
||||
return &authInfo
|
||||
}
|
||||
|
||||
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
@@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
|
||||
func (a *AuthProviderOIDC) handleRegistration(
|
||||
user *types.User,
|
||||
registrationID types.RegistrationID,
|
||||
registrationID types.AuthID,
|
||||
expiry time.Time,
|
||||
) (bool, error) {
|
||||
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||
|
||||
Reference in New Issue
Block a user