mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-20 07:21:26 +02:00
feat: Add per-route OIDC client ID and secret support (#145)
This commit is contained in:
@@ -130,7 +130,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R
|
|||||||
log.Err(err).Msg("failed to sign session token")
|
log.Err(err).Msg("failed to sign session token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
|
SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthSessionToken), signed, common.APIJWTTokenTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) {
|
func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -39,12 +40,27 @@ type (
|
|||||||
|
|
||||||
var _ Provider = (*OIDCProvider)(nil)
|
var _ Provider = (*OIDCProvider)(nil)
|
||||||
|
|
||||||
|
// Cookie names for OIDC authentication
|
||||||
const (
|
const (
|
||||||
CookieOauthState = "godoxy_oidc_state"
|
CookieOauthState = "godoxy_oidc_state"
|
||||||
CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
|
CookieOauthToken = "godoxy_oauth_token" //nolint:gosec
|
||||||
CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
|
CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getAppScopedCookieName returns a cookie name scoped to the specific application
|
||||||
|
// to prevent conflicts between different OIDC clients
|
||||||
|
func (auth *OIDCProvider) getAppScopedCookieName(baseName string) string {
|
||||||
|
// Use the client ID to scope the cookie name
|
||||||
|
// This prevents conflicts when multiple apps use different client IDs
|
||||||
|
if auth.oauthConfig.ClientID != "" {
|
||||||
|
// Create a hash of the client ID to keep cookie names short
|
||||||
|
hash := sha256.Sum256([]byte(auth.oauthConfig.ClientID))
|
||||||
|
clientHash := base64.URLEncoding.EncodeToString(hash[:])[:8]
|
||||||
|
return fmt.Sprintf("%s_%s", baseName, clientHash)
|
||||||
|
}
|
||||||
|
return baseName
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OIDCAuthInitPath = "/"
|
OIDCAuthInitPath = "/"
|
||||||
OIDCPostAuthPath = "/auth/callback"
|
OIDCPostAuthPath = "/auth/callback"
|
||||||
@@ -117,6 +133,37 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewOIDCProviderWithCustomClient creates a new OIDCProvider with custom client credentials
|
||||||
|
// based on an existing provider (for issuer discovery)
|
||||||
|
func NewOIDCProviderWithCustomClient(baseProvider *OIDCProvider, clientID, clientSecret string) (*OIDCProvider, error) {
|
||||||
|
if clientID == "" || clientSecret == "" {
|
||||||
|
return nil, errors.New("client ID and client secret are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new OIDC verifier with the custom client ID
|
||||||
|
oidcVerifier := baseProvider.oidcProvider.Verifier(&oidc.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create new OAuth config with custom credentials
|
||||||
|
oauthConfig := &oauth2.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: "",
|
||||||
|
Endpoint: baseProvider.oauthConfig.Endpoint,
|
||||||
|
Scopes: baseProvider.oauthConfig.Scopes,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OIDCProvider{
|
||||||
|
oauthConfig: oauthConfig,
|
||||||
|
oidcProvider: baseProvider.oidcProvider,
|
||||||
|
oidcVerifier: oidcVerifier,
|
||||||
|
endSessionURL: baseProvider.endSessionURL,
|
||||||
|
allowedUsers: baseProvider.allowedUsers,
|
||||||
|
allowedGroups: baseProvider.allowedGroups,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||||
auth.allowedUsers = users
|
auth.allowedUsers = users
|
||||||
}
|
}
|
||||||
@@ -125,6 +172,10 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
|||||||
auth.allowedGroups = groups
|
auth.allowedGroups = groups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetScopes(scopes []string) {
|
||||||
|
auth.oauthConfig.Scopes = scopes
|
||||||
|
}
|
||||||
|
|
||||||
// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri"
|
// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri"
|
||||||
// parameter of the authorization URL to the post auth path of the current
|
// parameter of the authorization URL to the post auth path of the current
|
||||||
// request host.
|
// request host.
|
||||||
@@ -169,7 +220,7 @@ var rateLimit = rate.NewLimiter(rate.Every(time.Second), 1)
|
|||||||
|
|
||||||
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// check for session token
|
// check for session token
|
||||||
sessionToken, err := r.Cookie(CookieOauthSessionToken)
|
sessionToken, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken))
|
||||||
if err == nil { // session token exists
|
if err == nil { // session token exists
|
||||||
result, err := auth.TryRefreshToken(r.Context(), sessionToken.Value)
|
result, err := auth.TryRefreshToken(r.Context(), sessionToken.Value)
|
||||||
// redirect back to where they requested
|
// redirect back to where they requested
|
||||||
@@ -193,7 +244,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
state := generateState()
|
state := generateState()
|
||||||
SetTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
|
SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthState), state, 300*time.Second)
|
||||||
// redirect user to Idp
|
// redirect user to Idp
|
||||||
url := auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r))
|
url := auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r))
|
||||||
if IsFrontend(r) {
|
if IsFrontend(r) {
|
||||||
@@ -209,7 +260,8 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
|||||||
if err := idToken.Claims(&claim); err != nil {
|
if err := idToken.Claims(&claim); err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse claims: %w", err)
|
return nil, fmt.Errorf("failed to parse claims: %w", err)
|
||||||
}
|
}
|
||||||
if claim.Username == "" {
|
// Username is optional if groups are present
|
||||||
|
if claim.Username == "" && len(claim.Groups) == 0 {
|
||||||
return nil, errors.New("missing username in ID token")
|
return nil, errors.New("missing username in ID token")
|
||||||
}
|
}
|
||||||
return &claim, nil
|
return &claim, nil
|
||||||
@@ -228,7 +280,7 @@ func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
tokenCookie, err := r.Cookie(CookieOauthToken)
|
tokenCookie, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrMissingOAuthToken
|
return ErrMissingOAuthToken
|
||||||
}
|
}
|
||||||
@@ -257,7 +309,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// verify state
|
// verify state
|
||||||
state, err := r.Cookie(CookieOauthState)
|
state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "missing state cookie", http.StatusBadRequest)
|
http.Error(w, "missing state cookie", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
@@ -297,8 +349,8 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
oauthToken, _ := r.Cookie(CookieOauthToken)
|
oauthToken, _ := r.Cookie(auth.getAppScopedCookieName(CookieOauthToken))
|
||||||
sessionToken, _ := r.Cookie(CookieOauthSessionToken)
|
sessionToken, _ := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken))
|
||||||
auth.clearCookie(w, r)
|
auth.clearCookie(w, r)
|
||||||
|
|
||||||
if sessionToken != nil {
|
if sessionToken != nil {
|
||||||
@@ -325,17 +377,17 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) {
|
func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) {
|
||||||
SetTokenCookie(w, r, CookieOauthToken, jwt, ttl)
|
SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthToken), jwt, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) {
|
||||||
ClearTokenCookie(w, r, CookieOauthToken)
|
ClearTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthToken))
|
||||||
ClearTokenCookie(w, r, CookieOauthSessionToken)
|
ClearTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthSessionToken))
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleTestCallback handles OIDC callback in test environment.
|
// handleTestCallback handles OIDC callback in test environment.
|
||||||
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
state, err := r.Cookie(CookieOauthState)
|
state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "missing state cookie", http.StatusBadRequest)
|
http.Error(w, "missing state cookie", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
@@ -347,7 +399,7 @@ func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create test JWT token
|
// Create test JWT token
|
||||||
SetTokenCookie(w, r, CookieOauthToken, "test", time.Hour)
|
SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthToken), "test", time.Hour)
|
||||||
|
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -426,6 +426,9 @@ func TestCheckToken(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
// Create the Auth Provider.
|
// Create the Auth Provider.
|
||||||
auth := &OIDCProvider{
|
auth := &OIDCProvider{
|
||||||
|
oauthConfig: &oauth2.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
},
|
||||||
oidcVerifier: provider.verifier,
|
oidcVerifier: provider.verifier,
|
||||||
allowedUsers: tc.allowedUsers,
|
allowedUsers: tc.allowedUsers,
|
||||||
allowedGroups: tc.allowedGroups,
|
allowedGroups: tc.allowedGroups,
|
||||||
@@ -435,7 +438,7 @@ func TestCheckToken(t *testing.T) {
|
|||||||
// Craft a test HTTP request that includes the token as a cookie.
|
// Craft a test HTTP request that includes the token as a cookie.
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.AddCookie(&http.Cookie{
|
req.AddCookie(&http.Cookie{
|
||||||
Name: CookieOauthToken,
|
Name: auth.getAppScopedCookieName(CookieOauthToken),
|
||||||
Value: signedToken,
|
Value: signedToken,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -13,6 +14,9 @@ import (
|
|||||||
type oidcMiddleware struct {
|
type oidcMiddleware struct {
|
||||||
AllowedUsers []string `json:"allowed_users"`
|
AllowedUsers []string `json:"allowed_users"`
|
||||||
AllowedGroups []string `json:"allowed_groups"`
|
AllowedGroups []string `json:"allowed_groups"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
Scopes string `json:"scopes"`
|
||||||
|
|
||||||
auth *auth.OIDCProvider
|
auth *auth.OIDCProvider
|
||||||
|
|
||||||
@@ -49,11 +53,28 @@ func (amw *oidcMiddleware) initSlow() error {
|
|||||||
amw.initMu.Unlock()
|
amw.initMu.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Always start with the global OIDC provider (for issuer discovery)
|
||||||
authProvider, err := auth.NewOIDCProviderFromEnv()
|
authProvider, err := auth.NewOIDCProviderFromEnv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if custom client credentials are provided
|
||||||
|
if amw.ClientID != "" && amw.ClientSecret != "" {
|
||||||
|
// Use custom client credentials
|
||||||
|
customProvider, err := auth.NewOIDCProviderWithCustomClient(
|
||||||
|
authProvider,
|
||||||
|
amw.ClientID,
|
||||||
|
amw.ClientSecret,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
authProvider = customProvider
|
||||||
|
}
|
||||||
|
// If no custom credentials, authProvider remains the global one
|
||||||
|
|
||||||
|
// Apply per-route user/group restrictions (these always override global)
|
||||||
if len(amw.AllowedUsers) > 0 {
|
if len(amw.AllowedUsers) > 0 {
|
||||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||||
}
|
}
|
||||||
@@ -61,6 +82,11 @@ func (amw *oidcMiddleware) initSlow() error {
|
|||||||
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply custom scopes if provided
|
||||||
|
if amw.Scopes != "" {
|
||||||
|
authProvider.SetScopes(strings.Split(amw.Scopes, ","))
|
||||||
|
}
|
||||||
|
|
||||||
amw.auth = authProvider
|
amw.auth = authProvider
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
35
internal/net/gphttp/middleware/oidc_test.go
Normal file
35
internal/net/gphttp/middleware/oidc_test.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOIDCMiddlewarePerRouteConfig(t *testing.T) {
|
||||||
|
t.Run("middleware struct has correct fields", func(t *testing.T) {
|
||||||
|
middleware := &oidcMiddleware{
|
||||||
|
AllowedUsers: []string{"custom-user"},
|
||||||
|
AllowedGroups: []string{"custom-group"},
|
||||||
|
ClientID: "custom-client-id",
|
||||||
|
ClientSecret: "custom-client-secret",
|
||||||
|
Scopes: "openid,profile,email,groups",
|
||||||
|
}
|
||||||
|
|
||||||
|
ExpectEqual(t, middleware.AllowedUsers, []string{"custom-user"})
|
||||||
|
ExpectEqual(t, middleware.AllowedGroups, []string{"custom-group"})
|
||||||
|
ExpectEqual(t, middleware.ClientID, "custom-client-id")
|
||||||
|
ExpectEqual(t, middleware.ClientSecret, "custom-client-secret")
|
||||||
|
ExpectEqual(t, middleware.Scopes, "openid,profile,email,groups")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("middleware struct handles empty values", func(t *testing.T) {
|
||||||
|
middleware := &oidcMiddleware{}
|
||||||
|
|
||||||
|
ExpectEqual(t, middleware.AllowedUsers, nil)
|
||||||
|
ExpectEqual(t, middleware.AllowedGroups, nil)
|
||||||
|
ExpectEqual(t, middleware.ClientID, "")
|
||||||
|
ExpectEqual(t, middleware.ClientSecret, "")
|
||||||
|
ExpectEqual(t, middleware.Scopes, "")
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user