From 41ce784a7fbd3bf4d228cc72b088ff92c02bccc0 Mon Sep 17 00:00:00 2001 From: DeAndre Harris <8107071+deandre@users.noreply.github.com> Date: Mon, 8 Sep 2025 02:16:30 +0200 Subject: [PATCH] feat: Add per-route OIDC client ID and secret support (#145) --- internal/auth/oauth_refresh.go | 2 +- internal/auth/oidc.go | 76 +++++++++++++++++---- internal/auth/oidc_test.go | 5 +- internal/net/gphttp/middleware/oidc.go | 26 +++++++ internal/net/gphttp/middleware/oidc_test.go | 35 ++++++++++ 5 files changed, 130 insertions(+), 14 deletions(-) create mode 100644 internal/net/gphttp/middleware/oidc_test.go diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go index 0fc03351..e4a26dd3 100644 --- a/internal/auth/oauth_refresh.go +++ b/internal/auth/oauth_refresh.go @@ -130,7 +130,7 @@ func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.R log.Err(err).Msg("failed to sign session token") return } - SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL) + SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthSessionToken), signed, common.APIJWTTokenTTL) } func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) { diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 3617f2ff..6ef021cd 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -3,6 +3,7 @@ package auth import ( "context" "crypto/rand" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -39,12 +40,27 @@ type ( var _ Provider = (*OIDCProvider)(nil) +// Cookie names for OIDC authentication const ( CookieOauthState = "godoxy_oidc_state" CookieOauthToken = "godoxy_oauth_token" //nolint:gosec CookieOauthSessionToken = "godoxy_session_token" //nolint:gosec ) +// getAppScopedCookieName returns a cookie name scoped to the specific application +// to prevent conflicts between different OIDC clients +func (auth *OIDCProvider) getAppScopedCookieName(baseName string) string { + // Use the client ID to scope the cookie name + // This prevents conflicts when multiple apps use different client IDs + if auth.oauthConfig.ClientID != "" { + // Create a hash of the client ID to keep cookie names short + hash := sha256.Sum256([]byte(auth.oauthConfig.ClientID)) + clientHash := base64.URLEncoding.EncodeToString(hash[:])[:8] + return fmt.Sprintf("%s_%s", baseName, clientHash) + } + return baseName +} + const ( OIDCAuthInitPath = "/" OIDCPostAuthPath = "/auth/callback" @@ -117,6 +133,37 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) { ) } +// NewOIDCProviderWithCustomClient creates a new OIDCProvider with custom client credentials +// based on an existing provider (for issuer discovery) +func NewOIDCProviderWithCustomClient(baseProvider *OIDCProvider, clientID, clientSecret string) (*OIDCProvider, error) { + if clientID == "" || clientSecret == "" { + return nil, errors.New("client ID and client secret are required") + } + + // Create a new OIDC verifier with the custom client ID + oidcVerifier := baseProvider.oidcProvider.Verifier(&oidc.Config{ + ClientID: clientID, + }) + + // Create new OAuth config with custom credentials + oauthConfig := &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: "", + Endpoint: baseProvider.oauthConfig.Endpoint, + Scopes: baseProvider.oauthConfig.Scopes, + } + + return &OIDCProvider{ + oauthConfig: oauthConfig, + oidcProvider: baseProvider.oidcProvider, + oidcVerifier: oidcVerifier, + endSessionURL: baseProvider.endSessionURL, + allowedUsers: baseProvider.allowedUsers, + allowedGroups: baseProvider.allowedGroups, + }, nil +} + func (auth *OIDCProvider) SetAllowedUsers(users []string) { auth.allowedUsers = users } @@ -125,6 +172,10 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) { auth.allowedGroups = groups } +func (auth *OIDCProvider) SetScopes(scopes []string) { + auth.oauthConfig.Scopes = scopes +} + // optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri" // parameter of the authorization URL to the post auth path of the current // request host. @@ -169,7 +220,7 @@ var rateLimit = rate.NewLimiter(rate.Every(time.Second), 1) func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { // check for session token - sessionToken, err := r.Cookie(CookieOauthSessionToken) + sessionToken, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken)) if err == nil { // session token exists result, err := auth.TryRefreshToken(r.Context(), sessionToken.Value) // redirect back to where they requested @@ -193,7 +244,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { } state := generateState() - SetTokenCookie(w, r, CookieOauthState, state, 300*time.Second) + SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthState), state, 300*time.Second) // redirect user to Idp url := auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)) if IsFrontend(r) { @@ -209,7 +260,8 @@ func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) { if err := idToken.Claims(&claim); err != nil { return nil, fmt.Errorf("failed to parse claims: %w", err) } - if claim.Username == "" { + // Username is optional if groups are present + if claim.Username == "" && len(claim.Groups) == 0 { return nil, errors.New("missing username in ID token") } return &claim, nil @@ -228,7 +280,7 @@ func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool { } func (auth *OIDCProvider) CheckToken(r *http.Request) error { - tokenCookie, err := r.Cookie(CookieOauthToken) + tokenCookie, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthToken)) if err != nil { return ErrMissingOAuthToken } @@ -257,7 +309,7 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http } // verify state - state, err := r.Cookie(CookieOauthState) + state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState)) if err != nil { http.Error(w, "missing state cookie", http.StatusBadRequest) return @@ -297,8 +349,8 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http } func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) { - oauthToken, _ := r.Cookie(CookieOauthToken) - sessionToken, _ := r.Cookie(CookieOauthSessionToken) + oauthToken, _ := r.Cookie(auth.getAppScopedCookieName(CookieOauthToken)) + sessionToken, _ := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken)) auth.clearCookie(w, r) if sessionToken != nil { @@ -325,17 +377,17 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) } func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) { - SetTokenCookie(w, r, CookieOauthToken, jwt, ttl) + SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthToken), jwt, ttl) } func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) { - ClearTokenCookie(w, r, CookieOauthToken) - ClearTokenCookie(w, r, CookieOauthSessionToken) + ClearTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthToken)) + ClearTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthSessionToken)) } // handleTestCallback handles OIDC callback in test environment. func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) { - state, err := r.Cookie(CookieOauthState) + state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState)) if err != nil { http.Error(w, "missing state cookie", http.StatusBadRequest) return @@ -347,7 +399,7 @@ func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Requ } // Create test JWT token - SetTokenCookie(w, r, CookieOauthToken, "test", time.Hour) + SetTokenCookie(w, r, auth.getAppScopedCookieName(CookieOauthToken), "test", time.Hour) http.Redirect(w, r, "/", http.StatusFound) } diff --git a/internal/auth/oidc_test.go b/internal/auth/oidc_test.go index c9be35fe..b30b1c64 100644 --- a/internal/auth/oidc_test.go +++ b/internal/auth/oidc_test.go @@ -426,6 +426,9 @@ func TestCheckToken(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Create the Auth Provider. auth := &OIDCProvider{ + oauthConfig: &oauth2.Config{ + ClientID: clientID, + }, oidcVerifier: provider.verifier, allowedUsers: tc.allowedUsers, allowedGroups: tc.allowedGroups, @@ -435,7 +438,7 @@ func TestCheckToken(t *testing.T) { // Craft a test HTTP request that includes the token as a cookie. req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ - Name: CookieOauthToken, + Name: auth.getAppScopedCookieName(CookieOauthToken), Value: signedToken, }) diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 6e6a3ba2..5deb4eab 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -3,6 +3,7 @@ package middleware import ( "errors" "net/http" + "strings" "sync" "sync/atomic" @@ -13,6 +14,9 @@ import ( type oidcMiddleware struct { AllowedUsers []string `json:"allowed_users"` AllowedGroups []string `json:"allowed_groups"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Scopes string `json:"scopes"` auth *auth.OIDCProvider @@ -49,11 +53,28 @@ func (amw *oidcMiddleware) initSlow() error { amw.initMu.Unlock() }() + // Always start with the global OIDC provider (for issuer discovery) authProvider, err := auth.NewOIDCProviderFromEnv() if err != nil { return err } + // Check if custom client credentials are provided + if amw.ClientID != "" && amw.ClientSecret != "" { + // Use custom client credentials + customProvider, err := auth.NewOIDCProviderWithCustomClient( + authProvider, + amw.ClientID, + amw.ClientSecret, + ) + if err != nil { + return err + } + authProvider = customProvider + } + // If no custom credentials, authProvider remains the global one + + // Apply per-route user/group restrictions (these always override global) if len(amw.AllowedUsers) > 0 { authProvider.SetAllowedUsers(amw.AllowedUsers) } @@ -61,6 +82,11 @@ func (amw *oidcMiddleware) initSlow() error { authProvider.SetAllowedGroups(amw.AllowedGroups) } + // Apply custom scopes if provided + if amw.Scopes != "" { + authProvider.SetScopes(strings.Split(amw.Scopes, ",")) + } + amw.auth = authProvider return nil } diff --git a/internal/net/gphttp/middleware/oidc_test.go b/internal/net/gphttp/middleware/oidc_test.go new file mode 100644 index 00000000..c3fe0f85 --- /dev/null +++ b/internal/net/gphttp/middleware/oidc_test.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestOIDCMiddlewarePerRouteConfig(t *testing.T) { + t.Run("middleware struct has correct fields", func(t *testing.T) { + middleware := &oidcMiddleware{ + AllowedUsers: []string{"custom-user"}, + AllowedGroups: []string{"custom-group"}, + ClientID: "custom-client-id", + ClientSecret: "custom-client-secret", + Scopes: "openid,profile,email,groups", + } + + ExpectEqual(t, middleware.AllowedUsers, []string{"custom-user"}) + ExpectEqual(t, middleware.AllowedGroups, []string{"custom-group"}) + ExpectEqual(t, middleware.ClientID, "custom-client-id") + ExpectEqual(t, middleware.ClientSecret, "custom-client-secret") + ExpectEqual(t, middleware.Scopes, "openid,profile,email,groups") + }) + + t.Run("middleware struct handles empty values", func(t *testing.T) { + middleware := &oidcMiddleware{} + + ExpectEqual(t, middleware.AllowedUsers, nil) + ExpectEqual(t, middleware.AllowedGroups, nil) + ExpectEqual(t, middleware.ClientID, "") + ExpectEqual(t, middleware.ClientSecret, "") + ExpectEqual(t, middleware.Scopes, "") + }) +}