diff --git a/hscontrol/app.go b/hscontrol/app.go index c86277c5..ed0da82a 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -485,6 +485,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux { if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { r.Get("/oidc/callback", provider.OIDCCallbackHandler) + r.Post("/register/confirm/{auth_id}", provider.RegisterConfirmHandler) } r.Get("/apple", h.AppleConfigMessage) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index b84a965d..7834c8b5 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -346,30 +346,13 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - // If this is a registration flow, then we need to register the node. + // If this is a registration flow, render the confirmation + // interstitial instead of finalising the registration immediately. + // Without an explicit user click, a single GET to + // /register/{auth_id} could silently complete a registration when + // the IdP allows silent SSO. if authInfo.Registration { - newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry) - if err != nil { - if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { - log.Debug().Caller().Str("auth_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed") - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) - - return - } - - httpError(writer, err) - - return - } - - content := renderRegistrationSuccessTemplate(user, newNode) - - writer.Header().Set("Content-Type", "text/html; charset=utf-8") - writer.WriteHeader(http.StatusOK) - - if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr - util.LogErr(err, "Failed to write HTTP response") - } + a.renderRegistrationConfirmInterstitial(writer, req, authInfo.AuthID, user, nodeExpiry) return } @@ -676,6 +659,202 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return user, c, nil } +// registerConfirmCSRFCookie is the cookie name used to bind the +// /register/confirm POST handler's CSRF token to the OIDC callback that +// rendered the interstitial. It includes a per-session prefix derived +// from the auth ID so cookies for unrelated registrations on the same +// browser do not collide. +const registerConfirmCSRFCookie = "headscale_register_confirm" + +// renderRegistrationConfirmInterstitial captures the resolved OIDC +// identity and node expiry into the cached AuthRequest, sets the CSRF +// cookie, and renders the confirmation page that the user must +// explicitly submit before the registration is finalised. +func (a *AuthProviderOIDC) renderRegistrationConfirmInterstitial( + writer http.ResponseWriter, + req *http.Request, + authID types.AuthID, + user *types.User, + nodeExpiry *time.Time, +) { + authReq, ok := a.h.state.GetAuthCacheEntry(authID) + if !ok { + log.Debug().Caller().Str("auth_id", authID.String()).Msg("registration session expired before authorization completed") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + + return + } + + if !authReq.IsRegistration() { + log.Warn().Caller(). + Str("auth_id", authID.String()). + Msg("OIDC callback hit registration path with auth request that is not a node registration") + httpError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for node registration", nil)) + + return + } + + csrf, err := util.GenerateRandomStringURLSafe(32) + if err != nil { + httpError(writer, fmt.Errorf("generating csrf token: %w", err)) + + return + } + + authReq.SetPendingConfirmation(&types.PendingRegistrationConfirmation{ + UserID: user.ID, + NodeExpiry: nodeExpiry, + CSRF: csrf, + }) + + http.SetCookie(writer, &http.Cookie{ + Name: registerConfirmCSRFCookie, + Value: csrf, + Path: "/register/confirm/" + authID.String(), + MaxAge: int(authCacheExpiration.Seconds()), + Secure: req.TLS != nil, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + + regData := authReq.RegistrationData() + + info := templates.RegisterConfirmInfo{ + FormAction: "/register/confirm/" + authID.String(), + CSRFTokenName: registerConfirmCSRFCookie, + CSRFToken: csrf, + User: user.Display(), + Hostname: regData.Hostname, + MachineKey: regData.MachineKey.ShortString(), + } + if regData.Hostinfo != nil { + info.OS = regData.Hostinfo.OS + } + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + if _, err := writer.Write([]byte(templates.RegisterConfirm(info).Render())); err != nil { //nolint:noinlineerr + util.LogErr(err, "Failed to write HTTP response") + } +} + +// RegisterConfirmHandler is the POST endpoint behind the OIDC +// registration confirmation interstitial. It validates the CSRF cookie +// against the form-submitted token, finalises the registration via +// handleRegistration, and renders the success page. +func (a *AuthProviderOIDC) RegisterConfirmHandler( + writer http.ResponseWriter, + req *http.Request, +) { + if req.Method != http.MethodPost { + httpError(writer, errMethodNotAllowed) + + return + } + + authID, err := authIDFromRequest(req) + if err != nil { + httpError(writer, err) + + return + } + + // Cap the form body. The confirmation form is a single CSRF token, + // so 4 KiB is generous and prevents an unauthenticated client from + // submitting an arbitrarily large body to ParseForm. + req.Body = http.MaxBytesReader(writer, req.Body, 4*1024) + + if err := req.ParseForm(); err != nil { //nolint:noinlineerr,gosec // body is bounded above + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid form", err)) + + return + } + + formCSRF := req.PostFormValue(registerConfirmCSRFCookie) //nolint:gosec // body is bounded above + if formCSRF == "" { + httpError(writer, NewHTTPError(http.StatusBadRequest, "missing csrf token", nil)) + + return + } + + cookie, err := req.Cookie(registerConfirmCSRFCookie) + if err != nil { + httpError(writer, NewHTTPError(http.StatusForbidden, "missing csrf cookie", err)) + + return + } + + if cookie.Value != formCSRF { + httpError(writer, NewHTTPError(http.StatusForbidden, "csrf token mismatch", nil)) + + return + } + + authReq, ok := a.h.state.GetAuthCacheEntry(authID) + if !ok { + httpError(writer, NewHTTPError(http.StatusGone, "registration session expired", nil)) + + return + } + + pending := authReq.PendingConfirmation() + if pending == nil { + httpError(writer, NewHTTPError(http.StatusForbidden, "registration not OIDC-authorized", nil)) + + return + } + + if pending.CSRF != cookie.Value { + httpError(writer, NewHTTPError(http.StatusForbidden, "csrf token does not match cached registration", nil)) + + return + } + + user, err := a.h.state.GetUserByID(types.UserID(pending.UserID)) + if err != nil { + httpError(writer, fmt.Errorf("looking up user: %w", err)) + + return + } + + newNode, err := a.handleRegistration(user, authID, pending.NodeExpiry) + if err != nil { + if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { + httpError(writer, NewHTTPError(http.StatusGone, "registration session expired", err)) + + return + } + + httpError(writer, err) + + return + } + + // Clear the CSRF cookie now that the registration is final. + http.SetCookie(writer, &http.Cookie{ + Name: registerConfirmCSRFCookie, + Value: "", + Path: "/register/confirm/" + authID.String(), + MaxAge: -1, + Secure: req.TLS != nil, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }) + + content := renderRegistrationSuccessTemplate(user, newNode) + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + // renderRegistrationSuccessTemplate's output only embeds + // HTML-escaped values from a server-side template, so the gosec + // XSS warning is a false positive here. + if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr,gosec + util.LogErr(err, "Failed to write HTTP response") + } +} + func (a *AuthProviderOIDC) handleRegistration( user *types.User, registrationID types.AuthID, diff --git a/hscontrol/oidc_confirm_test.go b/hscontrol/oidc_confirm_test.go new file mode 100644 index 00000000..74af90ec --- /dev/null +++ b/hscontrol/oidc_confirm_test.go @@ -0,0 +1,102 @@ +package hscontrol + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newConfirmRequest(t *testing.T, authID types.AuthID, formCSRF, cookieCSRF string) *http.Request { + t.Helper() + + form := strings.NewReader(registerConfirmCSRFCookie + "=" + formCSRF) + req := httptest.NewRequestWithContext( + context.Background(), + http.MethodPost, + "/register/confirm/"+authID.String(), + form, + ) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{ + Name: registerConfirmCSRFCookie, + Value: cookieCSRF, + }) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("auth_id", authID.String()) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + return req +} + +// TestRegisterConfirmHandler_RejectsCSRFMismatch verifies that the +// /register/confirm POST handler refuses to finalise a pending +// registration when the form CSRF token does not match the cookie. +func TestRegisterConfirmHandler_RejectsCSRFMismatch(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + provider := &AuthProviderOIDC{h: app} + + // Mint a pending registration with a stashed pending-confirmation, + // as the OIDC callback would have done after resolving the user + // identity but before the user clicked the interstitial form. + authID := types.MustAuthID() + regReq := types.NewRegisterAuthRequest(&types.RegistrationData{ + Hostname: "phish-target", + }) + regReq.SetPendingConfirmation(&types.PendingRegistrationConfirmation{ + UserID: 1, + CSRF: "expected-csrf", + }) + app.state.SetAuthCacheEntry(authID, regReq) + + rec := httptest.NewRecorder() + provider.RegisterConfirmHandler(rec, + newConfirmRequest(t, authID, "wrong-csrf", "expected-csrf"), + ) + + assert.Equal(t, http.StatusForbidden, rec.Code, + "CSRF cookie/form mismatch must be rejected with 403") + + // And the registration must still be pending — the rejected POST + // must not have called handleRegistration. + cached, ok := app.state.GetAuthCacheEntry(authID) + require.True(t, ok, "rejected POST must not evict the cached registration") + require.NotNil(t, cached.PendingConfirmation(), + "rejected POST must not clear the pending confirmation") +} + +// TestRegisterConfirmHandler_RejectsWithoutPending verifies that +// /register/confirm refuses to finalise a registration that did not +// first complete the OIDC interstitial. Without this check an attacker +// who knew an auth_id could POST directly to the confirm endpoint and +// claim the device. +func TestRegisterConfirmHandler_RejectsWithoutPending(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + provider := &AuthProviderOIDC{h: app} + + authID := types.MustAuthID() + // Cached registration with NO pending confirmation set — i.e. the + // OIDC callback has not run yet. + app.state.SetAuthCacheEntry(authID, types.NewRegisterAuthRequest( + &types.RegistrationData{Hostname: "no-oidc-yet"}, + )) + + rec := httptest.NewRecorder() + provider.RegisterConfirmHandler(rec, + newConfirmRequest(t, authID, "fake", "fake"), + ) + + assert.Equal(t, http.StatusForbidden, rec.Code, + "confirm without prior OIDC pending state must be rejected with 403") +} diff --git a/hscontrol/templates/register_confirm.go b/hscontrol/templates/register_confirm.go new file mode 100644 index 00000000..3d483ffd --- /dev/null +++ b/hscontrol/templates/register_confirm.go @@ -0,0 +1,101 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" +) + +// RegisterConfirmInfo carries the human-readable information shown on +// the registration confirmation interstitial that an OIDC-authenticated +// user must explicitly accept before a pending node is registered to +// their identity. The fields here intentionally include enough device +// detail (hostname, OS, machine-key fingerprint) for the user to +// recognise whether the device they are about to claim is in fact +// theirs. +type RegisterConfirmInfo struct { + // FormAction is the absolute or relative URL the confirm form + // POSTs to. Typically /register/confirm/{auth_id}. + FormAction string + + // CSRFTokenName is the name of the hidden form field carrying the + // CSRF token. The corresponding cookie shares this name. + CSRFTokenName string + + // CSRFToken is the per-session token that must match the value of + // the cookie set by the OIDC callback before the POST is honoured. + CSRFToken string + + // User is the OIDC-authenticated identity the device will be + // registered to if the user confirms. + User string + + // Hostname is the hostname the registering tailscaled instance + // reported in its RegisterRequest. + Hostname string + + // OS is the operating system the registering tailscaled reported. + // May be the empty string when the client did not send Hostinfo. + OS string + + // MachineKey is the short fingerprint of the registering machine + // key. The full key is intentionally not shown. + MachineKey string +} + +// RegisterConfirm renders an interstitial page that asks the +// OIDC-authenticated user to explicitly confirm that they want to +// register the named device under their account. Without this +// confirmation step a single GET to /register/{auth_id} could +// silently complete a phishing-style registration when the victim's +// IdP allows silent SSO. +func RegisterConfirm(info RegisterConfirmInfo) *elem.Element { + deviceList := elem.Ul(nil, + elem.Li(nil, elem.Strong(nil, elem.Text("Hostname: ")), elem.Text(info.Hostname)), + elem.Li(nil, elem.Strong(nil, elem.Text("OS: ")), elem.Text(displayOrUnknown(info.OS))), + elem.Li(nil, elem.Strong(nil, elem.Text("Machine key: ")), Code(elem.Text(info.MachineKey))), + elem.Li(nil, elem.Strong(nil, elem.Text("Will be registered to: ")), elem.Text(info.User)), + ) + + form := elem.Form( + attrs.Props{ + attrs.Method: "POST", + attrs.Action: info.FormAction, + }, + elem.Input(attrs.Props{ + attrs.Type: "hidden", + attrs.Name: info.CSRFTokenName, + attrs.Value: info.CSRFToken, + }), + elem.Button( + attrs.Props{attrs.Type: "submit"}, + elem.Text("Confirm registration"), + ), + ) + + return HtmlStructure( + elem.Title(nil, elem.Text("Headscale - Confirm node registration")), + mdTypesetBody( + headscaleLogo(), + H2(elem.Text("Confirm node registration")), + P(elem.Text( + "A device is asking to be added to your tailnet. "+ + "Please review the details below and confirm that this device is yours.", + )), + deviceList, + form, + P(elem.Text( + "If you do not recognise this device, close this window. "+ + "The registration request will expire automatically.", + )), + pageFooter(), + ), + ) +} + +func displayOrUnknown(s string) string { + if s == "" { + return "(unknown)" + } + + return s +} diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 0e17c5a8..b0558d07 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -230,6 +230,22 @@ type SSHCheckBinding struct { DstNodeID NodeID } +// PendingRegistrationConfirmation captures the server-side state needed +// to finalise a node registration after the user has confirmed it on +// the OIDC interstitial. The OIDC callback resolves the user identity +// and node expiry, stores them on the cached AuthRequest, and renders +// a confirmation page; only when the user POSTs the confirmation form +// does the actual node registration run. +// +// CSRF is a one-shot per-session token that the OIDC callback set +// both as a cookie and as a hidden form field. The confirm POST +// handler refuses to proceed unless the cookie and form values match. +type PendingRegistrationConfirmation struct { + UserID uint + NodeExpiry *time.Time + CSRF string +} + // AuthRequest represents a pending authentication request from a user or a // node. It carries the minimum data needed to either complete a node // registration (regData populated) or an SSH check-mode auth (sshBinding @@ -257,6 +273,13 @@ type AuthRequest struct { // safely. sshBinding *SSHCheckBinding + // pendingConfirmation is populated by the OIDC callback for the + // node-registration flow once the user identity has been resolved + // but before the user has explicitly confirmed the registration on + // the interstitial. The /register/confirm POST handler reads it to + // finalise the registration without re-running the OIDC flow. + pendingConfirmation *PendingRegistrationConfirmation + finished chan AuthVerdict closed *atomic.Bool } @@ -331,6 +354,22 @@ func (rn *AuthRequest) IsSSHCheck() bool { return rn.sshBinding != nil } +// SetPendingConfirmation marks this AuthRequest as having an +// OIDC-resolved user that is waiting to confirm the registration on +// the interstitial. The OIDC callback should call this and then render +// the confirmation page; the /register/confirm POST handler reads the +// stored UserID/NodeExpiry to finish the registration. +func (rn *AuthRequest) SetPendingConfirmation(p *PendingRegistrationConfirmation) { + rn.pendingConfirmation = p +} + +// PendingConfirmation returns the pending OIDC-resolved registration +// state captured by SetPendingConfirmation, or nil if no OIDC callback +// has yet resolved an identity for this AuthRequest. +func (rn *AuthRequest) PendingConfirmation() *PendingRegistrationConfirmation { + return rn.pendingConfirmation +} + func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { if rn.closed.Swap(true) { return diff --git a/hscontrol/types/common_test.go b/hscontrol/types/common_test.go index 9ccb0145..54028493 100644 --- a/hscontrol/types/common_test.go +++ b/hscontrol/types/common_test.go @@ -57,6 +57,27 @@ func TestNewAuthRequestEmptyPayload(t *testing.T) { assert.Panics(t, func() { _ = req.SSHCheckBinding() }) } +// TestPendingRegistrationConfirmation verifies that the OIDC callback +// can stash a pending confirmation onto an AuthRequest and that the +// /register/confirm POST handler can read it back unchanged. +func TestPendingRegistrationConfirmation(t *testing.T) { + req := NewRegisterAuthRequest(&RegistrationData{Hostname: "phish-test"}) + + require.Nil(t, req.PendingConfirmation(), + "new AuthRequest must have no pending confirmation") + + pending := &PendingRegistrationConfirmation{ + UserID: 42, + CSRF: "csrf-marker", + } + req.SetPendingConfirmation(pending) + + got := req.PendingConfirmation() + require.NotNil(t, got, "PendingConfirmation must return the stored value") + assert.Equal(t, uint(42), got.UserID) + assert.Equal(t, "csrf-marker", got.CSRF) +} + func TestDefaultBatcherWorkersFor(t *testing.T) { tests := []struct { cpuCount int diff --git a/integration/scenario.go b/integration/scenario.go index d503d174..779a9bff 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -1189,9 +1189,118 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f } } + // The OIDC registration flow now renders a confirmation interstitial + // (POST form) instead of completing immediately. Detect the form and + // auto-submit it so integration tests behave like a real browser. + if followRedirects && strings.Contains(body, `action="/register/confirm/`) { + confirmBody, confirmURL, confirmErr := submitConfirmForm(hostname, body, resp, hc) + if confirmErr != nil { + return body, redirectURL, confirmErr + } + + return confirmBody, confirmURL, nil + } + return body, redirectURL, nil } +// submitConfirmForm parses the OIDC registration confirmation +// interstitial HTML, extracts the form action and CSRF token, and +// POSTs the form using the same HTTP client (which carries the CSRF +// cookie set by the callback). +func submitConfirmForm( + hostname string, + htmlBody string, + prevResp *http.Response, + hc *http.Client, +) (string, *url.URL, error) { + // Extract form action URL. + actionIdx := strings.Index(htmlBody, `action="`) + if actionIdx == -1 { + return "", nil, fmt.Errorf("%s confirm form: no action attribute", hostname) //nolint:err113 + } + + actionStart := actionIdx + len(`action="`) + + actionEnd := strings.Index(htmlBody[actionStart:], `"`) + if actionEnd == -1 { + return "", nil, fmt.Errorf("%s confirm form: unterminated action attribute", hostname) //nolint:err113 + } + + formAction := htmlBody[actionStart : actionStart+actionEnd] + + // Extract hidden CSRF input value. The rendered has + // attributes in name-type-value order so we grab the whole tag. + before, _, ok := strings.Cut(htmlBody, `name="headscale_register_confirm"`) + if !ok { + return "", nil, fmt.Errorf("%s confirm form: no CSRF input", hostname) //nolint:err113 + } + + tagStart := strings.LastIndex(before, "") + if tagEnd == -1 { + return "", nil, fmt.Errorf("%s confirm form: unterminated input tag", hostname) //nolint:err113 + } + + inputTag := htmlBody[tagStart : tagStart+tagEnd+1] + + valIdx := strings.Index(inputTag, `value="`) + if valIdx == -1 { + return "", nil, fmt.Errorf("%s confirm form: no value in CSRF input", hostname) //nolint:err113 + } + + valStart := valIdx + len(`value="`) + valEnd := strings.Index(inputTag[valStart:], `"`) + csrfToken := inputTag[valStart : valStart+valEnd] + + // Build the absolute POST URL from the response's request URL. + base := prevResp.Request.URL + confirmURL := &url.URL{ + Scheme: base.Scheme, + Host: base.Host, + Path: formAction, + } + + log.Printf("%s auto-submitting confirm form: %s", hostname, confirmURL) + + formData := url.Values{ + "headscale_register_confirm": {csrfToken}, + } + + ctx := context.Background() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, confirmURL.String(), strings.NewReader(formData.Encode())) + if err != nil { + return "", nil, fmt.Errorf("%s creating confirm request: %w", hostname, err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + confirmResp, err := hc.Do(req) + if err != nil { + return "", nil, fmt.Errorf("%s sending confirm request: %w", hostname, err) + } + defer confirmResp.Body.Close() + + confirmBytes, err := io.ReadAll(confirmResp.Body) + if err != nil { + return "", nil, fmt.Errorf("%s reading confirm response: %w", hostname, err) + } + + if confirmResp.StatusCode != http.StatusOK { + return string(confirmBytes), nil, fmt.Errorf( //nolint:err113 + "%s confirm returned status %d: %s", + hostname, confirmResp.StatusCode, string(confirmBytes), + ) + } + + return string(confirmBytes), nil, nil +} + var errParseAuthPage = errors.New("parsing auth page") func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {