diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 358b8abd..88107903 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -44,6 +44,54 @@ func httpError(w http.ResponseWriter, err error) { } } +// httpUserError logs an error and sends a styled HTML error page. +// Use this for browser-facing error paths (OIDC, registration confirm) +// where the user should see a branded page instead of plain text. +// Technical details go to the server log; the HTML page only shows +// an actionable message derived from the HTTP status code. +func httpUserError(w http.ResponseWriter, err error) { + code := http.StatusInternalServerError + + if herr, ok := errors.AsType[HTTPError](err); ok { + if herr.Code != 0 { + code = herr.Code + } + + log.Error().Err(herr.Err).Int("code", code).Msgf("user msg: %s", herr.Msg) + } else { + log.Error().Err(err).Int("code", code).Msg("http internal server error") + } + + userMsg := userMessageForStatusCode(code) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(code) + + page := templates.AuthError(templates.AuthErrorResult{ + Title: "Headscale - Error", + Heading: http.StatusText(code), + Message: userMsg, + }) + + _, werr := w.Write([]byte(page.Render())) + if werr != nil { + log.Error().Err(werr).Msg("failed to write HTML error response") + } +} + +func userMessageForStatusCode(code int) string { + switch { + case code == http.StatusUnauthorized || code == http.StatusForbidden: + return "You are not authorized. Please contact your administrator." + case code == http.StatusGone: + return "Your session has expired. Please try again." + case code >= 400 && code < 500: + return "The request could not be processed. Please try again." + default: + return "Something went wrong. Please try again later." + } +} + // HTTPError represents an error that is surfaced to the user via web. type HTTPError struct { Code int // HTTP response code to send to client; 0 means 500 diff --git a/hscontrol/handlers_test.go b/hscontrol/handlers_test.go index 1058681e..12ab7f96 100644 --- a/hscontrol/handlers_test.go +++ b/hscontrol/handlers_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/assert" ) +var errTestUnexpected = errors.New("unexpected failure") + // TestHandleVerifyRequest_OversizedBodyRejected verifies that the // /verify handler refuses POST bodies larger than verifyBodyLimit. // The MaxBytesReader is applied in VerifyHandler, so we simulate @@ -55,3 +57,73 @@ func errorAsHTTPError(err error) (HTTPError, bool) { return HTTPError{}, false } + +func TestHttpUserError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantCode int + wantContains string + wantNotContain string + }{ + { + name: "forbidden_renders_authorization_message", + err: NewHTTPError(http.StatusForbidden, "csrf token mismatch", nil), + wantCode: http.StatusForbidden, + wantContains: "You are not authorized. Please contact your administrator.", + wantNotContain: "csrf token mismatch", + }, + { + name: "unauthorized_renders_authorization_message", + err: NewHTTPError(http.StatusUnauthorized, "unauthorised domain", nil), + wantCode: http.StatusUnauthorized, + wantContains: "You are not authorized. Please contact your administrator.", + wantNotContain: "unauthorised domain", + }, + { + name: "gone_renders_session_expired", + err: NewHTTPError(http.StatusGone, "login session expired, try again", nil), + wantCode: http.StatusGone, + wantContains: "Your session has expired. Please try again.", + wantNotContain: "login session expired", + }, + { + name: "bad_request_renders_generic_retry", + err: NewHTTPError(http.StatusBadRequest, "state not found", nil), + wantCode: http.StatusBadRequest, + wantContains: "The request could not be processed. Please try again.", + wantNotContain: "state not found", + }, + { + name: "plain_error_renders_500", + err: errTestUnexpected, + wantCode: http.StatusInternalServerError, + wantContains: "Something went wrong. Please try again later.", + }, + { + name: "html_structure_present", + err: NewHTTPError(http.StatusGone, "session expired", nil), + wantCode: http.StatusGone, + wantContains: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + rec := httptest.NewRecorder() + httpUserError(rec, tt.err) + + assert.Equal(t, tt.wantCode, rec.Code) + assert.Contains(t, rec.Header().Get("Content-Type"), "text/html") + assert.Contains(t, rec.Body.String(), tt.wantContains) + + if tt.wantNotContain != "" { + assert.NotContains(t, rec.Body.String(), tt.wantNotContain) + } + }) + } +} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 7834c8b5..5b595333 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -155,21 +155,21 @@ func (a *AuthProviderOIDC) authHandler( ) { authID, err := authIDFromRequest(req) if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } // Set the state and nonce cookies to protect against CSRF attacks state, err := setCSRFCookie(writer, req, "state") if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } // Set the state and nonce cookies to protect against CSRF attacks nonce, err := setCSRFCookie(writer, req, "nonce") if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } @@ -222,7 +222,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( ) { code, state, err := extractCodeAndStateParamFromRequest(req) if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } @@ -230,29 +230,29 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( cookieState, err := req.Cookie(stateCookieName) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) return } if state != cookieState.Value { - httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil)) return } oauth2Token, err := a.getOauth2Token(req.Context(), code, state) if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } idToken, err := a.extractIDToken(req.Context(), oauth2Token) if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } if idToken.Nonce == "" { - httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) return } @@ -260,12 +260,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( nonce, err := req.Cookie(nonceCookieName) if err != nil { - httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } if idToken.Nonce != nonce.Value { - httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return } @@ -273,7 +273,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( var claims types.OIDCClaims if err := idToken.Claims(&claims); err != nil { //nolint:noinlineerr - httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) + httpUserError(writer, fmt.Errorf("decoding ID token claims: %w", err)) return } @@ -310,26 +310,17 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // against allowed emails, email domains, and groups. err = doOIDCAuthorization(a.cfg, &claims) if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } user, _, err := a.createOrUpdateUserFromClaim(&claims) if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("could not create or update user") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - - _, werr := writer.Write([]byte("Could not create or update user")) - if werr != nil { - log.Error(). - Caller(). - Err(werr). - Msg("Failed to write HTTP response") - } + httpUserError(writer, NewHTTPError( + http.StatusInternalServerError, + "could not create or update user", + err, + )) return } @@ -341,7 +332,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( authInfo := a.getAuthInfoFromState(state) if authInfo == nil { log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired") - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + httpUserError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) return } @@ -367,7 +358,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( authReq, ok := a.h.state.GetAuthCacheEntry(authInfo.AuthID) if !ok { log.Debug().Caller().Str("auth_id", authInfo.AuthID.String()).Msg("auth session expired before authorization completed") - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + httpUserError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) return } @@ -376,7 +367,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( log.Warn().Caller(). Str("auth_id", authInfo.AuthID.String()). Msg("OIDC callback hit non-registration path with auth request that is not an SSH check binding") - httpError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for SSH check", nil)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for SSH check", nil)) return } @@ -389,7 +380,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Str("auth_id", authInfo.AuthID.String()). Uint64("src_node_id", binding.SrcNodeID.Uint64()). Msg("SSH check src node no longer exists") - httpError(writer, NewHTTPError(http.StatusGone, "src node no longer exists", nil)) + httpUserError(writer, NewHTTPError(http.StatusGone, "src node no longer exists", nil)) return } @@ -404,7 +395,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Bool("src_is_tagged", srcNode.IsTagged()). Str("oidc_user", user.Username()). Msg("SSH check rejected: src node has no user owner") - httpError(writer, NewHTTPError(http.StatusForbidden, "src node has no user owner", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "src node has no user owner", nil)) return } @@ -417,7 +408,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Uint("oidc_user_id", user.ID). Str("oidc_user", user.Username()). Msg("SSH check rejected: OIDC user is not the owner of src node") - httpError(writer, NewHTTPError(http.StatusForbidden, "OIDC user is not the owner of the SSH source node", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "OIDC user is not the owner of the SSH source node", nil)) return } @@ -680,7 +671,7 @@ func (a *AuthProviderOIDC) renderRegistrationConfirmInterstitial( authReq, ok := a.h.state.GetAuthCacheEntry(authID) if !ok { log.Debug().Caller().Str("auth_id", authID.String()).Msg("registration session expired before authorization completed") - httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) + httpUserError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) return } @@ -689,14 +680,14 @@ func (a *AuthProviderOIDC) renderRegistrationConfirmInterstitial( log.Warn().Caller(). Str("auth_id", authID.String()). Msg("OIDC callback hit registration path with auth request that is not a node registration") - httpError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for node registration", nil)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "auth session is not for node registration", nil)) return } csrf, err := util.GenerateRandomStringURLSafe(32) if err != nil { - httpError(writer, fmt.Errorf("generating csrf token: %w", err)) + httpUserError(writer, fmt.Errorf("generating csrf token: %w", err)) return } @@ -748,14 +739,14 @@ func (a *AuthProviderOIDC) RegisterConfirmHandler( req *http.Request, ) { if req.Method != http.MethodPost { - httpError(writer, errMethodNotAllowed) + httpUserError(writer, errMethodNotAllowed) return } authID, err := authIDFromRequest(req) if err != nil { - httpError(writer, err) + httpUserError(writer, err) return } @@ -766,54 +757,54 @@ func (a *AuthProviderOIDC) RegisterConfirmHandler( req.Body = http.MaxBytesReader(writer, req.Body, 4*1024) if err := req.ParseForm(); err != nil { //nolint:noinlineerr,gosec // body is bounded above - httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid form", err)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "invalid form", err)) return } formCSRF := req.PostFormValue(registerConfirmCSRFCookie) //nolint:gosec // body is bounded above if formCSRF == "" { - httpError(writer, NewHTTPError(http.StatusBadRequest, "missing csrf token", nil)) + httpUserError(writer, NewHTTPError(http.StatusBadRequest, "missing csrf token", nil)) return } cookie, err := req.Cookie(registerConfirmCSRFCookie) if err != nil { - httpError(writer, NewHTTPError(http.StatusForbidden, "missing csrf cookie", err)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "missing csrf cookie", err)) return } if cookie.Value != formCSRF { - httpError(writer, NewHTTPError(http.StatusForbidden, "csrf token mismatch", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "csrf token mismatch", nil)) return } authReq, ok := a.h.state.GetAuthCacheEntry(authID) if !ok { - httpError(writer, NewHTTPError(http.StatusGone, "registration session expired", nil)) + httpUserError(writer, NewHTTPError(http.StatusGone, "registration session expired", nil)) return } pending := authReq.PendingConfirmation() if pending == nil { - httpError(writer, NewHTTPError(http.StatusForbidden, "registration not OIDC-authorized", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "registration not OIDC-authorized", nil)) return } if pending.CSRF != cookie.Value { - httpError(writer, NewHTTPError(http.StatusForbidden, "csrf token does not match cached registration", nil)) + httpUserError(writer, NewHTTPError(http.StatusForbidden, "csrf token does not match cached registration", nil)) return } user, err := a.h.state.GetUserByID(types.UserID(pending.UserID)) if err != nil { - httpError(writer, fmt.Errorf("looking up user: %w", err)) + httpUserError(writer, fmt.Errorf("looking up user: %w", err)) return } @@ -821,12 +812,12 @@ func (a *AuthProviderOIDC) RegisterConfirmHandler( newNode, err := a.handleRegistration(user, authID, pending.NodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { - httpError(writer, NewHTTPError(http.StatusGone, "registration session expired", err)) + httpUserError(writer, NewHTTPError(http.StatusGone, "registration session expired", err)) return } - httpError(writer, err) + httpUserError(writer, err) return } diff --git a/hscontrol/oidc_template_test.go b/hscontrol/oidc_template_test.go index 24dfc0b0..05cf2286 100644 --- a/hscontrol/oidc_template_test.go +++ b/hscontrol/oidc_template_test.go @@ -7,6 +7,71 @@ import ( "github.com/stretchr/testify/assert" ) +func TestAuthErrorTemplate(t *testing.T) { + tests := []struct { + name string + result templates.AuthErrorResult + }{ + { + name: "bad_request", + result: templates.AuthErrorResult{ + Title: "Headscale - Error", + Heading: "Bad Request", + Message: "The request could not be processed. Please try again.", + }, + }, + { + name: "forbidden", + result: templates.AuthErrorResult{ + Title: "Headscale - Error", + Heading: "Forbidden", + Message: "You are not authorized. Please contact your administrator.", + }, + }, + { + name: "gone_expired", + result: templates.AuthErrorResult{ + Title: "Headscale - Error", + Heading: "Gone", + Message: "Your session has expired. Please try again.", + }, + }, + { + name: "internal_server_error", + result: templates.AuthErrorResult{ + Title: "Headscale - Error", + Heading: "Internal Server Error", + Message: "Something went wrong. Please try again later.", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + html := templates.AuthError(tt.result).Render() + + // Verify the HTML contains expected structural elements + assert.Contains(t, html, "") + assert.Contains(t, html, ""+tt.result.Title+"") + assert.Contains(t, html, tt.result.Heading) + assert.Contains(t, html, tt.result.Message) + + // Verify Material for MkDocs design system CSS is present + assert.Contains(t, html, "Material for MkDocs") + assert.Contains(t, html, "Roboto") + assert.Contains(t, html, ".md-typeset") + + // Verify SVG elements are present + assert.Contains(t, html, "