Merge branch 'main' into dev

This commit is contained in:
yusing
2025-12-22 10:45:44 +08:00
8 changed files with 48 additions and 45 deletions

View File

@@ -12,11 +12,12 @@ var blockPageHTML string
var blockPageTemplate = template.Must(template.New("block_page").Parse(blockPageHTML)) var blockPageTemplate = template.Must(template.New("block_page").Parse(blockPageHTML))
func WriteBlockPage(w http.ResponseWriter, status int, error string, logoutURL string) { func WriteBlockPage(w http.ResponseWriter, status int, errorMessage, actionText, actionURL string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
blockPageTemplate.Execute(w, map[string]string{ blockPageTemplate.Execute(w, map[string]string{
"StatusText": http.StatusText(status), "StatusText": http.StatusText(status),
"Error": error, "Error": errorMessage,
"LogoutURL": logoutURL, "ActionURL": actionURL,
"ActionText": actionText,
}) })
} }

View File

@@ -1,14 +1,14 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Access Denied</title> <title>Access Denied</title>
</head> </head>
<body> <body>
<h1>{{.StatusText}}</h1> <h1>{{.StatusText}}</h1>
<p>{{.Error}}</p> <p>{{.Error}}</p>
<a href="{{.LogoutURL}}">Logout</a> <a href="{{.ActionURL}}">{{.ActionText}}</a>
</body> </body>
</html> </html>

View File

@@ -32,6 +32,8 @@ type (
allowedUsers []string allowedUsers []string
allowedGroups []string allowedGroups []string
rateLimit *rate.Limiter
onUnknownPathHandler http.HandlerFunc onUnknownPathHandler http.HandlerFunc
} }
@@ -123,6 +125,7 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, all
endSessionURL: endSessionURL, endSessionURL: endSessionURL,
allowedUsers: allowedUsers, allowedUsers: allowedUsers,
allowedGroups: allowedGroups, allowedGroups: allowedGroups,
rateLimit: rate.NewLimiter(rate.Every(common.OIDCRateLimitPeriod), common.OIDCRateLimit),
}, nil }, nil
} }
@@ -165,6 +168,7 @@ func NewOIDCProviderWithCustomClient(baseProvider *OIDCProvider, clientID, clien
endSessionURL: baseProvider.endSessionURL, endSessionURL: baseProvider.endSessionURL,
allowedUsers: baseProvider.allowedUsers, allowedUsers: baseProvider.allowedUsers,
allowedGroups: baseProvider.allowedGroups, allowedGroups: baseProvider.allowedGroups,
rateLimit: baseProvider.rateLimit,
}, nil }, nil
} }
@@ -228,9 +232,12 @@ func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
} }
} }
var rateLimit = rate.NewLimiter(rate.Every(time.Second), 1)
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
if !httputils.GetAccept(r.Header).AcceptHTML() {
http.Error(w, "authentication is required", http.StatusForbidden)
return
}
// check for session token // check for session token
sessionToken, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken)) sessionToken, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthSessionToken))
if err == nil { // session token exists if err == nil { // session token exists
@@ -250,8 +257,8 @@ func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if !rateLimit.Allow() { if !auth.rateLimit.Allow() {
http.Error(w, "auth rate limit exceeded", http.StatusTooManyRequests) WriteBlockPage(w, http.StatusTooManyRequests, "auth rate limit exceeded", "Try again", OIDCAuthInitPath)
return return
} }
@@ -318,34 +325,39 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http
// verify state // verify state
state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState)) state, err := r.Cookie(auth.getAppScopedCookieName(CookieOauthState))
if err != nil { if err != nil {
http.Error(w, "missing state cookie", http.StatusBadRequest) auth.clearCookie(w, r)
WriteBlockPage(w, http.StatusBadRequest, "missing state cookie", "Back to Login", OIDCAuthInitPath)
return return
} }
if r.URL.Query().Get("state") != state.Value { if r.URL.Query().Get("state") != state.Value {
http.Error(w, "invalid oauth state", http.StatusBadRequest) auth.clearCookie(w, r)
WriteBlockPage(w, http.StatusBadRequest, "invalid oauth state", "Back to Login", OIDCAuthInitPath)
return return
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r)) oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r))
if err != nil { if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) auth.clearCookie(w, r)
httputils.LogError(r).Msg(fmt.Sprintf("failed to exchange token: %v", err)) WriteBlockPage(w, http.StatusInternalServerError, "failed to exchange token", "Try again", OIDCAuthInitPath)
httputils.LogError(r).Msgf("failed to exchange token: %v", err)
return return
} }
idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token) idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token)
if err != nil { if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) auth.clearCookie(w, r)
httputils.LogError(r).Msg(fmt.Sprintf("failed to get ID token: %v", err)) WriteBlockPage(w, http.StatusInternalServerError, "failed to get ID token", "Try again", OIDCAuthInitPath)
httputils.LogError(r).Msgf("failed to get ID token: %v", err)
return return
} }
if oauth2Token.RefreshToken != "" { if oauth2Token.RefreshToken != "" {
claims, err := parseClaims(idToken) claims, err := parseClaims(idToken)
if err != nil { if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) auth.clearCookie(w, r)
httputils.LogError(r).Msg(fmt.Sprintf("failed to parse claims: %v", err)) WriteBlockPage(w, http.StatusInternalServerError, "failed to parse claims", "Try again", OIDCAuthInitPath)
httputils.LogError(r).Msgf("failed to parse claims: %v", err)
return return
} }
session := newSession(claims.Username, claims.Groups) session := newSession(claims.Username, claims.Groups)

