diff --git a/go.mod b/go.mod index e6823e39..2d8870fc 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( gopkg.in/yaml.v3 v3.0.1 // yaml parsing for different config files ) -replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.1 +replace github.com/coreos/go-oidc/v3 => github.com/godoxy-app/go-oidc/v3 v3.14.2 require ( github.com/docker/cli v28.1.1+incompatible diff --git a/go.sum b/go.sum index f3836772..d5305790 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,8 @@ github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJA github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/godoxy-app/go-oidc/v3 v3.14.1 h1:QWes6QTEeyZQ40fluO8HLqF/D9W0Nj/Zm7UYhTTUJ+k= -github.com/godoxy-app/go-oidc/v3 v3.14.1/go.mod h1:ZRZLrEz7MmMe1kRzRsYqYmWKN2EHlPVGn71GMbrLLt4= +github.com/godoxy-app/go-oidc/v3 v3.14.2 h1:y1sosR6N7IpMiREM8I8w68zrUhh5P0Hg+6wERmuhFAc= +github.com/godoxy-app/go-oidc/v3 v3.14.2/go.mod h1:ZRZLrEz7MmMe1kRzRsYqYmWKN2EHlPVGn71GMbrLLt4= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= diff --git a/internal/api/handler.go b/internal/api/handler.go index a16e19bf..178fb3d5 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -98,21 +98,21 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { logging.Info().Msg("prometheus metrics enabled") } - defaultAuth := auth.GetDefaultAuth() - if defaultAuth != nil { - mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage) - mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) { - if err := defaultAuth.CheckToken(r); err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - }) - mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler) - mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutCallbackHandler) - } else { - mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) - } + // defaultAuth := auth.GetDefaultAuth() + // if defaultAuth != nil { + // mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage) + // mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) { + // if err := defaultAuth.CheckToken(r); err != nil { + // http.Error(w, err.Error(), http.StatusUnauthorized) + // return + // } + // }) + // mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler) + // mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutCallbackHandler) + // } else { + // mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) { + // w.WriteHeader(http.StatusOK) + // }) + // } return mux } diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index 0fda97fa..10369429 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -2,8 +2,6 @@ package auth import ( "context" - "crypto/rand" - "encoding/base64" "errors" "fmt" "net/http" @@ -28,7 +26,6 @@ type ( oidcEndSessionURL *url.URL allowedUsers []string allowedGroups []string - isMiddleware bool } providerJSON struct { @@ -40,11 +37,17 @@ type ( const CookieOauthState = "godoxy_oidc_state" const ( - OIDCMiddlewareCallbackPath = "/auth/callback" - OIDCLogoutPath = "/auth/logout" + OIDCAuthCallbackPath = "/auth/callback" + OIDCPostAuthPath = "/auth/postauth" + OIDCLogoutPath = "/auth/logout" ) -func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { +var ( + ErrMissingState = errors.New("missing state cookie") + ErrInvalidState = errors.New("invalid oauth state") +) + +func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { if len(allowedUsers)+len(allowedGroups) == 0 { return nil, errors.New("OIDC users, groups, or both must not be empty") } @@ -62,11 +65,15 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allo Msg("failed to parse end session URL") } + logging.Debug(). + Str("issuer", issuerURL). + Str("end_session_endpoint", provider.EndSessionEndpoint()). + Msg("end session URL") return &OIDCProvider{ oauthConfig: &oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - RedirectURL: redirectURL, + RedirectURL: "", Endpoint: provider.Endpoint(), Scopes: strutils.CommaSeperatedList(common.OIDCScopes), }, @@ -86,7 +93,6 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) { common.OIDCIssuerURL, common.OIDCClientID, common.OIDCClientSecret, - common.OIDCRedirectURL, common.OIDCAllowedUsers, common.OIDCAllowedGroups, ) @@ -96,11 +102,6 @@ func (auth *OIDCProvider) TokenCookieName() string { return "godoxy_oidc_token" } -func (auth *OIDCProvider) SetIsMiddleware(enabled bool) { - auth.isMiddleware = enabled - auth.oauthConfig.RedirectURL = "" -} - func (auth *OIDCProvider) SetAllowedUsers(users []string) { auth.allowedUsers = users } @@ -109,6 +110,56 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) { auth.allowedGroups = groups } +func (auth *OIDCProvider) getVerifyStateCookie(r *http.Request) (string, error) { + state, err := r.Cookie(CookieOauthState) + if err != nil { + return "", ErrMissingState + } + if r.URL.Query().Get("state") != state.Value { + return "", ErrInvalidState + } + return state.Value, nil +} + +func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath) +} + +func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { + logging.Debug().Str("method", r.Method).Str("path", r.URL.Path).Msg("handle auth") + + switch r.Method { + case http.MethodHead: + w.WriteHeader(http.StatusOK) + return + case http.MethodGet: + break + default: + gphttp.Forbidden(w, "method not allowed") + return + } + + switch r.URL.Path { + case OIDCAuthCallbackPath: + state := generateState() + http.SetCookie(w, &http.Cookie{ + Name: CookieOauthState, + Value: state, + MaxAge: 300, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: common.APIJWTSecure, + Path: "/", + }) + // redirect user to Idp + http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusTemporaryRedirect) + case OIDCPostAuthPath: + auth.PostAuthCallbackHandler(w, r) + default: + auth.LogoutHandler(w, r) + } +} + func (auth *OIDCProvider) CheckToken(r *http.Request) error { token, err := r.Cookie(auth.TokenCookieName()) if err != nil { @@ -143,80 +194,25 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error { return nil } -// generateState generates a random string for OIDC state. -const oidcStateLength = 32 - -func generateState() (string, error) { - b := make([]byte, oidcStateLength) - _, err := rand.Read(b) - if err != nil { - return "", err - } - return base64.URLEncoding.EncodeToString(b)[:oidcStateLength], nil -} - -// RedirectOIDC initiates the OIDC login flow. func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { - state, err := generateState() - if err != nil { - gphttp.ServerError(w, r, err) - return - } - http.SetCookie(w, &http.Cookie{ - Name: CookieOauthState, - Value: state, - MaxAge: 300, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - Secure: common.APIJWTSecure, - Path: "/", - }) - - var redirURL string - if auth.isMiddleware { - optOverrideRedirectURL := oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCMiddlewareCallbackPath) - redirURL = auth.oauthConfig.AuthCodeURL(state, optOverrideRedirectURL) - } else { - redirURL = auth.oauthConfig.AuthCodeURL(state) - } - http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect) + http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } -func (auth *OIDCProvider) cloneConfig() *oauth2.Config { - cfg := *auth.oauthConfig - return &cfg -} - -func (auth *OIDCProvider) exchange(r *http.Request) (*oauth2.Token, error) { - var cfg *oauth2.Config - if auth.isMiddleware { - cfg = auth.cloneConfig() - cfg.RedirectURL = "https://" + r.Host + OIDCMiddlewareCallbackPath - } - return cfg.Exchange(r.Context(), r.URL.Query().Get("code")) -} - -// OIDCCallbackHandler handles the OIDC callback. -func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) { +func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) { // For testing purposes, skip provider verification if common.IsTest { auth.handleTestCallback(w, r) return } - state, err := r.Cookie(CookieOauthState) + _, err := auth.getVerifyStateCookie(r) if err != nil { - gphttp.BadRequest(w, "missing state cookie") + gphttp.BadRequest(w, err.Error()) return } - query := r.URL.Query() - if query.Get("state") != state.Value { - gphttp.BadRequest(w, "invalid oauth state") - return - } - - oauth2Token, err := auth.exchange(r) + code := r.URL.Query().Get("code") + oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r)) if err != nil { gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err)) return @@ -240,23 +236,26 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } -func (auth *OIDCProvider) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { +func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) { if auth.oidcEndSessionURL == nil { - DefaultLogoutCallbackHandler(auth, w, r) + clearTokenCookie(w, r, auth.TokenCookieName()) + http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect) return } token, err := r.Cookie(auth.TokenCookieName()) - if err != nil { - gphttp.BadRequest(w, "missing token cookie") - return + if err == nil { + query := auth.oidcEndSessionURL.Query() + query.Add("id_token_hint", token.Value) + + logoutURL := *auth.oidcEndSessionURL + logoutURL.RawQuery = query.Encode() + + clearTokenCookie(w, r, auth.TokenCookieName()) + http.Redirect(w, r, logoutURL.String(), http.StatusFound) } - clearTokenCookie(w, r, auth.TokenCookieName()) - logoutURL := *auth.oidcEndSessionURL - logoutURL.Query().Add("id_token_hint", token.Value) - - http.Redirect(w, r, logoutURL.String(), http.StatusFound) + http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect) } // handleTestCallback handles OIDC callback in test environment. diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index 0ed759fa..33a7293b 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -155,10 +155,10 @@ func TestOIDCLoginHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() - defaultAuth.RedirectLoginPage(w, req) + defaultAuth.(*OIDCProvider).HandleAuth(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) @@ -219,7 +219,7 @@ func TestOIDCCallbackHandler(t *testing.T) { } w := httptest.NewRecorder() - defaultAuth.LoginCallbackHandler(w, req) + defaultAuth.(*OIDCProvider).PostAuthCallbackHandler(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus) @@ -270,7 +270,6 @@ func TestInitOIDC(t *testing.T) { issuerURL: server.URL, clientID: "client_id", clientSecret: "client_secret", - redirectURL: "https://example.com/callback", allowedUsers: []string{"user1", "user2"}, wantErr: false, }, @@ -279,7 +278,6 @@ func TestInitOIDC(t *testing.T) { issuerURL: server.URL, clientID: "client_id", clientSecret: "client_secret", - redirectURL: "https://example.com/callback", allowedGroups: []string{"group1", "group2"}, wantErr: false, }, @@ -288,7 +286,6 @@ func TestInitOIDC(t *testing.T) { issuerURL: server.URL, clientID: "client_id", clientSecret: "client_secret", - redirectURL: "https://example.com/callback", logoutURL: "https://example.com/logout", allowedUsers: []string{"user1", "user2"}, allowedGroups: []string{"group1", "group2"}, @@ -299,14 +296,13 @@ func TestInitOIDC(t *testing.T) { issuerURL: "https://example.com", clientID: "client_id", clientSecret: "client_secret", - redirectURL: "https://example.com/callback", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups) + _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.allowedUsers, tt.allowedGroups) if (err != nil) != tt.wantErr { t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/internal/api/v1/auth/provider.go b/internal/api/v1/auth/provider.go index 86e53169..d3770278 100644 --- a/internal/api/v1/auth/provider.go +++ b/internal/api/v1/auth/provider.go @@ -7,7 +7,4 @@ import ( type Provider interface { TokenCookieName() string CheckToken(r *http.Request) error - RedirectLoginPage(w http.ResponseWriter, r *http.Request) - LoginCallbackHandler(w http.ResponseWriter, r *http.Request) - LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) } diff --git a/internal/api/v1/auth/userpass.go b/internal/api/v1/auth/userpass.go index 239a4cc0..34e1e812 100644 --- a/internal/api/v1/auth/userpass.go +++ b/internal/api/v1/auth/userpass.go @@ -128,7 +128,8 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re } func (auth *UserPassAuth) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { - DefaultLogoutCallbackHandler(auth, w, r) + clearTokenCookie(w, r, auth.TokenCookieName()) + auth.RedirectLoginPage(w, r) } func (auth *UserPassAuth) validatePassword(user, pass string) error { diff --git a/internal/api/v1/auth/utils.go b/internal/api/v1/auth/utils.go index 00ed7b04..2ddd3d12 100644 --- a/internal/api/v1/auth/utils.go +++ b/internal/api/v1/auth/utils.go @@ -1,6 +1,8 @@ package auth import ( + "crypto/rand" + "encoding/base64" "net" "net/http" "time" @@ -73,8 +75,11 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) { }) } -// DefaultLogoutCallbackHandler clears the token cookie and redirects to the login page.. -func DefaultLogoutCallbackHandler(auth Provider, w http.ResponseWriter, r *http.Request) { - clearTokenCookie(w, r, auth.TokenCookieName()) - auth.RedirectLoginPage(w, r) +// generateState generates a random string for OIDC state. +const oidcStateLength = 32 + +func generateState() string { + b := make([]byte, oidcStateLength) + _, _ = rand.Read(b) + return base64.URLEncoding.EncodeToString(b)[:oidcStateLength] } diff --git a/internal/common/env.go b/internal/common/env.go index c73028b0..12bdcb72 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -48,7 +48,6 @@ var ( OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") - OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "") OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "") diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 231422b9..73b9c252 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -14,8 +14,7 @@ type oidcMiddleware struct { AllowedUsers []string `json:"allowed_users"` AllowedGroups []string `json:"allowed_groups"` - auth auth.Provider - authMux *http.ServeMux + auth *auth.OIDCProvider isInitialized int32 initMu sync.Mutex @@ -55,7 +54,6 @@ func (amw *oidcMiddleware) initSlow() error { return err } - authProvider.SetIsMiddleware(true) if len(amw.AllowedUsers) > 0 { authProvider.SetAllowedUsers(amw.AllowedUsers) } @@ -63,27 +61,24 @@ func (amw *oidcMiddleware) initSlow() error { authProvider.SetAllowedGroups(amw.AllowedGroups) } - amw.authMux = http.NewServeMux() - amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) - amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage) amw.auth = authProvider return nil } func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { if err := amw.init(); err != nil { - // no need to log here, main OIDC may already failed and logged + // no need to log here, main OIDC should've already failed and logged http.Error(w, err.Error(), http.StatusInternalServerError) return false } if r.URL.Path == auth.OIDCLogoutPath { - amw.auth.LogoutCallbackHandler(w, r) + amw.auth.LogoutHandler(w, r) return false } if err := amw.auth.CheckToken(r); err != nil { if errors.Is(err, auth.ErrMissingToken) { - amw.authMux.ServeHTTP(w, r) + amw.auth.HandleAuth(w, r) } else { auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) }