mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-25 02:09:01 +02:00
feat(auth): add CSRF protection middleware
Implement Signed Double Submit Cookie pattern to prevent CSRF attacks. Adds CSRF token generation, validation, and middleware for API endpoints. Safe methods (GET/HEAD/OPTIONS) automatically receive CSRF cookies, while unsafe methods require X-CSRF-Token header matching the cookie value with valid HMAC signature. Includes same-origin exemption for login/callback endpoints to support browser-based authentication flows.
This commit is contained in:
2
goutils
2
goutils
Submodule goutils updated: c0bbdc984e...635feb302e
109
internal/api/csrf.go
Normal file
109
internal/api/csrf.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/yusing/godoxy/internal/auth"
|
||||||
|
apitypes "github.com/yusing/goutils/apitypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CSRFMiddleware implements the Signed Double Submit Cookie pattern.
|
||||||
|
//
|
||||||
|
// Safe methods (GET/HEAD/OPTIONS): ensure a signed CSRF cookie exists.
|
||||||
|
// Unsafe methods (POST/PUT/DELETE/PATCH): require X-CSRF-Token header
|
||||||
|
// matching the cookie value, with a valid HMAC signature.
|
||||||
|
func CSRFMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
switch c.Request.Method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||||
|
ensureCSRFCookie(c)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if allowSameOriginAuthBootstrap(c.Request) {
|
||||||
|
ensureCSRFCookie(c)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie, err := c.Request.Cookie(auth.CSRFCookieName)
|
||||||
|
if err != nil {
|
||||||
|
// No cookie at all — issue one so the frontend can retry.
|
||||||
|
reissueCSRFCookie(c)
|
||||||
|
c.JSON(http.StatusForbidden, apitypes.Error("missing CSRF token"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieToken := canonicalCSRFToken(cookie.Value)
|
||||||
|
headerToken := canonicalCSRFToken(c.GetHeader(auth.CSRFHeaderName))
|
||||||
|
if headerToken == "" || cookieToken != headerToken || !auth.ValidateCSRFToken(cookieToken) {
|
||||||
|
// Stale or forged token — issue a fresh one so the
|
||||||
|
// frontend can read the new cookie and retry.
|
||||||
|
reissueCSRFCookie(c)
|
||||||
|
c.JSON(http.StatusForbidden, apitypes.Error("invalid CSRF token"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureCSRFCookie(c *gin.Context) {
|
||||||
|
if _, err := c.Request.Cookie(auth.CSRFCookieName); err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reissueCSRFCookie(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reissueCSRFCookie(c *gin.Context) {
|
||||||
|
token, err := auth.GenerateCSRFToken()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
auth.SetCSRFCookie(c.Writer, c.Request, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func allowSameOriginAuthBootstrap(r *http.Request) bool {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/api/v1/auth/login", "/api/v1/auth/callback":
|
||||||
|
return requestSourceMatchesHost(r)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestSourceMatchesHost(r *http.Request) bool {
|
||||||
|
for _, header := range []string{"Origin", "Referer"} {
|
||||||
|
value := r.Header.Get(header)
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
u, err := url.Parse(value)
|
||||||
|
if err != nil || u.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return normalizeHost(u.Hostname()) == normalizeHost(r.Host)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeHost(host string) string {
|
||||||
|
host = strings.ToLower(host)
|
||||||
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
func canonicalCSRFToken(token string) string {
|
||||||
|
return strings.Trim(strings.TrimSpace(token), "\"")
|
||||||
|
}
|
||||||
280
internal/api/csrf_test.go
Normal file
280
internal/api/csrf_test.go
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/godoxy/internal/auth"
|
||||||
|
autocert "github.com/yusing/godoxy/internal/autocert/types"
|
||||||
|
"github.com/yusing/godoxy/internal/common"
|
||||||
|
"github.com/yusing/goutils/task"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthCheckIssuesCSRFCookie(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodHead, "/api/v1/auth/check", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
|
||||||
|
csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName)
|
||||||
|
require.NotNil(t, csrfCookie)
|
||||||
|
assert.NotEmpty(t, csrfCookie.Value)
|
||||||
|
assert.Empty(t, csrfCookie.Domain)
|
||||||
|
assert.Equal(t, "/", csrfCookie.Path)
|
||||||
|
assert.Equal(t, http.SameSiteStrictMode, csrfCookie.SameSite)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassCallbackAllowsSameOriginFormPostWithoutCSRFCookie(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
|
||||||
|
req := newJSONRequest(t, http.MethodPost, "/api/v1/auth/callback", map[string]string{
|
||||||
|
"username": common.APIUser,
|
||||||
|
"password": common.APIPassword,
|
||||||
|
})
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
req.Header.Set("Origin", "https://app.example.com")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
tokenCookie := findCookie(rec.Result().Cookies(), "godoxy_token")
|
||||||
|
require.NotNil(t, tokenCookie)
|
||||||
|
assert.NotEmpty(t, tokenCookie.Value)
|
||||||
|
csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName)
|
||||||
|
require.NotNil(t, csrfCookie)
|
||||||
|
assert.NotEmpty(t, csrfCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassCallbackRejectsCrossOriginPostWithoutCSRFCookie(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
|
||||||
|
req := newJSONRequest(t, http.MethodPost, "/api/v1/auth/callback", map[string]string{
|
||||||
|
"username": common.APIUser,
|
||||||
|
"password": common.APIPassword,
|
||||||
|
})
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
req.Header.Set("Origin", "https://evil.example.com")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName)
|
||||||
|
require.NotNil(t, csrfCookie)
|
||||||
|
assert.NotEmpty(t, csrfCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassCallbackAcceptsValidCSRFCookie(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
csrfCookie := issueCSRFCookie(t, handler)
|
||||||
|
|
||||||
|
req := newJSONRequest(t, http.MethodPost, "/api/v1/auth/callback", map[string]string{
|
||||||
|
"username": common.APIUser,
|
||||||
|
"password": common.APIPassword,
|
||||||
|
})
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
req.AddCookie(csrfCookie)
|
||||||
|
req.Header.Set(auth.CSRFHeaderName, csrfCookie.Value)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
tokenCookie := findCookie(rec.Result().Cookies(), "godoxy_token")
|
||||||
|
require.NotNil(t, tokenCookie)
|
||||||
|
assert.NotEmpty(t, tokenCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnsafeRequestAcceptsQuotedCSRFCookieValue(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
csrfCookie := issueCSRFCookie(t, handler)
|
||||||
|
sessionToken := issueSessionToken(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
req.Header.Set("Cookie", `godoxy_token=`+sessionToken+`; godoxy_csrf="`+csrfCookie.Value+`"`)
|
||||||
|
req.Header.Set(auth.CSRFHeaderName, csrfCookie.Value)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutRequiresCSRFCookie(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginAllowsSameOriginPostWithoutCSRFCookie(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
req.Header.Set("Origin", "https://app.example.com")
|
||||||
|
req.Header.Set("Accept", "text/html")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName)
|
||||||
|
require.NotNil(t, csrfCookie)
|
||||||
|
assert.NotEmpty(t, csrfCookie.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetLogoutRouteStillAvailableForFrontend(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/logout", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertRenewRejectsCrossOriginWebSocketRequest(t *testing.T) {
|
||||||
|
handler := newAuthenticatedHandler(t)
|
||||||
|
provider := &stubAutocertProvider{}
|
||||||
|
sessionToken := issueSessionToken(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/cert/renew", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
|
req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
req.Header.Set("Origin", "https://evil.example.com")
|
||||||
|
req.AddCookie(&http.Cookie{Name: "godoxy_token", Value: sessionToken})
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), autocert.ContextKey{}, provider))
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||||
|
assert.Zero(t, provider.forceExpiryCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthenticatedHandler(t *testing.T) *gin.Engine {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
prevSecret := common.APIJWTSecret
|
||||||
|
prevUser := common.APIUser
|
||||||
|
prevPassword := common.APIPassword
|
||||||
|
prevDisableAuth := common.DebugDisableAuth
|
||||||
|
prevIssuerURL := common.OIDCIssuerURL
|
||||||
|
|
||||||
|
common.APIJWTSecret = []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
common.APIUser = "username"
|
||||||
|
common.APIPassword = "password"
|
||||||
|
common.DebugDisableAuth = false
|
||||||
|
common.OIDCIssuerURL = ""
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
common.APIJWTSecret = prevSecret
|
||||||
|
common.APIUser = prevUser
|
||||||
|
common.APIPassword = prevPassword
|
||||||
|
common.DebugDisableAuth = prevDisableAuth
|
||||||
|
common.OIDCIssuerURL = prevIssuerURL
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, auth.Initialize())
|
||||||
|
return NewHandler(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func issueCSRFCookie(t *testing.T, handler http.Handler) *http.Cookie {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodHead, "/api/v1/auth/check", nil)
|
||||||
|
req.Host = "app.example.com"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
csrfCookie := findCookie(rec.Result().Cookies(), auth.CSRFCookieName)
|
||||||
|
require.NotNil(t, csrfCookie)
|
||||||
|
return csrfCookie
|
||||||
|
}
|
||||||
|
|
||||||
|
func issueSessionToken(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
userpass, ok := auth.GetDefaultAuth().(*auth.UserPassAuth)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
token, err := userpass.NewToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJSONRequest(t *testing.T, method, target string, body any) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
encoded, err := json.Marshal(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(method, target, bytes.NewReader(encoded))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if cookie.Name == name {
|
||||||
|
return cookie
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubAutocertProvider struct {
|
||||||
|
forceExpiryCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubAutocertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubAutocertProvider) GetCertInfos() ([]autocert.CertInfo, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubAutocertProvider) ScheduleRenewalAll(task.Parent) {}
|
||||||
|
|
||||||
|
func (p *stubAutocertProvider) ObtainCertAll() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubAutocertProvider) ForceExpiryAll() bool {
|
||||||
|
p.forceExpiryCalls++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubAutocertProvider) WaitRenewalDone(context.Context) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -56,11 +56,11 @@ func NewHandler(requireAuth bool) *gin.Engine {
|
|||||||
if auth.IsEnabled() && requireAuth {
|
if auth.IsEnabled() && requireAuth {
|
||||||
v1Auth := r.Group("/api/v1/auth")
|
v1Auth := r.Group("/api/v1/auth")
|
||||||
{
|
{
|
||||||
v1Auth.HEAD("/check", authApi.Check)
|
v1Auth.HEAD("/check", CSRFMiddleware(), authApi.Check)
|
||||||
v1Auth.POST("/login", authApi.Login)
|
v1Auth.POST("/login", CSRFMiddleware(), authApi.Login)
|
||||||
v1Auth.GET("/callback", authApi.Callback)
|
v1Auth.GET("/callback", authApi.Callback)
|
||||||
v1Auth.POST("/callback", authApi.Callback)
|
v1Auth.POST("/callback", CSRFMiddleware(), authApi.Callback)
|
||||||
v1Auth.POST("/logout", authApi.Logout)
|
v1Auth.POST("/logout", CSRFMiddleware(), authApi.Logout)
|
||||||
v1Auth.GET("/logout", authApi.Logout)
|
v1Auth.GET("/logout", authApi.Logout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,6 +68,7 @@ func NewHandler(requireAuth bool) *gin.Engine {
|
|||||||
v1 := r.Group("/api/v1")
|
v1 := r.Group("/api/v1")
|
||||||
if auth.IsEnabled() && requireAuth {
|
if auth.IsEnabled() && requireAuth {
|
||||||
v1.Use(AuthMiddleware())
|
v1.Use(AuthMiddleware())
|
||||||
|
v1.Use(CSRFMiddleware())
|
||||||
}
|
}
|
||||||
if common.APISkipOriginCheck {
|
if common.APISkipOriginCheck {
|
||||||
v1.Use(SkipOriginCheckMiddleware())
|
v1.Use(SkipOriginCheckMiddleware())
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
// @Tags cert,websocket
|
// @Tags cert,websocket
|
||||||
// @Produce plain
|
// @Produce plain
|
||||||
// @Success 200 {object} apitypes.SuccessResponse
|
// @Success 200 {object} apitypes.SuccessResponse
|
||||||
|
// @Failure 400 {object} apitypes.ErrorResponse
|
||||||
// @Failure 403 {object} apitypes.ErrorResponse
|
// @Failure 403 {object} apitypes.ErrorResponse
|
||||||
// @Failure 500 {object} apitypes.ErrorResponse
|
// @Failure 500 {object} apitypes.ErrorResponse
|
||||||
// @Router /cert/renew [get]
|
// @Router /cert/renew [get]
|
||||||
|
|||||||
84
internal/auth/csrf.go
Normal file
84
internal/auth/csrf.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/yusing/godoxy/internal/common"
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CSRFCookieName = "godoxy_csrf"
|
||||||
|
CSRFHKDFSalt = "godoxy-csrf"
|
||||||
|
CSRFHeaderName = "X-CSRF-Token"
|
||||||
|
csrfTokenLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// csrfSecret is derived from API_JWT_SECRET via HKDF for cryptographic
|
||||||
|
// separation from JWT signing. Falls back to an ephemeral random key
|
||||||
|
// for OIDC-only setups where no JWT secret is configured.
|
||||||
|
var csrfSecret = func() []byte {
|
||||||
|
if common.APIJWTSecret != nil {
|
||||||
|
return hkdf.Extract(sha256.New, common.APIJWTSecret, []byte(CSRFHKDFSalt))
|
||||||
|
}
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
panic("failed to generate CSRF secret: " + err.Error())
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}()
|
||||||
|
|
||||||
|
func GenerateCSRFToken() (string, error) {
|
||||||
|
nonce := make([]byte, csrfTokenLength)
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
nonceHex := hex.EncodeToString(nonce)
|
||||||
|
return nonceHex + "." + csrfSign(nonceHex), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCSRFToken checks the HMAC signature embedded in the token.
|
||||||
|
// This prevents subdomain cookie-injection attacks where an attacker
|
||||||
|
// sets a forged CSRF cookie — they cannot produce a valid signature
|
||||||
|
// without the ephemeral secret.
|
||||||
|
func ValidateCSRFToken(token string) bool {
|
||||||
|
nonce, sig, ok := strings.Cut(token, ".")
|
||||||
|
if !ok || len(nonce) != csrfTokenLength*2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return hmac.Equal([]byte(sig), []byte(csrfSign(nonce)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func csrfSign(nonce string) string {
|
||||||
|
mac := hmac.New(sha256.New, csrfSecret)
|
||||||
|
mac.Write([]byte(nonce))
|
||||||
|
return hex.EncodeToString(mac.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetCSRFCookie(w http.ResponseWriter, r *http.Request, token string) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: CSRFCookieName,
|
||||||
|
Value: token,
|
||||||
|
HttpOnly: false,
|
||||||
|
Secure: common.APIJWTSecure,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClearCSRFCookie(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: CSRFCookieName,
|
||||||
|
Value: "",
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: false,
|
||||||
|
Secure: common.APIJWTSecure,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user