diff --git a/.env.example b/.env.example index 19d7c249..f022fab5 100644 --- a/.env.example +++ b/.env.example @@ -1,17 +1,14 @@ # set timezone to get correct log timestamp TZ=ETC/UTC -# generate secret with `openssl rand -base64 32` -GODOXY_API_JWT_SECRET= - -# the JWT token time-to-live -GODOXY_API_JWT_TOKEN_TTL=1h - -# API/WebUI login credentials -# Important: If using OIDC authentication, the API_USER must match the username -# provided by the OIDC provider. +# API/WebUI user password login credentials (optional) +# These fields are not required for OIDC authentication GODOXY_API_USER=admin GODOXY_API_PASSWORD=password +# generate secret with `openssl rand -base64 32` +GODOXY_API_JWT_SECRET= +# the JWT token time-to-live +GODOXY_API_JWT_TOKEN_TTL=1h # OIDC Configuration (optional) # Uncomment and configure these values to enable OIDC authentication. @@ -22,6 +19,21 @@ GODOXY_API_PASSWORD=password # GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback # Comma-separated list of scopes # GODOXY_OIDC_SCOPES=openid, profile, email +# +# User definitions: Uncomment and configure these values to restrict access to specific users or groups. +# These two fields act as a logical AND operator. For example, given the following membership: +# user1, group1 +# user2, group1 +# user3, group2 +# user1, group2 +# You can allow access to user3 AND all users of group1 by providing: +# # GODOXY_OIDC_ALLOWED_USERS=user3 +# # GODOXY_OIDC_ALLOWED_GROUPS=group1 +# +# Comma-separated list of allowed users. +# GODOXY_OIDC_ALLOWED_USERS=user1,user2 +# Optional: Comma-separated list of allowed groups. +# GODOXY_OIDC_ALLOWED_GROUPS=group1,group2 # Proxy listening address GODOXY_HTTP_ADDR=:80 diff --git a/Makefile b/Makefile index 83b4e26f..4a7942cd 100755 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ profile: run: build sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy - bin/godoxy + [ -f .env ] && godotenv -f .env bin/godoxy || bin/godoxy mtrace: bin/godoxy debug-ls-mtrace > mtrace.json diff --git a/cmd/main.go b/cmd/main.go index e2417472..b38fc0ee 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -117,18 +117,13 @@ func main() { return } + if err := auth.Initialize(); err != nil { + logging.Fatal().Err(err).Msg("failed to initialize authentication") + } + cfg.Start() config.WatchChanges() - if !auth.IsEnabled() { - logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") - } else { - // Initialize authentication providers - if err := auth.Initialize(); err != nil { - logging.Fatal().Err(err).Msg("Failed to initialize authentication providers") - } - } - sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM) diff --git a/internal/api/handler.go b/internal/api/handler.go index 5e121616..eff102f4 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -1,45 +1,46 @@ package api import ( - "net" "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/favicon" - . "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ServeMux struct{ *http.ServeMux } -func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { - mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler)) +func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) { + for _, m := range strutils.CommaSeperatedList(methods) { + mux.ServeMux.HandleFunc(m+" "+endpoint, handler) + } } func NewHandler(cfg config.ConfigInstance) http.Handler { mux := ServeMux{http.NewServeMux()} mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) - mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler) - mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler) - mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) - mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) - mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) - mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) - mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) + mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile) mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) mux.HandleFunc("GET", "/v1/health/ws", useCfg(cfg, v1.HealthWS)) mux.HandleFunc("GET", "/v1/logs/ws", useCfg(cfg, v1.LogsWS())) mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon)) + + defaultAuth := auth.GetDefaultAuth() + if defaultAuth != nil { + mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage) + mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler) + mux.HandleFunc("GET,POST", "/v1/auth/logout", auth.LogoutCallbackHandler(defaultAuth)) + } return mux } @@ -48,20 +49,3 @@ func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w handler(cfg, w, r) } } - -// allow only requests to API server with localhost. -func checkHost(f http.HandlerFunc) http.HandlerFunc { - if common.IsDebug { - return f - } - return func(w http.ResponseWriter, r *http.Request) { - host, _, _ := net.SplitHostPort(r.RemoteAddr) - if host != "127.0.0.1" && host != "localhost" && host != "[::1]" { - LogWarn(r).Msgf("blocked API request from %s", host) - http.Error(w, "forbidden", http.StatusForbidden) - return - } - LogDebug(r).Interface("headers", r.Header).Msg("API request") - f(w, r) - } -} diff --git a/internal/api/v1/auth/auth.go b/internal/api/v1/auth/auth.go index 532d339b..ed0060e2 100644 --- a/internal/api/v1/auth/auth.go +++ b/internal/api/v1/auth/auth.go @@ -1,139 +1,54 @@ package auth import ( - "fmt" "net/http" - "time" - "github.com/golang-jwt/jwt/v5" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" - "github.com/yusing/go-proxy/internal/utils/strutils" + "github.com/yusing/go-proxy/internal/logging" ) -type ( - Credentials struct { - Username string `json:"username"` - Password string `json:"password"` - } - Claims struct { - Username string `json:"username"` - jwt.RegisteredClaims - } -) +var defaultAuth Provider // Initialize sets up authentication providers. func Initialize() error { + if !IsEnabled() { + logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") + return nil + } + + var err error // Initialize OIDC if configured. if common.OIDCIssuerURL != "" { - return InitOIDC( - common.OIDCIssuerURL, - common.OIDCClientID, - common.OIDCClientSecret, - common.OIDCRedirectURL, - ) + defaultAuth, err = NewOIDCProviderFromEnv() + } else { + defaultAuth, err = NewUserPassAuthFromEnv() } - return nil + + return err +} + +func GetDefaultAuth() Provider { + return defaultAuth } func IsEnabled() bool { - return common.APIJWTSecret != nil || common.OIDCIssuerURL != "" + return common.APIJWTSecret != nil || IsOIDCEnabled() } -// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration. -func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) { - switch { - case oauthConfig != nil: - RedirectOIDC(w, r) - return - case common.APIJWTSecret != nil: - http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) - return - default: - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) - } -} - -func setAuthenticatedCookie(w http.ResponseWriter, username string) error { - expiresAt := time.Now().Add(common.APIJWTTokenTTL) - claim := &Claims{ - Username: username, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expiresAt), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim) - tokenStr, err := token.SignedString(common.APIJWTSecret) - if err != nil { - return err - } - http.SetCookie(w, &http.Cookie{ - Name: CookieToken, - Value: tokenStr, - Expires: expiresAt, - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - Path: "/", - }) - return nil -} - -// LogoutHandler clear authentication cookie and redirect to login page. -func LogoutHandler(w http.ResponseWriter, r *http.Request) { - http.SetCookie(w, &http.Cookie{ - Name: CookieToken, - Value: "", - Expires: time.Unix(0, 0), - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - Path: "/", - }) - AuthRedirectHandler(w, r) +func IsOIDCEnabled() bool { + return common.OIDCIssuerURL != "" } func RequireAuth(next http.HandlerFunc) http.HandlerFunc { if IsEnabled() { return func(w http.ResponseWriter, r *http.Request) { - if checkToken(w, r) { + if err := defaultAuth.CheckToken(r); err != nil { + U.RespondError(w, err, http.StatusUnauthorized) + } else { next(w, r) } } } return next } - -func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { - tokenCookie, err := r.Cookie(CookieToken) - if err != nil { - U.RespondError(w, E.New("missing token"), http.StatusUnauthorized) - return false - } - var claims Claims - token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return common.APIJWTSecret, nil - }) - - switch { - case err != nil: - break - case !token.Valid: - err = E.New("invalid token") - case claims.Username != common.APIUser: - err = E.New("username mismatch").Subject(claims.Username) - case claims.ExpiresAt.Before(time.Now()): - err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) - } - - if err != nil { - U.RespondError(w, err, http.StatusForbidden) - return false - } - - return true -} diff --git a/internal/api/v1/auth/cookies.go b/internal/api/v1/auth/cookies.go deleted file mode 100644 index c6d73862..00000000 --- a/internal/api/v1/auth/cookies.go +++ /dev/null @@ -1,6 +0,0 @@ -package auth - -const ( - CookieToken = "godoxy_token" - CookieOauthState = "godoxy_oauth_state" -) diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index 3b1886ae..0c96559e 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -2,87 +2,186 @@ package auth import ( "context" + "crypto/rand" + "encoding/base64" + "errors" "fmt" "net/http" + "slices" + "time" "github.com/coreos/go-oidc/v3/oidc" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + CE "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" "golang.org/x/oauth2" ) -var ( - oauthConfig *oauth2.Config - oidcProvider *oidc.Provider - oidcVerifier *oidc.IDTokenVerifier +type OIDCProvider struct { + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + allowedUsers []string + allowedGroups []string + isMiddleware bool +} + +const CookieOauthState = "godoxy_oidc_state" + +const ( + OIDCMiddlewareCallbackPath = "/auth/callback" + OIDCLogoutPath = "/auth/logout" ) -// InitOIDC initializes the OIDC provider. -func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error { - if issuerURL == "" { - return nil // OIDC not configured +func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { + if len(allowedUsers)+len(allowedGroups) == 0 { + return nil, errors.New("OIDC users, groups, or both must not be empty") } provider, err := oidc.NewProvider(context.Background(), issuerURL) if err != nil { - return fmt.Errorf("failed to initialize OIDC provider: %w", err) + return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) } - oidcProvider = provider - oidcVerifier = provider.Verifier(&oidc.Config{ - ClientID: clientID, - }) + return &OIDCProvider{ + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Endpoint: provider.Endpoint(), + Scopes: strutils.CommaSeperatedList(common.OIDCScopes), + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ + ClientID: clientID, + }), + allowedUsers: allowedUsers, + allowedGroups: allowedGroups, + }, nil +} - oauthConfig = &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: redirectURL, - Endpoint: provider.Endpoint(), - Scopes: strutils.CommaSeperatedList(common.OIDCScopes), +// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables. +func NewOIDCProviderFromEnv() (*OIDCProvider, error) { + return NewOIDCProvider( + common.OIDCIssuerURL, + common.OIDCClientID, + common.OIDCClientSecret, + common.OIDCRedirectURL, + common.OIDCAllowedUsers, + common.OIDCAllowedGroups, + ) +} + +func (auth *OIDCProvider) TokenCookieName() string { + return "godoxy_oidc_token" +} + +func (auth *OIDCProvider) SetIsMiddleware(enabled bool) { + auth.isMiddleware = enabled +} + +func (auth *OIDCProvider) SetAllowedUsers(users []string) { + auth.allowedUsers = users +} + +func (auth *OIDCProvider) SetAllowedGroups(groups []string) { + auth.allowedGroups = groups +} + +func (auth *OIDCProvider) CheckToken(r *http.Request) error { + token, err := r.Cookie(auth.TokenCookieName()) + if err != nil { + return ErrMissingToken } + // checks for Expiry, Audience == ClientID, Issuer, etc. + idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value) + if err != nil { + return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err) + } + + if len(idToken.Audience) == 0 { + return ErrInvalidToken + } + + var claims struct { + Email string `json:"email"` + Username string `json:"preferred_username"` + Groups []string `json:"groups"` + } + if err := idToken.Claims(&claims); err != nil { + return fmt.Errorf("failed to parse claims: %w", err) + } + + // Logical AND between allowed users and groups. + allowedUser := slices.Contains(auth.allowedUsers, claims.Username) + allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0 + if !allowedUser && !allowedGroup { + return ErrUserNotAllowed.Subject(claims.Username) + } return nil } +// generateState generates a random string for OIDC state. +const oidcStateLength = 32 + +func generateState() (string, error) { + b := make([]byte, oidcStateLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b)[:oidcStateLength], nil +} + // RedirectOIDC initiates the OIDC login flow. -func RedirectOIDC(w http.ResponseWriter, r *http.Request) { - if oauthConfig == nil { - U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) +func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { + state, err := generateState() + if err != nil { + U.HandleErr(w, r, err, http.StatusInternalServerError) return } - - state := common.GenerateRandomString(32) http.SetCookie(w, &http.Cookie{ Name: CookieOauthState, Value: state, MaxAge: 300, HttpOnly: true, - SameSite: http.SameSiteNoneMode, + SameSite: http.SameSiteLaxMode, Secure: true, Path: "/", }) - url := oauthConfig.AuthCodeURL(state) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) + redirURL := auth.oauthConfig.AuthCodeURL(state) + if auth.isMiddleware { + u, err := r.URL.Parse(redirURL) + if err != nil { + U.HandleErr(w, r, err, http.StatusInternalServerError) + return + } + q := u.Query() + q.Set("redirect_uri", "https://"+r.Host+OIDCMiddlewareCallbackPath+q.Get("redirect_uri")) + u.RawQuery = q.Encode() + redirURL = u.String() + } + http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect) +} + +func (auth *OIDCProvider) exchange(r *http.Request) (*oauth2.Token, error) { + if auth.isMiddleware { + cfg := *auth.oauthConfig + cfg.RedirectURL = "https://" + r.Host + OIDCMiddlewareCallbackPath + return cfg.Exchange(r.Context(), r.URL.Query().Get("code")) + } + return auth.oauthConfig.Exchange(r.Context(), r.URL.Query().Get("code")) } // OIDCCallbackHandler handles the OIDC callback. -func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { - if oauthConfig == nil { - U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) - return - } - +func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) { // For testing purposes, skip provider verification if common.IsTest { - handleTestCallback(w, r) - return - } - - if oidcProvider == nil { - U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) + auth.handleTestCallback(w, r) return } @@ -92,13 +191,13 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { return } - if r.URL.Query().Get("state") != state.Value { + query := r.URL.Query() + if query.Get("state") != state.Value { U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest) return } - code := r.URL.Query().Get("code") - oauth2Token, err := oauthConfig.Exchange(r.Context(), code) + oauth2Token, err := auth.exchange(r) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError) return @@ -110,32 +209,20 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { return } - idToken, err := oidcVerifier.Verify(r.Context(), rawIDToken) + idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError) return } - var claims struct { - Email string `json:"email"` - Username string `json:"preferred_username"` - } - if err := idToken.Claims(&claims); err != nil { - U.HandleErr(w, r, fmt.Errorf("failed to parse claims: %w", err), http.StatusInternalServerError) - return - } - - if err := setAuthenticatedCookie(w, claims.Username); err != nil { - U.HandleErr(w, r, err, http.StatusInternalServerError) - return - } + setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry)) // Redirect to home page http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } // handleTestCallback handles OIDC callback in test environment. -func handleTestCallback(w http.ResponseWriter, r *http.Request) { +func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) { state, err := r.Cookie(CookieOauthState) if err != nil { U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) @@ -148,10 +235,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) { } // Create test JWT token - if err := setAuthenticatedCookie(w, "test-user"); err != nil { - U.HandleErr(w, r, err, http.StatusInternalServerError) - return - } + setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour) http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index 1e1f986d..d14715ea 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -1,77 +1,165 @@ package auth import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" "net/http" "net/http/httptest" "testing" + "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v5" "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" "golang.org/x/oauth2" + + . "github.com/yusing/go-proxy/internal/utils/testing" ) // setupMockOIDC configures mock OIDC provider for testing. func setupMockOIDC(t *testing.T) { t.Helper() - oauthConfig = &oauth2.Config{ - ClientID: "test-client", - ClientSecret: "test-secret", - RedirectURL: "http://localhost/callback", - Endpoint: oauth2.Endpoint{ - AuthURL: "http://mock-provider/auth", - TokenURL: "http://mock-provider/token", + provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO()) + defaultAuth = &OIDCProvider{ + oauthConfig: &oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost/callback", + Endpoint: oauth2.Endpoint{ + AuthURL: "http://mock-provider/auth", + TokenURL: "http://mock-provider/token", + }, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, }, - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ + ClientID: "test-client", + }), + allowedUsers: []string{"test-user"}, + allowedGroups: []string{"test-group1", "test-group2"}, + } +} + +// discoveryDocument returns a mock OIDC discovery document. +func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any { + t.Helper() + + discovery := map[string]any{ + "issuer": server.URL, + "authorization_endpoint": server.URL + "/auth", + "token_endpoint": server.URL + "/token", + } + + return discovery +} + +const ( + keyID = "test-key-id" + clientID = "test-client-id" +) + +type provider struct { + ts *httptest.Server + key *rsa.PrivateKey + verifier *oidc.IDTokenVerifier +} + +func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string { + t.Helper() + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = keyID + signed, err := token.SignedString(j.key) + ExpectNoError(t, err) + return signed +} + +func setupProvider(t *testing.T) *provider { + t.Helper() + + // Generate an RSA key pair for the test. + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + ExpectNoError(t, err) + + // Build the matching public JWK that will be served by the endpoint. + jwk := buildRSAJWK(t, &privKey.PublicKey, keyID) + + // Start a test server that serves the JWKS endpoint. + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/jwks.json": + _ = json.NewEncoder(w).Encode(map[string]any{ + "keys": []any{jwk}, + }) + default: + http.NotFound(w, r) + } + })) + t.Cleanup(ts.Close) + + // Create a test OIDCProvider. + providerCtx := oidc.ClientContext(context.Background(), ts.Client()) + keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json") + + return &provider{ + ts: ts, + key: privKey, + verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{ + ClientID: clientID, // matches audience in the token + }), + } +} + +// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint +func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any { + t.Helper() + + nBytes := pub.N.Bytes() + eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537 + + return map[string]any{ + "kty": "RSA", + "alg": "RS256", + "use": "sig", + "kid": kid, + "n": base64.RawURLEncoding.EncodeToString(nBytes), + "e": base64.RawURLEncoding.EncodeToString(eBytes), } } func cleanup() { - oauthConfig = nil - oidcProvider = nil - oidcVerifier = nil + defaultAuth = nil } func TestOIDCLoginHandler(t *testing.T) { // Setup common.APIJWTSecret = []byte("test-secret") - common.IsTest = true - t.Cleanup(func() { - cleanup() - common.IsTest = false - }) + t.Cleanup(cleanup) setupMockOIDC(t) tests := []struct { - name string - configureOAuth bool - wantStatus int - wantRedirect bool + name string + wantStatus int + wantRedirect bool }{ { - name: "Success - Redirects to provider", - configureOAuth: true, - wantStatus: http.StatusTemporaryRedirect, - wantRedirect: true, - }, - { - name: "Failure - OIDC not configured", - configureOAuth: false, - wantStatus: http.StatusNotImplemented, - wantRedirect: false, + name: "Success - Redirects to provider", + wantStatus: http.StatusTemporaryRedirect, + wantRedirect: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if !tt.configureOAuth { - oauthConfig = nil - } - req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) w := httptest.NewRecorder() - RedirectOIDC(w, req) + defaultAuth.RedirectLoginPage(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) @@ -94,110 +182,267 @@ func TestOIDCLoginHandler(t *testing.T) { func TestOIDCCallbackHandler(t *testing.T) { // Setup common.APIJWTSecret = []byte("test-secret") - common.IsTest = true - t.Cleanup(func() { - cleanup() - common.IsTest = false - }) + t.Cleanup(cleanup) tests := []struct { - name string - configureOAuth bool - state string - code string - setupMocks func() - wantStatus int + name string + state string + code string + setupMocks bool + wantStatus int }{ { - name: "Success - Valid callback", - configureOAuth: true, - state: "valid-state", - code: "valid-code", - setupMocks: func() { - setupMockOIDC(t) - }, + name: "Success - Valid callback", + state: "valid-state", + code: "valid-code", + setupMocks: true, wantStatus: http.StatusTemporaryRedirect, }, { - name: "Failure - OIDC not configured", - configureOAuth: false, - wantStatus: http.StatusNotImplemented, - }, - { - name: "Failure - Missing state", - configureOAuth: true, - code: "valid-code", - setupMocks: func() { - setupMockOIDC(t) - }, + name: "Failure - Missing state", + code: "valid-code", + setupMocks: true, wantStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.setupMocks != nil { - tt.setupMocks() - } - - if !tt.configureOAuth { - oauthConfig = nil + if tt.setupMocks { + setupMockOIDC(t) } req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil) if tt.state != "" { req.AddCookie(&http.Cookie{ - Name: "oauth_state", + Name: CookieOauthState, Value: tt.state, }) } w := httptest.NewRecorder() - OIDCCallbackHandler(w, req) + defaultAuth.LoginCallbackHandler(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus) } if tt.wantStatus == http.StatusTemporaryRedirect { - cookie := w.Header().Get("Set-Cookie") - if cookie == "" { - t.Error("OIDCCallbackHandler() missing token cookie") - } + setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) + ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName()) + ExpectTrue(t, setCookie.Value != "") + ExpectEqual(t, setCookie.Path, "/") + ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) + ExpectEqual(t, setCookie.HttpOnly, true) } }) } } func TestInitOIDC(t *testing.T) { - common.IsTest = true - t.Cleanup(func() { - common.IsTest = false + setupMockOIDC(t) + // Create a test server that serves the discovery document + var server *httptest.Server + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server))) }) + server = httptest.NewServer(mux) + t.Cleanup(server.Close) + t.Cleanup(cleanup) + tests := []struct { - name string - issuerURL string - clientID string - clientSecret string - redirectURL string - wantErr bool + name string + issuerURL string + clientID string + clientSecret string + redirectURL string + allowedUsers []string + allowedGroups []string + wantErr bool }{ { - name: "Success - Empty configuration", + name: "Fail - Empty configuration", issuerURL: "", clientID: "", clientSecret: "", redirectURL: "", + allowedUsers: nil, + wantErr: true, + }, + { + name: "Success - Valid configuration with users", + issuerURL: server.URL, + clientID: "client_id", + clientSecret: "client_secret", + redirectURL: "https://example.com/callback", + allowedUsers: []string{"user1", "user2"}, wantErr: false, }, + { + name: "Success - Valid configuration with groups", + issuerURL: server.URL, + clientID: "client_id", + clientSecret: "client_secret", + redirectURL: "https://example.com/callback", + allowedGroups: []string{"group1", "group2"}, + wantErr: false, + }, + { + name: "Fail - No allowed users or allowed groups", + issuerURL: "https://example.com", + clientID: "client_id", + clientSecret: "client_secret", + redirectURL: "https://example.com/callback", + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Cleanup(cleanup) - err := InitOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL) + _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups) if (err != nil) != tt.wantErr { t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) } }) } } + +func TestCheckToken(t *testing.T) { + provider := setupProvider(t) + + tests := []struct { + name string + allowedUsers []string + allowedGroups []string + claims jwt.Claims + wantErr error + }{ + { + name: "Success - Valid token with allowed user", + allowedUsers: []string{"user1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + }, + { + name: "Success - Valid token with allowed group", + allowedGroups: []string{"group1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + }, + { + name: "Success - Server omits groups, but user is allowed", + allowedUsers: []string{"user1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + }, + }, + { + name: "Success - Server omits preferred_username, but group is allowed", + allowedGroups: []string{"group1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "groups": []string{"group1"}, + }, + }, + { + name: "Success - Valid token with allowed user and group", + allowedUsers: []string{"user1"}, + allowedGroups: []string{"group1"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + }, + { + name: "Error - User not allowed", + allowedUsers: []string{"user2", "user3"}, + allowedGroups: []string{"group2", "group3"}, + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrUserNotAllowed, + }, + { + name: "Error - Server returns incorrect issuer", + claims: jwt.MapClaims{ + "iss": "https://example.com", + "aud": clientID, + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrInvalidToken, + }, + { + name: "Error - Server returns incorrect audience", + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": "some-other-audience", + "exp": time.Now().Add(time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrInvalidToken, + }, + { + name: "Error - Server returns expired token", + claims: jwt.MapClaims{ + "iss": provider.ts.URL, + "aud": clientID, + "exp": time.Now().Add(-time.Hour).Unix(), + "preferred_username": "user1", + "groups": []string{"group1"}, + }, + wantErr: ErrInvalidToken, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create the Auth Provider. + auth := &OIDCProvider{ + oidcVerifier: provider.verifier, + allowedUsers: tc.allowedUsers, + allowedGroups: tc.allowedGroups, + } + // Sign the claims to create a token. + signedToken := provider.SignClaims(t, tc.claims) + // Craft a test HTTP request that includes the token as a cookie. + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{ + Name: auth.TokenCookieName(), + Value: signedToken, + }) + + // Call CheckToken and verify the result. + err := auth.CheckToken(req) + if tc.wantErr == nil { + ExpectNoError(t, err) + } else { + ExpectError(t, tc.wantErr, err) + } + }) + } +} diff --git a/internal/api/v1/auth/provider.go b/internal/api/v1/auth/provider.go new file mode 100644 index 00000000..8ea4d320 --- /dev/null +++ b/internal/api/v1/auth/provider.go @@ -0,0 +1,12 @@ +package auth + +import ( + "net/http" +) + +type Provider interface { + TokenCookieName() string + CheckToken(r *http.Request) error + RedirectLoginPage(w http.ResponseWriter, r *http.Request) + LoginCallbackHandler(w http.ResponseWriter, r *http.Request) +} diff --git a/internal/api/v1/auth/userpass.go b/internal/api/v1/auth/userpass.go index 6d00e6a7..ae80c1c7 100644 --- a/internal/api/v1/auth/userpass.go +++ b/internal/api/v1/auth/userpass.go @@ -1,13 +1,17 @@ package auth import ( - "bytes" "encoding/json" + "fmt" "net/http" + "time" + "github.com/golang-jwt/jwt/v5" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" + "golang.org/x/crypto/bcrypt" ) var ( @@ -15,31 +19,120 @@ var ( ErrInvalidPassword = E.New("invalid password") ) -func validatePassword(cred *Credentials) error { - if cred.Username != common.APIUser { - return ErrInvalidUsername.Subject(cred.Username) +type ( + UserPassAuth struct { + username string + pwdHash []byte + secret []byte + tokenTTL time.Duration } - if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { - return ErrInvalidPassword.Subject(cred.Password) + UserPassClaims struct { + Username string `json:"username"` + jwt.RegisteredClaims } +) + +func NewUserPassAuth(username, password string, secret []byte, tokenTTL time.Duration) (*UserPassAuth, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + return &UserPassAuth{ + username: username, + pwdHash: hash, + secret: secret, + tokenTTL: tokenTTL, + }, nil +} + +func NewUserPassAuthFromEnv() (*UserPassAuth, error) { + return NewUserPassAuth( + common.APIUser, + common.APIPassword, + common.APIJWTSecret, + common.APIJWTTokenTTL, + ) +} + +func (auth *UserPassAuth) TokenCookieName() string { + return "godoxy_token" +} + +func (auth *UserPassAuth) NewToken() (token string, err error) { + claim := &UserPassClaims{ + Username: auth.username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(auth.tokenTTL)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS512, claim) + token, err = tok.SignedString(auth.secret) + if err != nil { + return "", err + } + return token, nil +} + +func (auth *UserPassAuth) CheckToken(r *http.Request) error { + jwtCookie, err := r.Cookie(auth.TokenCookieName()) + if err != nil { + return ErrMissingToken + } + var claims UserPassClaims + token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return auth.secret, nil + }) + if err != nil { + return err + } + switch { + case !token.Valid: + return ErrInvalidToken + case claims.Username != auth.username: + return ErrUserNotAllowed.Subject(claims.Username) + case claims.ExpiresAt.Before(time.Now()): + return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) + } + return nil } -// UserPassLoginHandler handles user login. -func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) { - var creds Credentials +func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) +} + +func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) { + var creds struct { + User string `json:"username"` + Pass string `json:"password"` + } err := json.NewDecoder(r.Body).Decode(&creds) if err != nil { U.HandleErr(w, r, err, http.StatusBadRequest) return } - if err := validatePassword(&creds); err != nil { + if err := auth.validatePassword(creds.User, creds.Pass); err != nil { U.HandleErr(w, r, err, http.StatusUnauthorized) return } - if err := setAuthenticatedCookie(w, creds.Username); err != nil { + token, err := auth.NewToken() + if err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return } + setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL) w.WriteHeader(http.StatusOK) } + +func (auth *UserPassAuth) validatePassword(user, pass string) error { + if user != auth.username { + return ErrInvalidUsername.Subject(user) + } + if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil { + return ErrInvalidPassword.With(err).Subject(pass) + } + return nil +} diff --git a/internal/api/v1/auth/userpass_test.go b/internal/api/v1/auth/userpass_test.go new file mode 100644 index 00000000..c43360e7 --- /dev/null +++ b/internal/api/v1/auth/userpass_test.go @@ -0,0 +1,116 @@ +package auth + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/utils/testing" + "golang.org/x/crypto/bcrypt" +) + +func newMockUserPassAuth() *UserPassAuth { + return &UserPassAuth{ + username: "username", + pwdHash: E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)), + secret: []byte("abcdefghijklmnopqrstuvwxyz"), + tokenTTL: time.Hour, + } +} + +func TestUserPassValidateCredentials(t *testing.T) { + auth := newMockUserPassAuth() + err := auth.validatePassword("username", "password") + ExpectNoError(t, err) + err = auth.validatePassword("username", "wrong-password") + ExpectError(t, ErrInvalidPassword, err) + err = auth.validatePassword("wrong-username", "password") + ExpectError(t, ErrInvalidUsername, err) +} + +func TestUserPassCheckToken(t *testing.T) { + auth := newMockUserPassAuth() + token, err := auth.NewToken() + ExpectNoError(t, err) + tests := []struct { + token string + wantErr bool + }{ + { + token: token, + wantErr: false, + }, + { + token: "invalid-token", + wantErr: true, + }, + { + token: "", + wantErr: true, + }, + } + for _, tt := range tests { + req := &http.Request{Header: http.Header{}} + if tt.token != "" { + req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token) + } + err = auth.CheckToken(req) + if tt.wantErr { + ExpectTrue(t, err != nil) + } else { + ExpectNoError(t, err) + } + } +} + +func TestUserPassLoginCallbackHandler(t *testing.T) { + type cred struct { + User string `json:"username"` + Pass string `json:"password"` + } + auth := newMockUserPassAuth() + tests := []struct { + creds cred + wantErr bool + }{ + { + creds: cred{ + User: "username", + Pass: "password", + }, + wantErr: false, + }, + { + creds: cred{ + User: "username", + Pass: "wrong-password", + }, + wantErr: true, + }, + } + for _, tt := range tests { + w := httptest.NewRecorder() + req := &http.Request{ + Host: "app.example.com", + Body: io.NopCloser(bytes.NewReader(E.Must(json.Marshal(tt.creds)))), + } + auth.LoginCallbackHandler(w, req) + if tt.wantErr { + ExpectEqual(t, w.Code, http.StatusUnauthorized) + } else { + setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) + ExpectTrue(t, setCookie.Name == auth.TokenCookieName()) + ExpectTrue(t, setCookie.Value != "") + ExpectEqual(t, setCookie.Domain, "example.com") + ExpectEqual(t, setCookie.Path, "/") + ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) + ExpectEqual(t, setCookie.HttpOnly, true) + ExpectEqual(t, w.Code, http.StatusOK) + } + } +} diff --git a/internal/api/v1/auth/utils.go b/internal/api/v1/auth/utils.go new file mode 100644 index 00000000..1d57de13 --- /dev/null +++ b/internal/api/v1/auth/utils.go @@ -0,0 +1,70 @@ +package auth + +import ( + "net" + "net/http" + "time" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +var ( + ErrMissingToken = E.New("missing token") + ErrInvalidToken = E.New("invalid token") + ErrUserNotAllowed = E.New("user not allowed") +) + +// cookieFQDN returns the fully qualified domain name of the request host +// with subdomain stripped. +// +// If the request host does not have a subdomain, +// an empty string is returned +// +// "abc.example.com" -> "example.com" +// "example.com" -> "" +func cookieFQDN(r *http.Request) string { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + parts := strutils.SplitRune(host, '.') + if len(parts) < 2 { + return "" + } + parts[0] = "" + return strutils.JoinRune(parts, '.') +} + +func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: value, + MaxAge: int(ttl.Seconds()), + Domain: cookieFQDN(r), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + Path: "/", + }) +} + +func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + MaxAge: -1, + Domain: cookieFQDN(r), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + Path: "/", + }) +} + +func LogoutCallbackHandler(auth Provider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + clearTokenCookie(w, r, auth.TokenCookieName()) + auth.RedirectLoginPage(w, r) + } +} diff --git a/internal/api/v1/utils/logging.go b/internal/api/v1/utils/logging.go index ac795b8c..194735f5 100644 --- a/internal/api/v1/utils/logging.go +++ b/internal/api/v1/utils/logging.go @@ -11,6 +11,7 @@ func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { return logging.WithLevel(level). Str("module", "api"). Str("remote", r.RemoteAddr). + Str("host", r.Host). Str("uri", r.Method+" "+r.RequestURI) } diff --git a/internal/common/crypto.go b/internal/common/crypto.go index 5afd0bfe..6214a572 100644 --- a/internal/common/crypto.go +++ b/internal/common/crypto.go @@ -1,18 +1,11 @@ package common import ( - "crypto/sha512" "encoding/base64" "github.com/rs/zerolog/log" ) -func HashPassword(pwd string) []byte { - h := sha512.New() - h.Write([]byte(pwd)) - return h.Sum(nil) -} - func decodeJWTKey(key string) []byte { if key == "" { return nil diff --git a/internal/common/env.go b/internal/common/env.go index 0cc17553..4d13afe2 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/utils/strutils" ) var ( @@ -43,17 +44,19 @@ var ( MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http") PrometheusEnabled = MetricsHTTPURL != "" - APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) - APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) - APIUser = GetEnvString("API_USER", "admin") - APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password")) + APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) + APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) + APIUser = GetEnvString("API_USER", "admin") + APIPassword = GetEnvString("API_PASSWORD", "password") // OIDC Configuration. - OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") - OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") - OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") - OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") - OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") + OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") + OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") + OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") + OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") + OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") + OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "") + OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "") ) func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T { @@ -105,3 +108,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str func GetDurationEnv(key string, defaultValue time.Duration) time.Duration { return GetEnv(key, defaultValue, time.ParseDuration) } + +func GetCommaSepEnv(key string, defaultValue string) []string { + return strutils.CommaSeperatedList(GetEnvString(key, defaultValue)) +} diff --git a/internal/common/random.go b/internal/common/random.go deleted file mode 100644 index ea4586f1..00000000 --- a/internal/common/random.go +++ /dev/null @@ -1,13 +0,0 @@ -package common - -import ( - "crypto/rand" - "encoding/base64" -) - -// GenerateRandomString generates a random string of specified length. -func GenerateRandomString(length int) string { - b := make([]byte, length) - rand.Read(b) - return base64.URLEncoding.EncodeToString(b)[:length] -} diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index a2f5c0a6..faed4898 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -89,7 +89,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Then scraper / scanners will know the subdomain is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. if served := middleware.ServeStaticErrorPageFile(w, r); !served { - logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") + logger.Err(err). + Str("method", r.Method). + Str("url", r.URL.String()). + Str("remote", r.RemoteAddr). + Msg("request") errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) if ok { w.WriteHeader(http.StatusNotFound) diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 78410d36..a2719722 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "reflect" + "sort" "strings" E "github.com/yusing/go-proxy/internal/error" @@ -26,28 +27,50 @@ type ( name string construct ImplNewFunc impl any + // priority is only applied for ReverseProxy. + // + // Middleware compose follows the order of the slice + // + // Default is 10, 0 is the highest + priority int } + ByPriority []*Middleware RequestModifier interface { before(w http.ResponseWriter, r *http.Request) (proceed bool) } - ResponseModifier interface{ modifyResponse(r *http.Response) error } - MiddlewareWithSetup interface{ setup() } - MiddlewareFinalizer interface{ finalize() } + ResponseModifier interface{ modifyResponse(r *http.Response) error } + MiddlewareWithSetup interface{ setup() } + MiddlewareFinalizer interface{ finalize() } + MiddlewareFinalizerWithError interface { + finalize() error + } MiddlewareWithTracer interface { enableTrace() getTracer() *Tracer } ) +const DefaultPriority = 10 + +func (m ByPriority) Len() int { return len(m) } +func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] } +func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority } + func NewMiddleware[ImplType any]() *Middleware { // type check - switch any(new(ImplType)).(type) { + t := any(new(ImplType)) + switch t.(type) { case RequestModifier: case ResponseModifier: default: panic("must implement RequestModifier or ResponseModifier") } + _, hasFinializer := t.(MiddlewareFinalizer) + _, hasFinializerWithError := t.(MiddlewareFinalizerWithError) + if hasFinializer && hasFinializerWithError { + panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive") + } return &Middleware{ name: strings.ToLower(reflect.TypeFor[ImplType]().Name()), construct: func() any { return new(ImplType) }, @@ -84,13 +107,29 @@ func (m *Middleware) apply(optsRaw OptionsRaw) E.Error { if len(optsRaw) == 0 { return nil } + priority, ok := optsRaw["priority"].(int) + if ok { + m.priority = priority + // remove priority for deserialization, restore later + delete(optsRaw, "priority") + defer func() { + optsRaw["priority"] = priority + }() + } else { + m.priority = DefaultPriority + } return utils.Deserialize(optsRaw, m.impl) } -func (m *Middleware) finalize() { +func (m *Middleware) finalize() error { if finalizer, ok := m.impl.(MiddlewareFinalizer); ok { finalizer.finalize() + return nil } + if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok { + return finalizer.finalize() + } + return nil } func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { @@ -105,7 +144,9 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { if err := mid.apply(optsRaw); err != nil { return nil, err } - mid.finalize() + if err := mid.finalize(); err != nil { + return nil, E.From(err) + } return mid, nil } @@ -119,8 +160,9 @@ func (m *Middleware) String() string { func (m *Middleware) MarshalJSON() ([]byte, error) { return json.MarshalIndent(map[string]any{ - "name": m.name, - "options": m.impl, + "name": m.name, + "options": m.impl, + "priority": m.priority, }, "", " ") } @@ -193,6 +235,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) ( } func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { + sort.Sort(ByPriority(middlewares)) middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...) mid := NewMiddlewareChain(rp.TargetName, middlewares) diff --git a/internal/net/http/middleware/middleware_test.go b/internal/net/http/middleware/middleware_test.go new file mode 100644 index 00000000..5b6e5218 --- /dev/null +++ b/internal/net/http/middleware/middleware_test.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "net/http" + "strconv" + "strings" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +type testPriority struct { + Value int `json:"value"` +} + +var test = NewMiddleware[testPriority]() + +func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool { + w.Header().Add("Test-Value", strconv.Itoa(t.Value)) + return true +} + +func TestMiddlewarePriority(t *testing.T) { + priorities := []int{4, 7, 9, 0} + chain := make([]*Middleware, len(priorities)) + for i, p := range priorities { + mid, err := test.New(OptionsRaw{ + "priority": p, + "value": i, + }) + ExpectNoError(t, err) + chain[i] = mid + } + res, err := newMiddlewaresTest(chain, nil) + ExpectNoError(t, err) + ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2") +} diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index cfea2c35..1785083c 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -14,6 +14,8 @@ import ( var allMiddlewares = map[string]*Middleware{ "redirecthttp": RedirectHTTP, + "oidc": OIDC, + "request": ModifyRequest, "modifyrequest": ModifyRequest, "response": ModifyResponse, diff --git a/internal/net/http/middleware/oidc.go b/internal/net/http/middleware/oidc.go new file mode 100644 index 00000000..8e1b3e67 --- /dev/null +++ b/internal/net/http/middleware/oidc.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "net/http" + + "github.com/yusing/go-proxy/internal/api/v1/auth" + E "github.com/yusing/go-proxy/internal/error" +) + +type oidcMiddleware struct { + AllowedUsers []string `json:"allowed_users"` + AllowedGroups []string `json:"allowed_groups"` + + auth auth.Provider + authMux *http.ServeMux + logoutHandler http.HandlerFunc +} + +var OIDC = NewMiddleware[oidcMiddleware]() + +func (amw *oidcMiddleware) finalize() error { + if !auth.IsOIDCEnabled() { + return E.New("OIDC not enabled but ODIC middleware is used") + } + authProvider, err := auth.NewOIDCProviderFromEnv() + if err != nil { + return err + } + + authProvider.SetIsMiddleware(true) + if len(amw.AllowedUsers) > 0 { + authProvider.SetAllowedUsers(amw.AllowedUsers) + } + if len(amw.AllowedGroups) > 0 { + authProvider.SetAllowedGroups(amw.AllowedGroups) + } + + amw.authMux = http.NewServeMux() + amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) + amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + }) + amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage) + amw.logoutHandler = auth.LogoutCallbackHandler(authProvider) + amw.auth = authProvider + return nil +} + +func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + if err := amw.auth.CheckToken(r); err != nil { + amw.authMux.ServeHTTP(w, r) + return false + } + if r.URL.Path == auth.OIDCLogoutPath { + amw.logoutHandler(w, r) + return false + } + return true +} diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index dceeb39a..9c8ce3a4 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -127,6 +127,20 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E } args.setDefaults() + mid, setOptErr := middleware.New(args.middlewareOpt) + if setOptErr != nil { + return nil, setOptErr + } + + return newMiddlewaresTest([]*Middleware{mid}, args) +} + +func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) { + if args == nil { + args = new(testArgs) + } + args.setDefaults() + req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader()) for k, v := range args.headers { req.Header[k] = v @@ -139,14 +153,8 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E rr.parent = http.DefaultTransport } - rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr) - - mid, setOptErr := middleware.New(args.middlewareOpt) - if setOptErr != nil { - return nil, setOptErr - } - - patchReverseProxy(rp, []*Middleware{mid}) + rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr) + patchReverseProxy(rp, middlewares) rp.ServeHTTP(w, req) resp := w.Result() diff --git a/internal/utils/slices.go b/internal/utils/slices.go new file mode 100644 index 00000000..afe2914c --- /dev/null +++ b/internal/utils/slices.go @@ -0,0 +1,20 @@ +package utils + +// Intersect returns a new slice containing the elements that are present in both input slices. +// This provides a more efficient solution than using two nested loops. +func Intersect[T comparable, Slice ~[]T](slice1 Slice, slice2 Slice) Slice { + var result Slice + seen := map[T]struct{}{} + + for i := range slice1 { + seen[slice1[i]] = struct{}{} + } + + for i := range slice2 { + if _, ok := seen[slice2[i]]; ok { + result = append(result, slice2[i]) + } + } + + return result +} diff --git a/internal/utils/slices_test.go b/internal/utils/slices_test.go new file mode 100644 index 00000000..8d2a1f1a --- /dev/null +++ b/internal/utils/slices_test.go @@ -0,0 +1,96 @@ +package utils + +import ( + "slices" + "strings" + "testing" + + utils "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestIntersect(t *testing.T) { + t.Run("strings", func(t *testing.T) { + t.Run("no intersection", func(t *testing.T) { + var ( + slice1 = []string{"a", "b", "c"} + slice2 = []string{"d", "e", "f"} + want []string + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + t.Run("intersection", func(t *testing.T) { + var ( + slice1 = []string{"a", "b", "c"} + slice2 = []string{"b", "c", "d"} + want = []string{"b", "c"} + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + }) + t.Run("ints", func(t *testing.T) { + t.Run("no intersection", func(t *testing.T) { + var ( + slice1 = []int{1, 2, 3} + slice2 = []int{4, 5, 6} + want []int + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + t.Run("intersection", func(t *testing.T) { + var ( + slice1 = []int{1, 2, 3} + slice2 = []int{2, 3, 4} + want = []int{2, 3} + ) + result := Intersect(slice1, slice2) + slices.Sort(result) + slices.Sort(want) + utils.ExpectDeepEqual(t, result, want) + }) + }) + t.Run("complex", func(t *testing.T) { + type T struct { + A string + B int + } + t.Run("no intersection", func(t *testing.T) { + var ( + slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}} + slice2 = []T{{"d", 4}, {"e", 5}, {"f", 6}} + want []T + ) + result := Intersect(slice1, slice2) + slices.SortFunc(result, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + slices.SortFunc(want, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + utils.ExpectDeepEqual(t, result, want) + }) + t.Run("intersection", func(t *testing.T) { + var ( + slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}} + slice2 = []T{{"b", 2}, {"c", 3}, {"d", 4}} + want = []T{{"b", 2}, {"c", 3}} + ) + result := Intersect(slice1, slice2) + slices.SortFunc(result, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + slices.SortFunc(want, func(i T, j T) int { + return strings.Compare(i.A, j.A) + }) + utils.ExpectDeepEqual(t, result, want) + }) + }) +} diff --git a/internal/utils/strutils/string.go b/internal/utils/strutils/string.go index 4664c2b1..18f78c61 100644 --- a/internal/utils/strutils/string.go +++ b/internal/utils/strutils/string.go @@ -10,6 +10,9 @@ import ( // CommaSeperatedList returns a list of strings split by commas, // then trim spaces from each element. func CommaSeperatedList(s string) []string { + if s == "" { + return []string{} + } res := SplitComma(s) for i, part := range res { res[i] = strings.TrimSpace(part) diff --git a/next-release.md b/next-release.md index ae1989a5..e398b33e 100644 --- a/next-release.md +++ b/next-release.md @@ -73,6 +73,26 @@ GoDoxy v0.8.2 expected changes * Connection #0 to host localhost left intact ``` +- **Thanks [polds](https://github.com/polds)** + Support WebUI authentication via OIDC by setting these environment variables: + - `GODOXY_OIDC_ISSUER_URL` e.g.: + - Pocket ID: `https://pocker-id.yourdomain.com` + - Authentik: `https://authentik.yourdomain.com/application/o//` **The ending slash is required** + - `GODOXY_OIDC_CLIENT_ID` + - `GODOXY_OIDC_CLIENT_SECRET` + - `GODOXY_OIDC_REDIRECT_URL` + - `GODOXY_OIDC_SCOPES` _(optional)_ + - `GODOXY_OIDC_ALLOWED_USERS` + +- Use OpenID Connect to authenticate GoDoxy's WebUI and all your services (SSO) + ```yaml + # default + proxy.app.middlewares.oidc: + + # override allowed users + proxy.app.middlewares.oidc.allowed_users: user1, user2 + ``` + - Caddyfile like rules ```yaml