diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 7ee3575e..4728f4db 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -32,6 +32,8 @@ type ( allowedUsers []string allowedGroups []string + rateLimit *rate.Limiter + onUnknownPathHandler http.HandlerFunc } @@ -123,6 +125,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all endSessionURL: endSessionURL, allowedUsers: allowedUsers, allowedGroups: allowedGroups, + rateLimit: rate.NewLimiter(rate.Every(common.OIDCRateLimitPeriod), common.OIDCRateLimit), }, nil } @@ -165,6 +168,7 @@ func NewOIDCProviderWithCustomClient(baseProvider *OIDCProvider, clientID, clien endSessionURL: baseProvider.endSessionURL, allowedUsers: baseProvider.allowedUsers, allowedGroups: baseProvider.allowedGroups, + rateLimit: baseProvider.rateLimit, }, nil } @@ -228,8 +232,6 @@ func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { } } -var rateLimit = rate.NewLimiter(rate.Every(time.Second), 1) - func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { if !httputils.GetAccept(r.Header).AcceptHTML() { http.Error(w, "authentication is required", http.StatusForbidden) @@ -255,7 +257,7 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { return } - if !rateLimit.Allow() { + if !auth.rateLimit.Allow() { WriteBlockPage(w, http.StatusTooManyRequests, "auth rate limit exceeded", "Try again", OIDCAuthInitPath) return } diff --git a/internal/common/env.go b/internal/common/env.go index f96f8a31..121aa16a 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -39,12 +39,14 @@ var ( DebugDisableAuth = env.GetEnvBool("DEBUG_DISABLE_AUTH", false) // OIDC Configuration. - OIDCIssuerURL = env.GetEnvString("OIDC_ISSUER_URL", "") - OIDCClientID = env.GetEnvString("OIDC_CLIENT_ID", "") - OIDCClientSecret = env.GetEnvString("OIDC_CLIENT_SECRET", "") - OIDCScopes = env.GetEnvCommaSep("OIDC_SCOPES", "openid, profile, email, groups") - OIDCAllowedUsers = env.GetEnvCommaSep("OIDC_ALLOWED_USERS", "") - OIDCAllowedGroups = env.GetEnvCommaSep("OIDC_ALLOWED_GROUPS", "") + OIDCIssuerURL = env.GetEnvString("OIDC_ISSUER_URL", "") + OIDCClientID = env.GetEnvString("OIDC_CLIENT_ID", "") + OIDCClientSecret = env.GetEnvString("OIDC_CLIENT_SECRET", "") + OIDCScopes = env.GetEnvCommaSep("OIDC_SCOPES", "openid, profile, email, groups") + OIDCAllowedUsers = env.GetEnvCommaSep("OIDC_ALLOWED_USERS", "") + OIDCAllowedGroups = env.GetEnvCommaSep("OIDC_ALLOWED_GROUPS", "") + OIDCRateLimit = env.GetEnvInt("OIDC_RATE_LIMIT", 10) + OIDCRateLimitPeriod = env.GetEnvDuation("OIDC_RATE_LIMIT_PERIOD", time.Second) // metrics configuration MetricsDisableCPU = env.GetEnvBool("METRICS_DISABLE_CPU", false)