From 213e4a5cdb635bf1f5f5985fd04515ad592497d5 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 19 Mar 2026 14:55:47 +0800 Subject: [PATCH] feat(auth): add CSRF protection middleware Implement Signed Double Submit Cookie pattern to prevent CSRF attacks. Adds CSRF token generation, validation, and middleware for API endpoints. Safe methods (GET/HEAD/OPTIONS) automatically receive CSRF cookies, while unsafe methods require X-CSRF-Token header matching the cookie value with valid HMAC signature. Includes same-origin exemption for login/callback endpoints to support browser-based authentication flows. --- goutils | 2 +- internal/api/csrf.go | 109 +++++++++++++ internal/api/csrf_test.go | 280 ++++++++++++++++++++++++++++++++++ internal/api/handler.go | 9 +- internal/api/v1/cert/renew.go | 1 + internal/auth/csrf.go | 84 ++++++++++ 6 files changed, 480 insertions(+), 5 deletions(-) create mode 100644 internal/api/csrf.go create mode 100644 internal/api/csrf_test.go create mode 100644 internal/auth/csrf.go diff --git a/goutils b/goutils index c0bbdc98..635feb30 160000 --- a/goutils +++ b/goutils @@ -1 +1 @@ -Subproject commit c0bbdc984e138bf49c17694660ab7d0c358a2055 +Subproject commit 635feb302e50a29f4705e829a1e087cd95699fc8 diff --git a/internal/api/csrf.go b/internal/api/csrf.go new file mode 100644 index 00000000..e70d607a --- /dev/null +++ b/internal/api/csrf.go @@ -0,0 +1,109 @@ +package api + +import ( + "net" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" + "github.com/yusing/godoxy/internal/auth" + apitypes "github.com/yusing/goutils/apitypes" +) + +// CSRFMiddleware implements the Signed Double Submit Cookie pattern. +// +// Safe methods (GET/HEAD/OPTIONS): ensure a signed CSRF cookie exists. +// Unsafe methods (POST/PUT/DELETE/PATCH): require X-CSRF-Token header +// matching the cookie value, with a valid HMAC signature. +func CSRFMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + switch c.Request.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + ensureCSRFCookie(c) + c.Next() + return + } + if allowSameOriginAuthBootstrap(c.Request) { + ensureCSRFCookie(c) + c.Next() + return + } + + cookie, err := c.Request.Cookie(auth.CSRFCookieName) + if err != nil { + // No cookie at all — issue one so the frontend can retry. + reissueCSRFCookie(c) + c.JSON(http.StatusForbidden, apitypes.Error("missing CSRF token")) + c.Abort() + return + } + + cookieToken := canonicalCSRFToken(cookie.Value) + headerToken := canonicalCSRFToken(c.GetHeader(auth.CSRFHeaderName)) + if headerToken == "" || cookieToken != headerToken || !auth.ValidateCSRFToken(cookieToken) { + // Stale or forged token — issue a fresh one so the + // frontend can read the new cookie and retry. + reissueCSRFCookie(c) + c.JSON(http.StatusForbidden, apitypes.Error("invalid CSRF token")) + c.Abort() + return + } + + c.Next() + } +} + +func ensureCSRFCookie(c *gin.Context) { + if _, err := c.Request.Cookie(auth.CSRFCookieName); err == nil { + return + } + reissueCSRFCookie(c) +} + +func reissueCSRFCookie(c *gin.Context) { + token, err := auth.GenerateCSRFToken() + if err != nil { + return + } + auth.SetCSRFCookie(c.Writer, c.Request, token) +} + +func allowSameOriginAuthBootstrap(r *http.Request) bool { + if r.Method != http.MethodPost { + return false + } + switch r.URL.Path { + case "/api/v1/auth/login", "/api/v1/auth/callback": + return requestSourceMatchesHost(r) + default: + return false + } +} + +func requestSourceMatchesHost(r *http.Request) bool { + for _, header := range []string{"Origin", "Referer"} { + value := r.Header.Get(header) + if value == "" { + continue + } + u, err := url.Parse(value) + if err != nil || u.Host == "" { + return false + } + return normalizeHost(u.Hostname()) == normalizeHost(r.Host) + } + return false +} + +func normalizeHost(host string) string { + host = strings.ToLower(host) + if h, _, err := net.SplitHostPort(host); err == nil { + return h + } + return host +} + +func canonicalCSRFToken(token string) string { + return strings.Trim(strings.TrimSpace(token), "\"") +} diff --git a/internal/api/csrf_test.go b/internal/api/csrf_test.go new file mode 100644 index 00000000..a926eb58 --- /dev/null +++ b/internal/api/csrf_test.go @@ -0,0 +1,280 @@ +package api + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/auth" + autocert "github.com/yusing/godoxy/internal/autocert/types" + "github.com/yusing/godoxy/internal/common" + "github.com/yusing/goutils/task" +) + +func TestAuthCheckIssuesCSRFCookie(t *testing.T) { + handler := newAuthenticatedHandler(t) + + req := httptest.NewRequest(http.MethodHead, "/api/v1/auth/check", nil) + req.Host = "app.example.com" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code) + + csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName) + require.NotNil(t, csrfCookie) + assert.NotEmpty(t, csrfCookie.Value) + assert.Empty(t, csrfCookie.Domain) + assert.Equal(t, "/", csrfCookie.Path) + assert.Equal(t, http.SameSiteStrictMode, csrfCookie.SameSite) +} + +func TestUserPassCallbackAllowsSameOriginFormPostWithoutCSRFCookie(t *testing.T) { + handler := newAuthenticatedHandler(t) + + req := newJSONRequest(t, http.MethodPost, "/api/v1/auth/callback", map[string]string{ + "username": common.APIUser, + "password": common.APIPassword, + }) + req.Host = "app.example.com" + req.Header.Set("Origin", "https://app.example.com") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + tokenCookie := findCookie(rec.Result().Cookies(), "godoxy_token") + require.NotNil(t, tokenCookie) + assert.NotEmpty(t, tokenCookie.Value) + csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName) + require.NotNil(t, csrfCookie) + assert.NotEmpty(t, csrfCookie.Value) +} + +func TestUserPassCallbackRejectsCrossOriginPostWithoutCSRFCookie(t *testing.T) { + handler := newAuthenticatedHandler(t) + + req := newJSONRequest(t, http.MethodPost, "/api/v1/auth/callback", map[string]string{ + "username": common.APIUser, + "password": common.APIPassword, + }) + req.Host = "app.example.com" + req.Header.Set("Origin", "https://evil.example.com") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code) + csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName) + require.NotNil(t, csrfCookie) + assert.NotEmpty(t, csrfCookie.Value) +} + +func TestUserPassCallbackAcceptsValidCSRFCookie(t *testing.T) { + handler := newAuthenticatedHandler(t) + csrfCookie := issueCSRFCookie(t, handler) + + req := newJSONRequest(t, http.MethodPost, "/api/v1/auth/callback", map[string]string{ + "username": common.APIUser, + "password": common.APIPassword, + }) + req.Host = "app.example.com" + req.AddCookie(csrfCookie) + req.Header.Set(auth.CSRFHeaderName, csrfCookie.Value) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + tokenCookie := findCookie(rec.Result().Cookies(), "godoxy_token") + require.NotNil(t, tokenCookie) + assert.NotEmpty(t, tokenCookie.Value) +} + +func TestUnsafeRequestAcceptsQuotedCSRFCookieValue(t *testing.T) { + handler := newAuthenticatedHandler(t) + csrfCookie := issueCSRFCookie(t, handler) + sessionToken := issueSessionToken(t) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.Host = "app.example.com" + req.Header.Set("Cookie", `godoxy_token=`+sessionToken+`; godoxy_csrf="`+csrfCookie.Value+`"`) + req.Header.Set(auth.CSRFHeaderName, csrfCookie.Value) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code) +} + +func TestLogoutRequiresCSRFCookie(t *testing.T) { + handler := newAuthenticatedHandler(t) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil) + req.Host = "app.example.com" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestLoginAllowsSameOriginPostWithoutCSRFCookie(t *testing.T) { + handler := newAuthenticatedHandler(t) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil) + req.Host = "app.example.com" + req.Header.Set("Origin", "https://app.example.com") + req.Header.Set("Accept", "text/html") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code) + csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName) + require.NotNil(t, csrfCookie) + assert.NotEmpty(t, csrfCookie.Value) +} + +func TestGetLogoutRouteStillAvailableForFrontend(t *testing.T) { + handler := newAuthenticatedHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil) + req.Host = "app.example.com" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusFound, rec.Code) +} + +func TestCertRenewRejectsCrossOriginWebSocketRequest(t *testing.T) { + handler := newAuthenticatedHandler(t) + provider := &stubAutocertProvider{} + sessionToken := issueSessionToken(t) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/cert/renew", nil) + req.Host = "app.example.com" + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Origin", "https://evil.example.com") + req.AddCookie(&http.Cookie{Name: "godoxy_token", Value: sessionToken}) + req = req.WithContext(context.WithValue(req.Context(), autocert.ContextKey{}, provider)) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Zero(t, provider.forceExpiryCalls) +} + +func newAuthenticatedHandler(t *testing.T) *gin.Engine { + t.Helper() + + gin.SetMode(gin.TestMode) + + prevSecret := common.APIJWTSecret + prevUser := common.APIUser + prevPassword := common.APIPassword + prevDisableAuth := common.DebugDisableAuth + prevIssuerURL := common.OIDCIssuerURL + + common.APIJWTSecret = []byte("0123456789abcdef0123456789abcdef") + common.APIUser = "username" + common.APIPassword = "password" + common.DebugDisableAuth = false + common.OIDCIssuerURL = "" + + t.Cleanup(func() { + common.APIJWTSecret = prevSecret + common.APIUser = prevUser + common.APIPassword = prevPassword + common.DebugDisableAuth = prevDisableAuth + common.OIDCIssuerURL = prevIssuerURL + }) + + require.NoError(t, auth.Initialize()) + return NewHandler(true) +} + +func issueCSRFCookie(t *testing.T, handler http.Handler) *http.Cookie { + t.Helper() + + req := httptest.NewRequest(http.MethodHead, "/api/v1/auth/check", nil) + req.Host = "app.example.com" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName) + require.NotNil(t, csrfCookie) + return csrfCookie +} + +func issueSessionToken(t *testing.T) string { + t.Helper() + + userpass, ok := auth.GetDefaultAuth().(*auth.UserPassAuth) + require.True(t, ok) + + token, err := userpass.NewToken() + require.NoError(t, err) + return token +} + +func newJSONRequest(t *testing.T, method, target string, body any) *http.Request { + t.Helper() + + encoded, err := json.Marshal(body) + require.NoError(t, err) + + req := httptest.NewRequest(method, target, bytes.NewReader(encoded)) + req.Header.Set("Content-Type", "application/json") + return req +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +type stubAutocertProvider struct { + forceExpiryCalls int +} + +func (p *stubAutocertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, nil +} + +func (p *stubAutocertProvider) GetCertInfos() ([]autocert.CertInfo, error) { + return nil, nil +} + +func (p *stubAutocertProvider) ScheduleRenewalAll(task.Parent) {} + +func (p *stubAutocertProvider) ObtainCertAll() error { + return nil +} + +func (p *stubAutocertProvider) ForceExpiryAll() bool { + p.forceExpiryCalls++ + return true +} + +func (p *stubAutocertProvider) WaitRenewalDone(context.Context) bool { + return true +} diff --git a/internal/api/handler.go b/internal/api/handler.go index ceb20dc6..2588620b 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -56,11 +56,11 @@ func NewHandler(requireAuth bool) *gin.Engine { if auth.IsEnabled() && requireAuth { v1Auth := r.Group("/api/v1/auth") { - v1Auth.HEAD("/check", authApi.Check) - v1Auth.POST("/login", authApi.Login) + v1Auth.HEAD("/check", CSRFMiddleware(), authApi.Check) + v1Auth.POST("/login", CSRFMiddleware(), authApi.Login) v1Auth.GET("/callback", authApi.Callback) - v1Auth.POST("/callback", authApi.Callback) - v1Auth.POST("/logout", authApi.Logout) + v1Auth.POST("/callback", CSRFMiddleware(), authApi.Callback) + v1Auth.POST("/logout", CSRFMiddleware(), authApi.Logout) v1Auth.GET("/logout", authApi.Logout) } } @@ -68,6 +68,7 @@ func NewHandler(requireAuth bool) *gin.Engine { v1 := r.Group("/api/v1") if auth.IsEnabled() && requireAuth { v1.Use(AuthMiddleware()) + v1.Use(CSRFMiddleware()) } if common.APISkipOriginCheck { v1.Use(SkipOriginCheckMiddleware()) diff --git a/internal/api/v1/cert/renew.go b/internal/api/v1/cert/renew.go index 33232aeb..7ff6510b 100644 --- a/internal/api/v1/cert/renew.go +++ b/internal/api/v1/cert/renew.go @@ -19,6 +19,7 @@ import ( // @Tags cert,websocket // @Produce plain // @Success 200 {object} apitypes.SuccessResponse +// @Failure 400 {object} apitypes.ErrorResponse // @Failure 403 {object} apitypes.ErrorResponse // @Failure 500 {object} apitypes.ErrorResponse // @Router /cert/renew [get] diff --git a/internal/auth/csrf.go b/internal/auth/csrf.go new file mode 100644 index 00000000..90fa0ef7 --- /dev/null +++ b/internal/auth/csrf.go @@ -0,0 +1,84 @@ +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "net/http" + "strings" + + "github.com/yusing/godoxy/internal/common" + "golang.org/x/crypto/hkdf" +) + +const ( + CSRFCookieName = "godoxy_csrf" + CSRFHKDFSalt = "godoxy-csrf" + CSRFHeaderName = "X-CSRF-Token" + csrfTokenLength = 32 +) + +// csrfSecret is derived from API_JWT_SECRET via HKDF for cryptographic +// separation from JWT signing. Falls back to an ephemeral random key +// for OIDC-only setups where no JWT secret is configured. +var csrfSecret = func() []byte { + if common.APIJWTSecret != nil { + return hkdf.Extract(sha256.New, common.APIJWTSecret, []byte(CSRFHKDFSalt)) + } + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + panic("failed to generate CSRF secret: " + err.Error()) + } + return b +}() + +func GenerateCSRFToken() (string, error) { + nonce := make([]byte, csrfTokenLength) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + nonceHex := hex.EncodeToString(nonce) + return nonceHex + "." + csrfSign(nonceHex), nil +} + +// ValidateCSRFToken checks the HMAC signature embedded in the token. +// This prevents subdomain cookie-injection attacks where an attacker +// sets a forged CSRF cookie — they cannot produce a valid signature +// without the ephemeral secret. +func ValidateCSRFToken(token string) bool { + nonce, sig, ok := strings.Cut(token, ".") + if !ok || len(nonce) != csrfTokenLength*2 { + return false + } + return hmac.Equal([]byte(sig), []byte(csrfSign(nonce))) +} + +func csrfSign(nonce string) string { + mac := hmac.New(sha256.New, csrfSecret) + mac.Write([]byte(nonce)) + return hex.EncodeToString(mac.Sum(nil)) +} + +func SetCSRFCookie(w http.ResponseWriter, r *http.Request, token string) { + http.SetCookie(w, &http.Cookie{ + Name: CSRFCookieName, + Value: token, + HttpOnly: false, + Secure: common.APIJWTSecure, + SameSite: http.SameSiteStrictMode, + Path: "/", + }) +} + +func ClearCSRFCookie(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: CSRFCookieName, + Value: "", + MaxAge: -1, + HttpOnly: false, + Secure: common.APIJWTSecure, + SameSite: http.SameSiteStrictMode, + Path: "/", + }) +}