mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-28 11:51:53 +01:00
feat(oidc): support token refreshing via offline_access scope
- refactored code - moved api/v1/auth to auth/ - security enhancement - env example update - default jwt ttl changed to 24 hours
This commit is contained in:
@@ -6,10 +6,10 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/certapi"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/dockerapi"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
||||
"github.com/yusing/go-proxy/internal/auth"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
)
|
||||
|
||||
var defaultAuth Provider
|
||||
|
||||
// Initialize sets up authentication providers.
|
||||
func Initialize() error {
|
||||
if !IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
// Initialize OIDC if configured.
|
||||
if common.OIDCIssuerURL != "" {
|
||||
defaultAuth, err = NewOIDCProviderFromEnv()
|
||||
} else {
|
||||
defaultAuth, err = NewUserPassAuthFromEnv()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func GetDefaultAuth() Provider {
|
||||
return defaultAuth
|
||||
}
|
||||
|
||||
func IsEnabled() bool {
|
||||
return !common.DebugDisableAuth && (common.APIJWTSecret != nil || IsOIDCEnabled())
|
||||
}
|
||||
|
||||
func IsOIDCEnabled() bool {
|
||||
return common.OIDCIssuerURL != ""
|
||||
}
|
||||
|
||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
if IsEnabled() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := defaultAuth.CheckToken(r); err != nil {
|
||||
gphttp.ClientError(w, err, http.StatusUnauthorized)
|
||||
} else {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return next
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed block_page.html
|
||||
var blockPageHTML string
|
||||
|
||||
var blockPageTemplate = template.Must(template.New("block_page").Parse(blockPageHTML))
|
||||
|
||||
func WriteBlockPage(w http.ResponseWriter, status int, error string, logoutURL string) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
blockPageTemplate.Execute(w, map[string]string{
|
||||
"StatusText": http.StatusText(status),
|
||||
"Error": error,
|
||||
"LogoutURL": logoutURL,
|
||||
})
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
|
||||
<title>Access Denied</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>{{.StatusText}}</h1>
|
||||
<p>{{.Error}}</p>
|
||||
<a href="{{.LogoutURL}}">Logout</a>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,272 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type (
|
||||
OIDCProvider struct {
|
||||
oauthConfig *oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
oidcEndSessionURL *url.URL
|
||||
allowedUsers []string
|
||||
allowedGroups []string
|
||||
}
|
||||
|
||||
providerJSON struct {
|
||||
oidc.ProviderConfig
|
||||
EndSessionURL string `json:"end_session_endpoint"`
|
||||
}
|
||||
)
|
||||
|
||||
const CookieOauthState = "godoxy_oidc_state"
|
||||
|
||||
const (
|
||||
OIDCAuthCallbackPath = "/auth/callback"
|
||||
OIDCPostAuthPath = "/auth/postauth"
|
||||
OIDCLogoutPath = "/auth/logout"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingState = errors.New("missing state cookie")
|
||||
ErrInvalidState = errors.New("invalid oauth state")
|
||||
)
|
||||
|
||||
func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
|
||||
if len(allowedUsers)+len(allowedGroups) == 0 {
|
||||
return nil, errors.New("OIDC users, groups, or both must not be empty")
|
||||
}
|
||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
||||
}
|
||||
|
||||
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
|
||||
if err != nil && provider.EndSessionEndpoint() != "" {
|
||||
// non critical, just warn
|
||||
logging.Warn().
|
||||
Str("issuer", issuerURL).
|
||||
Err(err).
|
||||
Msg("failed to parse end session URL")
|
||||
}
|
||||
|
||||
return &OIDCProvider{
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: "",
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
||||
},
|
||||
oidcProvider: provider,
|
||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: clientID,
|
||||
}),
|
||||
oidcEndSessionURL: endSessionURL,
|
||||
allowedUsers: allowedUsers,
|
||||
allowedGroups: allowedGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
|
||||
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
||||
return NewOIDCProvider(
|
||||
common.OIDCIssuerURL,
|
||||
common.OIDCClientID,
|
||||
common.OIDCClientSecret,
|
||||
common.OIDCAllowedUsers,
|
||||
common.OIDCAllowedGroups,
|
||||
)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) TokenCookieName() string {
|
||||
return "godoxy_oidc_token"
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||
auth.allowedUsers = users
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
||||
auth.allowedGroups = groups
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) getVerifyStateCookie(r *http.Request) (string, error) {
|
||||
state, err := r.Cookie(CookieOauthState)
|
||||
if err != nil {
|
||||
return "", ErrMissingState
|
||||
}
|
||||
if r.URL.Query().Get("state") != state.Value {
|
||||
return "", ErrInvalidState
|
||||
}
|
||||
return state.Value, nil
|
||||
}
|
||||
|
||||
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
||||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodHead:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
case http.MethodGet:
|
||||
break
|
||||
default:
|
||||
gphttp.Forbidden(w, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case OIDCAuthCallbackPath:
|
||||
state := generateState()
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CookieOauthState,
|
||||
Value: state,
|
||||
MaxAge: 300,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: common.APIJWTSecure,
|
||||
Path: "/",
|
||||
})
|
||||
// redirect user to Idp
|
||||
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusTemporaryRedirect)
|
||||
case OIDCPostAuthPath:
|
||||
auth.PostAuthCallbackHandler(w, r)
|
||||
default:
|
||||
auth.LogoutHandler(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||
token, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
}
|
||||
|
||||
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
|
||||
}
|
||||
|
||||
if len(idToken.Audience) == 0 {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username"`
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return fmt.Errorf("failed to parse claims: %w", err)
|
||||
}
|
||||
|
||||
// Logical AND between allowed users and groups.
|
||||
allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
|
||||
allowedGroup := len(utils.Intersect(claims.Groups, auth.allowedGroups)) > 0
|
||||
if !allowedUser && !allowedGroup {
|
||||
return ErrUserNotAllowed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// For testing purposes, skip provider verification
|
||||
if common.IsTest {
|
||||
auth.handleTestCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := auth.getVerifyStateCookie(r)
|
||||
if err != nil {
|
||||
gphttp.BadRequest(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r))
|
||||
if err != nil {
|
||||
gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
gphttp.BadRequest(w, "missing id_token")
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
gphttp.ServerError(w, r, fmt.Errorf("failed to verify ID token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry))
|
||||
|
||||
// Redirect to home page
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if auth.oidcEndSessionURL == nil {
|
||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := r.Cookie(auth.TokenCookieName())
|
||||
if err == nil {
|
||||
query := auth.oidcEndSessionURL.Query()
|
||||
query.Add("id_token_hint", token.Value)
|
||||
|
||||
logoutURL := *auth.oidcEndSessionURL
|
||||
logoutURL.RawQuery = query.Encode()
|
||||
|
||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||
http.Redirect(w, r, logoutURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// handleTestCallback handles OIDC callback in test environment.
|
||||
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := r.Cookie(CookieOauthState)
|
||||
if err != nil {
|
||||
gphttp.BadRequest(w, "missing state cookie")
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("state") != state.Value {
|
||||
gphttp.BadRequest(w, "invalid oauth state")
|
||||
return
|
||||
}
|
||||
|
||||
// Create test JWT token
|
||||
setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour)
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
@@ -1,450 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
// setupMockOIDC configures mock OIDC provider for testing.
|
||||
func setupMockOIDC(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
|
||||
defaultAuth = &OIDCProvider{
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
RedirectURL: "http://localhost/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "http://mock-provider/auth",
|
||||
TokenURL: "http://mock-provider/token",
|
||||
},
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
},
|
||||
oidcProvider: provider,
|
||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: "test-client",
|
||||
}),
|
||||
allowedUsers: []string{"test-user"},
|
||||
allowedGroups: []string{"test-group1", "test-group2"},
|
||||
}
|
||||
}
|
||||
|
||||
// discoveryDocument returns a mock OIDC discovery document.
|
||||
func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
discovery := map[string]any{
|
||||
"issuer": server.URL,
|
||||
"authorization_endpoint": server.URL + "/auth",
|
||||
"token_endpoint": server.URL + "/token",
|
||||
}
|
||||
|
||||
return discovery
|
||||
}
|
||||
|
||||
const (
|
||||
keyID = "test-key-id"
|
||||
clientID = "test-client-id"
|
||||
)
|
||||
|
||||
type provider struct {
|
||||
ts *httptest.Server
|
||||
key *rsa.PrivateKey
|
||||
verifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
|
||||
t.Helper()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = keyID
|
||||
signed, err := token.SignedString(j.key)
|
||||
ExpectNoError(t, err)
|
||||
return signed
|
||||
}
|
||||
|
||||
func setupProvider(t *testing.T) *provider {
|
||||
t.Helper()
|
||||
|
||||
// Generate an RSA key pair for the test.
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
ExpectNoError(t, err)
|
||||
|
||||
// Build the matching public JWK that will be served by the endpoint.
|
||||
jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)
|
||||
|
||||
// Start a test server that serves the JWKS endpoint.
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/jwks.json":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"keys": []any{jwk},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
// Create a test OIDCProvider.
|
||||
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
|
||||
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
||||
|
||||
return &provider{
|
||||
ts: ts,
|
||||
key: privKey,
|
||||
verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{
|
||||
ClientID: clientID, // matches audience in the token
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint.
|
||||
func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
nBytes := pub.N.Bytes()
|
||||
eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537
|
||||
|
||||
return map[string]any{
|
||||
"kty": "RSA",
|
||||
"alg": "RS256",
|
||||
"use": "sig",
|
||||
"kid": kid,
|
||||
"n": base64.RawURLEncoding.EncodeToString(nBytes),
|
||||
"e": base64.RawURLEncoding.EncodeToString(eBytes),
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
defaultAuth = nil
|
||||
}
|
||||
|
||||
func TestOIDCLoginHandler(t *testing.T) {
|
||||
// Setup
|
||||
common.APIJWTSecret = []byte("test-secret")
|
||||
t.Cleanup(cleanup)
|
||||
setupMockOIDC(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
wantStatus int
|
||||
wantRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "Success - Redirects to provider",
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
wantRedirect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
defaultAuth.(*OIDCProvider).HandleAuth(w, req)
|
||||
|
||||
if got := w.Code; got != tt.wantStatus {
|
||||
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
||||
}
|
||||
|
||||
if tt.wantRedirect {
|
||||
if loc := w.Header().Get("Location"); loc == "" {
|
||||
t.Error("OIDCLoginHandler() missing redirect location")
|
||||
}
|
||||
|
||||
cookie := w.Header().Get("Set-Cookie")
|
||||
if cookie == "" {
|
||||
t.Error("OIDCLoginHandler() missing state cookie")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCCallbackHandler(t *testing.T) {
|
||||
// Setup
|
||||
common.APIJWTSecret = []byte("test-secret")
|
||||
t.Cleanup(cleanup)
|
||||
tests := []struct {
|
||||
name string
|
||||
state string
|
||||
code string
|
||||
setupMocks bool
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "Success - Valid callback",
|
||||
state: "valid-state",
|
||||
code: "valid-code",
|
||||
setupMocks: true,
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "Failure - Missing state",
|
||||
code: "valid-code",
|
||||
setupMocks: true,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupMocks {
|
||||
setupMockOIDC(t)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
|
||||
if tt.state != "" {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: CookieOauthState,
|
||||
Value: tt.state,
|
||||
})
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
defaultAuth.(*OIDCProvider).PostAuthCallbackHandler(w, req)
|
||||
|
||||
if got := w.Code; got != tt.wantStatus {
|
||||
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
|
||||
}
|
||||
|
||||
if tt.wantStatus == http.StatusTemporaryRedirect {
|
||||
setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
|
||||
ExpectTrue(t, setCookie.Value != "")
|
||||
ExpectEqual(t, setCookie.Path, "/")
|
||||
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitOIDC(t *testing.T) {
|
||||
setupMockOIDC(t)
|
||||
// Create a test server that serves the discovery document
|
||||
var server *httptest.Server
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
|
||||
})
|
||||
server = httptest.NewServer(mux)
|
||||
t.Cleanup(server.Close)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURL string
|
||||
logoutURL string
|
||||
allowedUsers []string
|
||||
allowedGroups []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Fail - Empty configuration",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Success - Valid configuration with users",
|
||||
issuerURL: server.URL,
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
allowedUsers: []string{"user1", "user2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Success - Valid configuration with groups",
|
||||
issuerURL: server.URL,
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
allowedGroups: []string{"group1", "group2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Success - Valid configuration with users, groups and logout URL",
|
||||
issuerURL: server.URL,
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
logoutURL: "https://example.com/logout",
|
||||
allowedUsers: []string{"user1", "user2"},
|
||||
allowedGroups: []string{"group1", "group2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Fail - No allowed users or allowed groups",
|
||||
issuerURL: "https://example.com",
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.allowedUsers, tt.allowedGroups)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckToken(t *testing.T) {
|
||||
provider := setupProvider(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedUsers []string
|
||||
allowedGroups []string
|
||||
claims jwt.Claims
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "Success - Valid token with allowed user",
|
||||
allowedUsers: []string{"user1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Valid token with allowed group",
|
||||
allowedGroups: []string{"group1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Server omits groups, but user is allowed",
|
||||
allowedUsers: []string{"user1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Server omits preferred_username, but group is allowed",
|
||||
allowedGroups: []string{"group1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Success - Valid token with allowed user and group",
|
||||
allowedUsers: []string{"user1"},
|
||||
allowedGroups: []string{"group1"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error - User not allowed",
|
||||
allowedUsers: []string{"user2", "user3"},
|
||||
allowedGroups: []string{"group2", "group3"},
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrUserNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns incorrect issuer",
|
||||
claims: jwt.MapClaims{
|
||||
"iss": "https://example.com",
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns incorrect audience",
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": "some-other-audience",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns expired token",
|
||||
claims: jwt.MapClaims{
|
||||
"iss": provider.ts.URL,
|
||||
"aud": clientID,
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create the Auth Provider.
|
||||
auth := &OIDCProvider{
|
||||
oidcVerifier: provider.verifier,
|
||||
allowedUsers: tc.allowedUsers,
|
||||
allowedGroups: tc.allowedGroups,
|
||||
}
|
||||
// Sign the claims to create a token.
|
||||
signedToken := provider.SignClaims(t, tc.claims)
|
||||
// Craft a test HTTP request that includes the token as a cookie.
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: auth.TokenCookieName(),
|
||||
Value: signedToken,
|
||||
})
|
||||
|
||||
// Call CheckToken and verify the result.
|
||||
err := auth.CheckToken(req)
|
||||
if tc.wantErr == nil {
|
||||
ExpectNoError(t, err)
|
||||
} else {
|
||||
ExpectError(t, tc.wantErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
TokenCookieName() string
|
||||
CheckToken(r *http.Request) error
|
||||
}
|
||||
@@ -1,143 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUsername = gperr.New("invalid username")
|
||||
ErrInvalidPassword = gperr.New("invalid password")
|
||||
)
|
||||
|
||||
type (
|
||||
UserPassAuth struct {
|
||||
username string
|
||||
pwdHash []byte
|
||||
secret []byte
|
||||
tokenTTL time.Duration
|
||||
}
|
||||
UserPassClaims struct {
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
)
|
||||
|
||||
func NewUserPassAuth(username, password string, secret []byte, tokenTTL time.Duration) (*UserPassAuth, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &UserPassAuth{
|
||||
username: username,
|
||||
pwdHash: hash,
|
||||
secret: secret,
|
||||
tokenTTL: tokenTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewUserPassAuthFromEnv() (*UserPassAuth, error) {
|
||||
return NewUserPassAuth(
|
||||
common.APIUser,
|
||||
common.APIPassword,
|
||||
common.APIJWTSecret,
|
||||
common.APIJWTTokenTTL,
|
||||
)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) TokenCookieName() string {
|
||||
return "godoxy_token"
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) NewToken() (token string, err error) {
|
||||
claim := &UserPassClaims{
|
||||
Username: auth.username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(auth.tokenTTL)),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
|
||||
token, err = tok.SignedString(auth.secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
}
|
||||
var claims UserPassClaims
|
||||
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return auth.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case !token.Valid:
|
||||
return ErrInvalidToken
|
||||
case claims.Username != auth.username:
|
||||
return ErrUserNotAllowed.Subject(claims.Username)
|
||||
case claims.ExpiresAt.Before(time.Now()):
|
||||
return gperr.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds struct {
|
||||
User string `json:"username"`
|
||||
Pass string `json:"password"`
|
||||
}
|
||||
err := json.NewDecoder(r.Body).Decode(&creds)
|
||||
if err != nil {
|
||||
gphttp.Unauthorized(w, "invalid credentials")
|
||||
return
|
||||
}
|
||||
if err := auth.validatePassword(creds.User, creds.Pass); err != nil {
|
||||
gphttp.Unauthorized(w, "invalid credentials")
|
||||
return
|
||||
}
|
||||
token, err := auth.NewToken()
|
||||
if err != nil {
|
||||
gphttp.ServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||
auth.RedirectLoginPage(w, r)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
||||
if user != auth.username {
|
||||
return ErrInvalidUsername.Subject(user)
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
|
||||
return ErrInvalidPassword.With(err).Subject(pass)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func newMockUserPassAuth() *UserPassAuth {
|
||||
return &UserPassAuth{
|
||||
username: "username",
|
||||
pwdHash: Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
|
||||
secret: []byte("abcdefghijklmnopqrstuvwxyz"),
|
||||
tokenTTL: time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPassValidateCredentials(t *testing.T) {
|
||||
auth := newMockUserPassAuth()
|
||||
err := auth.validatePassword("username", "password")
|
||||
ExpectNoError(t, err)
|
||||
err = auth.validatePassword("username", "wrong-password")
|
||||
ExpectError(t, ErrInvalidPassword, err)
|
||||
err = auth.validatePassword("wrong-username", "password")
|
||||
ExpectError(t, ErrInvalidUsername, err)
|
||||
}
|
||||
|
||||
func TestUserPassCheckToken(t *testing.T) {
|
||||
auth := newMockUserPassAuth()
|
||||
token, err := auth.NewToken()
|
||||
ExpectNoError(t, err)
|
||||
tests := []struct {
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
token: token,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
token: "invalid-token",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
token: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
req := &http.Request{Header: http.Header{}}
|
||||
if tt.token != "" {
|
||||
req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token)
|
||||
}
|
||||
err = auth.CheckToken(req)
|
||||
if tt.wantErr {
|
||||
ExpectTrue(t, err != nil)
|
||||
} else {
|
||||
ExpectNoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPassLoginCallbackHandler(t *testing.T) {
|
||||
type cred struct {
|
||||
User string `json:"username"`
|
||||
Pass string `json:"password"`
|
||||
}
|
||||
auth := newMockUserPassAuth()
|
||||
tests := []struct {
|
||||
creds cred
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
creds: cred{
|
||||
User: "username",
|
||||
Pass: "password",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
creds: cred{
|
||||
User: "username",
|
||||
Pass: "wrong-password",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
w := httptest.NewRecorder()
|
||||
req := &http.Request{
|
||||
Host: "app.example.com",
|
||||
Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))),
|
||||
}
|
||||
auth.LoginCallbackHandler(w, req)
|
||||
if tt.wantErr {
|
||||
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
||||
} else {
|
||||
setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||
ExpectTrue(t, setCookie.Name == auth.TokenCookieName())
|
||||
ExpectTrue(t, setCookie.Value != "")
|
||||
ExpectEqual(t, setCookie.Domain, "example.com")
|
||||
ExpectEqual(t, setCookie.Path, "/")
|
||||
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||
ExpectEqual(t, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingToken = gperr.New("missing token")
|
||||
ErrInvalidToken = gperr.New("invalid token")
|
||||
ErrUserNotAllowed = gperr.New("user not allowed")
|
||||
)
|
||||
|
||||
// cookieFQDN returns the fully qualified domain name of the request host
|
||||
// with subdomain stripped.
|
||||
//
|
||||
// If the request host does not have a subdomain,
|
||||
// an empty string is returned
|
||||
//
|
||||
// "abc.example.com" -> "example.com"
|
||||
// "example.com" -> ""
|
||||
func cookieFQDN(r *http.Request) string {
|
||||
var host string
|
||||
// check if it's from backend
|
||||
switch r.Host {
|
||||
case common.APIHTTPAddr:
|
||||
// use XFH
|
||||
host = r.Header.Get("X-Forwarded-Host")
|
||||
default:
|
||||
var err error
|
||||
host, _, err = net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
}
|
||||
|
||||
parts := strutils.SplitRune(host, '.')
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
parts[0] = ""
|
||||
return strutils.JoinRune(parts, '.')
|
||||
}
|
||||
|
||||
func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
MaxAge: int(ttl.Seconds()),
|
||||
Domain: cookieFQDN(r),
|
||||
HttpOnly: true,
|
||||
Secure: common.APIJWTSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
Domain: cookieFQDN(r),
|
||||
HttpOnly: true,
|
||||
Secure: common.APIJWTSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
// generateState generates a random string for OIDC state.
|
||||
const oidcStateLength = 32
|
||||
|
||||
func generateState() string {
|
||||
b := make([]byte, oidcStateLength)
|
||||
_, _ = rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)[:oidcStateLength]
|
||||
}
|
||||
Reference in New Issue
Block a user