diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 7e96e860..84feec8c 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -317,20 +317,23 @@ func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http code := r.URL.Query().Get("code") oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r)) if err != nil { - gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + gphttp.LogError(r).Msg(fmt.Sprintf("failed to exchange token: %v", err)) return } idTokenJWT, idToken, err := auth.getIDToken(r.Context(), oauth2Token) if err != nil { - gphttp.ServerError(w, r, err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + gphttp.LogError(r).Msg(fmt.Sprintf("failed to get ID token: %v", err)) return } if oauth2Token.RefreshToken != "" { claims, err := parseClaims(idToken) if err != nil { - gphttp.ServerError(w, r, err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + gphttp.LogError(r).Msg(fmt.Sprintf("failed to parse claims: %v", err)) return } session := newSession(claims.Username, claims.Groups) diff --git a/internal/auth/userpass.go b/internal/auth/userpass.go index cdbf1ec3..16db865e 100644 --- a/internal/auth/userpass.go +++ b/internal/auth/userpass.go @@ -121,7 +121,8 @@ func (auth *UserPassAuth) PostAuthCallbackHandler(w http.ResponseWriter, r *http } token, err := auth.NewToken() if err != nil { - gphttp.ServerError(w, r, err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + gphttp.LogError(r).Msg(fmt.Sprintf("failed to generate token: %v", err)) return } SetTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL) diff --git a/internal/net/gphttp/error.go b/internal/net/gphttp/error.go deleted file mode 100644 index 8300a5ba..00000000 --- a/internal/net/gphttp/error.go +++ /dev/null @@ -1,32 +0,0 @@ -package gphttp - -import ( - "context" - "errors" - "net/http" - "syscall" - - "github.com/yusing/godoxy/internal/net/gphttp/httpheaders" -) - -// ServerError is for handling server errors. -// -// It logs the error and returns http.StatusInternalServerError to the client. -// Status code can be specified as an argument. -func ServerError(w http.ResponseWriter, r *http.Request, err error, code ...int) { - switch { - case err == nil, - errors.Is(err, context.Canceled), - errors.Is(err, syscall.EPIPE), - errors.Is(err, syscall.ECONNRESET): - return - } - LogError(r).Msg(err.Error()) - if httpheaders.IsWebsocket(r.Header) { - return - } - if len(code) == 0 { - code = []int{http.StatusInternalServerError} - } - http.Error(w, http.StatusText(code[0]), code[0]) -} diff --git a/internal/net/gphttp/server/utils.go b/internal/net/gphttp/server/utils.go index 76e59836..57f9a7fd 100644 --- a/internal/net/gphttp/server/utils.go +++ b/internal/net/gphttp/server/utils.go @@ -1,15 +1,19 @@ package server import ( + "context" + "errors" "log" "log/slog" "net/http" + "syscall" "github.com/quic-go/quic-go/http3" "github.com/rs/zerolog" slogzerolog "github.com/samber/slog-zerolog/v2" "github.com/yusing/godoxy/internal/common" "github.com/yusing/godoxy/internal/net/gphttp" + "github.com/yusing/goutils/http/httpheaders" ) func advertiseHTTP3(handler http.Handler, h3 *http3.Server) http.Handler { @@ -17,7 +21,17 @@ func advertiseHTTP3(handler http.Handler, h3 *http3.Server) http.Handler { if r.ProtoMajor < 3 { err := h3.SetQUICHeaders(w.Header()) if err != nil { - gphttp.ServerError(w, r, err) + switch { + case errors.Is(err, context.Canceled), + errors.Is(err, syscall.EPIPE), + errors.Is(err, syscall.ECONNRESET): + return + } + gphttp.LogError(r).Msg(err.Error()) + if httpheaders.IsWebsocket(r.Header) { + return + } + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } }