diff --git a/internal/auth/block_page.go b/internal/auth/block_page.go index 9d5c0b57..8813a90d 100644 --- a/internal/auth/block_page.go +++ b/internal/auth/block_page.go @@ -12,11 +12,12 @@ var blockPageHTML string var blockPageTemplate = template.Must(template.New("block_page").Parse(blockPageHTML)) -func WriteBlockPage(w http.ResponseWriter, status int, error string, logoutURL string) { +func WriteBlockPage(w http.ResponseWriter, status int, errorMessage, actionText, actionURL string) { w.Header().Set("Content-Type", "text/html; charset=utf-8") blockPageTemplate.Execute(w, map[string]string{ "StatusText": http.StatusText(status), - "Error": error, - "LogoutURL": logoutURL, + "Error": errorMessage, + "ActionURL": actionURL, + "ActionText": actionText, }) } diff --git a/internal/auth/block_page.html b/internal/auth/block_page.html index 195cc13e..02445e09 100644 --- a/internal/auth/block_page.html +++ b/internal/auth/block_page.html @@ -1,14 +1,14 @@ -
- - + + +{{.Error}}
- Logout - + {{.ActionText}} + diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 6352e898..4728f4db 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -32,6 +32,8 @@ type ( allowedUsers []string allowedGroups []string + rateLimit *rate.Limiter + onUnknownPathHandler http.HandlerFunc } @@ -123,6 +125,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all endSessionURL: endSessionURL, allowedUsers: allowedUsers, allowedGroups: allowedGroups, + rateLimit: rate.NewLimiter(rate.Every(common.OIDCRateLimitPeriod), common.OIDCRateLimit), }, nil } @@ -165,6 +168,7 @@ func NewOIDCProviderWithCustomClient(baseProvider *OIDCProvider, clientID, clien endSessionURL: baseProvider.endSessionURL, allowedUsers: baseProvider.allowedUsers, allowedGroups: baseProvider.allowedGroups, + rateLimit: baseProvider.rateLimit, }, nil } @@ -228,9 +232,12 @@ func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { } } -var rateLimit = rate.NewLimiter(rate.Every(time.Second), 1) - func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { + if !httputils.GetAccept(r.Header).AcceptHTML() { + http.Error(w, "authentication is required", http.StatusForbidden) + return + } + // check for session token sessionToken, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken)) if err == nil { // session token exists @@ -250,8 +257,8 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { return } - if !rateLimit.Allow() { - http.Error(w, "auth rate limit exceeded", http.StatusTooManyRequests) + if !auth.rateLimit.Allow() { + WriteBlockPage(w, http.StatusTooManyRequests, "auth rate limit exceeded", "Try again", OIDCAuthInitPath) return } @@ -318,34 +325,39 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http // verify state state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState)) if err != nil { - http.Error(w, "missing state cookie", http.StatusBadRequest) + auth.clearCookie(w, r) + WriteBlockPage(w, http.StatusBadRequest, "missing state cookie", "Back to Login", OIDCAuthInitPath) return } if r.URL.Query().Get("state") != state.Value { - http.Error(w, "invalid oauth state", http.StatusBadRequest) + auth.clearCookie(w, r) + WriteBlockPage(w, http.StatusBadRequest, "invalid oauth state", "Back to Login", OIDCAuthInitPath) return } code := r.URL.Query().Get("code") oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r)) if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - httputils.LogError(r).Msg(fmt.Sprintf("failed to exchange token: %v", err)) + auth.clearCookie(w, r) + WriteBlockPage(w, http.StatusInternalServerError, "failed to exchange token", "Try again", OIDCAuthInitPath) + httputils.LogError(r).Msgf("failed to exchange token: %v", err) return } idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token) if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - httputils.LogError(r).Msg(fmt.Sprintf("failed to get ID token: %v", err)) + auth.clearCookie(w, r) + WriteBlockPage(w, http.StatusInternalServerError, "failed to get ID token", "Try again", OIDCAuthInitPath) + httputils.LogError(r).Msgf("failed to get ID token: %v", err) return } if oauth2Token.RefreshToken != "" { claims, err := parseClaims(idToken) if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - httputils.LogError(r).Msg(fmt.Sprintf("failed to parse claims: %v", err)) + auth.clearCookie(w, r) + WriteBlockPage(w, http.StatusInternalServerError, "failed to parse claims", "Try again", OIDCAuthInitPath) + httputils.LogError(r).Msgf("failed to parse claims: %v", err) return } session := newSession(claims.Username, claims.Groups) diff --git a/internal/common/env.go b/internal/common/env.go index f96f8a31..121aa16a 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -39,12 +39,14 @@ var ( DebugDisableAuth = env.GetEnvBool("DEBUG_DISABLE_AUTH", false) // OIDC Configuration. - OIDCIssuerURL = env.GetEnvString("OIDC_ISSUER_URL", "") - OIDCClientID = env.GetEnvString("OIDC_CLIENT_ID", "") - OIDCClientSecret = env.GetEnvString("OIDC_CLIENT_SECRET", "") - OIDCScopes = env.GetEnvCommaSep("OIDC_SCOPES", "openid, profile, email, groups") - OIDCAllowedUsers = env.GetEnvCommaSep("OIDC_ALLOWED_USERS", "") - OIDCAllowedGroups = env.GetEnvCommaSep("OIDC_ALLOWED_GROUPS", "") + OIDCIssuerURL = env.GetEnvString("OIDC_ISSUER_URL", "") + OIDCClientID = env.GetEnvString("OIDC_CLIENT_ID", "") + OIDCClientSecret = env.GetEnvString("OIDC_CLIENT_SECRET", "") + OIDCScopes = env.GetEnvCommaSep("OIDC_SCOPES", "openid, profile, email, groups") + OIDCAllowedUsers = env.GetEnvCommaSep("OIDC_ALLOWED_USERS", "") + OIDCAllowedGroups = env.GetEnvCommaSep("OIDC_ALLOWED_GROUPS", "") + OIDCRateLimit = env.GetEnvInt("OIDC_RATE_LIMIT", 10) + OIDCRateLimitPeriod = env.GetEnvDuation("OIDC_RATE_LIMIT_PERIOD", time.Second) // metrics configuration MetricsDisableCPU = env.GetEnvBool("METRICS_DISABLE_CPU", false) diff --git a/internal/config/events.go b/internal/config/events.go index 5bccc603..46755645 100644 --- a/internal/config/events.go +++ b/internal/config/events.go @@ -90,8 +90,7 @@ func Reload() gperr.Error { if err != nil { newState.Task().FinishAndWait(err) config.WorkingState.Store(GetState()) - logNotifyError("reload", err) - return gperr.New(ansi.Warning("using last config")).With(err) + return gperr.Wrap(err, ansi.Warning("using last config")) } // flush temporary log @@ -117,7 +116,7 @@ func WatchChanges() { configEventFlushInterval, OnConfigChange, func(err gperr.Error) { - logNotifyError("config reload", err) + logNotifyError("reload", err) }, ) eventQueue.Start(cfgWatcher.Events(t.Context())) diff --git a/internal/docker/client.go b/internal/docker/client.go index 19cd3158..f8fa0041 100644 --- a/internal/docker/client.go +++ b/internal/docker/client.go @@ -110,8 +110,6 @@ func Clients() map[string]*SharedClient { return clients } -var versionArg = client.WithAPIVersionNegotiation() - // NewClient creates a new Docker client connection to the specified host. // // Returns existing client if available. @@ -154,7 +152,6 @@ func NewClient(host string, unique ...bool) (*SharedClient, error) { opt = []client.Opt{ client.WithHost(agent.DockerHost), client.WithHTTPClient(cfg.NewHTTPClient()), - versionArg, } addr = "tcp://" + cfg.Addr dial = cfg.DialContext @@ -165,7 +162,6 @@ func NewClient(host string, unique ...bool) (*SharedClient, error) { case common.DockerHostFromEnv: opt = []client.Opt{ client.WithHostFromEnv(), - versionArg, } default: helper, err := connhelper.GetConnectionHelper(host) @@ -173,21 +169,13 @@ func NewClient(host string, unique ...bool) (*SharedClient, error) { log.Panic().Err(err).Msg("failed to get connection helper") } if helper != nil { - httpClient := &http.Client{ - Transport: &http.Transport{ - DialContext: helper.Dialer, - }, - } opt = []client.Opt{ - client.WithHTTPClient(httpClient), client.WithHost(helper.Host), - versionArg, client.WithDialContext(helper.Dialer), } } else { opt = []client.Opt{ client.WithHost(host), - versionArg, } } } diff --git a/internal/homepage/list_icons.go b/internal/homepage/list_icons.go index 5b131114..2c465ca3 100644 --- a/internal/homepage/list_icons.go +++ b/internal/homepage/list_icons.go @@ -14,6 +14,7 @@ import ( "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/serialization" httputils "github.com/yusing/goutils/http" + "github.com/yusing/goutils/intern" strutils "github.com/yusing/goutils/strings" "github.com/yusing/goutils/synk" "github.com/yusing/goutils/task" @@ -402,7 +403,7 @@ func UpdateSelfhstIcons(m IconMap) error { } icon := &IconMeta{ DisplayName: item.Name, - Tag: tag, + Tag: intern.Make(tag).Value(), SVG: item.SVG == "Yes", PNG: item.PNG == "Yes", WebP: item.WebP == "Yes", diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 74c341d7..6406de4a 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -117,7 +117,7 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce case errors.Is(err, auth.ErrMissingOAuthToken): amw.auth.HandleAuth(w, r) default: - auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) + auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), "Logout", auth.OIDCLogoutPath) } return false }