View File

@@ -39,12 +39,14 @@ var (
DebugDisableAuth = env.GetEnvBool("DEBUG_DISABLE_AUTH", false) DebugDisableAuth = env.GetEnvBool("DEBUG_DISABLE_AUTH", false)
// OIDC Configuration. // OIDC Configuration.
OIDCIssuerURL = env.GetEnvString("OIDC_ISSUER_URL", "") OIDCIssuerURL = env.GetEnvString("OIDC_ISSUER_URL", "")
OIDCClientID = env.GetEnvString("OIDC_CLIENT_ID", "") OIDCClientID = env.GetEnvString("OIDC_CLIENT_ID", "")
OIDCClientSecret = env.GetEnvString("OIDC_CLIENT_SECRET", "") OIDCClientSecret = env.GetEnvString("OIDC_CLIENT_SECRET", "")
OIDCScopes = env.GetEnvCommaSep("OIDC_SCOPES", "openid, profile, email, groups") OIDCScopes = env.GetEnvCommaSep("OIDC_SCOPES", "openid, profile, email, groups")
OIDCAllowedUsers = env.GetEnvCommaSep("OIDC_ALLOWED_USERS", "") OIDCAllowedUsers = env.GetEnvCommaSep("OIDC_ALLOWED_USERS", "")
OIDCAllowedGroups = env.GetEnvCommaSep("OIDC_ALLOWED_GROUPS", "") OIDCAllowedGroups = env.GetEnvCommaSep("OIDC_ALLOWED_GROUPS", "")
OIDCRateLimit = env.GetEnvInt("OIDC_RATE_LIMIT", 10)
OIDCRateLimitPeriod = env.GetEnvDuation("OIDC_RATE_LIMIT_PERIOD", time.Second)
// metrics configuration // metrics configuration
MetricsDisableCPU = env.GetEnvBool("METRICS_DISABLE_CPU", false) MetricsDisableCPU = env.GetEnvBool("METRICS_DISABLE_CPU", false)

View File

@@ -90,8 +90,7 @@ func Reload() gperr.Error {
if err != nil { if err != nil {
newState.Task().FinishAndWait(err) newState.Task().FinishAndWait(err)
config.WorkingState.Store(GetState()) config.WorkingState.Store(GetState())
logNotifyError("reload", err) return gperr.Wrap(err, ansi.Warning("using last config"))
return gperr.New(ansi.Warning("using last config")).With(err)
} }
// flush temporary log // flush temporary log
@@ -117,7 +116,7 @@ func WatchChanges() {
configEventFlushInterval, configEventFlushInterval,
OnConfigChange, OnConfigChange,
func(err gperr.Error) { func(err gperr.Error) {
logNotifyError("config reload", err) logNotifyError("reload", err)
}, },
) )
eventQueue.Start(cfgWatcher.Events(t.Context())) eventQueue.Start(cfgWatcher.Events(t.Context()))

View File

@@ -110,8 +110,6 @@ func Clients() map[string]*SharedClient {
return clients return clients
} }
var versionArg = client.WithAPIVersionNegotiation()
// NewClient creates a new Docker client connection to the specified host. // NewClient creates a new Docker client connection to the specified host.
// //
// Returns existing client if available. // Returns existing client if available.
@@ -154,7 +152,6 @@ func NewClient(host string, unique ...bool) (*SharedClient, error) {
opt = []client.Opt{ opt = []client.Opt{
client.WithHost(agent.DockerHost), client.WithHost(agent.DockerHost),
client.WithHTTPClient(cfg.NewHTTPClient()), client.WithHTTPClient(cfg.NewHTTPClient()),
versionArg,
} }
addr = "tcp://" + cfg.Addr addr = "tcp://" + cfg.Addr
dial = cfg.DialContext dial = cfg.DialContext
@@ -165,7 +162,6 @@ func NewClient(host string, unique ...bool) (*SharedClient, error) {
case common.DockerHostFromEnv: case common.DockerHostFromEnv:
opt = []client.Opt{ opt = []client.Opt{
client.WithHostFromEnv(), client.WithHostFromEnv(),
versionArg,
} }
default: default:
helper, err := connhelper.GetConnectionHelper(host) helper, err := connhelper.GetConnectionHelper(host)
@@ -173,21 +169,13 @@ func NewClient(host string, unique ...bool) (*SharedClient, error) {
log.Panic().Err(err).Msg("failed to get connection helper") log.Panic().Err(err).Msg("failed to get connection helper")
} }
if helper != nil { if helper != nil {
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: helper.Dialer,
},
}
opt = []client.Opt{ opt = []client.Opt{
client.WithHTTPClient(httpClient),
client.WithHost(helper.Host), client.WithHost(helper.Host),
versionArg,
client.WithDialContext(helper.Dialer), client.WithDialContext(helper.Dialer),
} }
} else { } else {
opt = []client.Opt{ opt = []client.Opt{
client.WithHost(host), client.WithHost(host),
versionArg,
} }
} }
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/serialization" "github.com/yusing/godoxy/internal/serialization"
httputils "github.com/yusing/goutils/http" httputils "github.com/yusing/goutils/http"
"github.com/yusing/goutils/intern"
strutils "github.com/yusing/goutils/strings" strutils "github.com/yusing/goutils/strings"
"github.com/yusing/goutils/synk" "github.com/yusing/goutils/synk"
"github.com/yusing/goutils/task" "github.com/yusing/goutils/task"
@@ -402,7 +403,7 @@ func UpdateSelfhstIcons(m IconMap) error {
} }
icon := &IconMeta{ icon := &IconMeta{
DisplayName: item.Name, DisplayName: item.Name,
Tag: tag, Tag: intern.Make(tag).Value(),
SVG: item.SVG == "Yes", SVG: item.SVG == "Yes",
PNG: item.PNG == "Yes", PNG: item.PNG == "Yes",
WebP: item.WebP == "Yes", WebP: item.WebP == "Yes",

View File

@@ -117,7 +117,7 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
case errors.Is(err, auth.ErrMissingOAuthToken): case errors.Is(err, auth.ErrMissingOAuthToken):
amw.auth.HandleAuth(w, r) amw.auth.HandleAuth(w, r)
default: default:
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), "Logout", auth.OIDCLogoutPath)
} }
return false return false
} }