diff --git a/cmd/main.go b/cmd/main.go index 0fb6f3a3..c1a88eba 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -109,16 +109,16 @@ func main() { return } - if common.APIJWTSecret == nil { - logging.Warn().Msg("API JWT secret is empty, authentication is disabled") - } - cfg.Start() config.WatchChanges() - // Initialize authentication providers - if err := auth.Initialize(); err != nil { - logging.Warn().Err(err).Msg("Failed to initialize authentication providers") + if !auth.IsEnabled() { + logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") + } else { + // Initialize authentication providers + if err := auth.Initialize(); err != nil { + logging.Fatal().Err(err).Msg("Failed to initialize authentication providers") + } } sig := make(chan os.Signal, 1) diff --git a/internal/api/handler.go b/internal/api/handler.go index 512c5feb..cab7a1fb 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -22,9 +22,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { mux := ServeMux{http.NewServeMux()} mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) - mux.HandleFunc("POST", "/v1/login", auth.LoginHandler) - mux.HandleFunc("GET", "/v1/login/method", auth.AuthMethodHandler) - mux.HandleFunc("GET", "/v1/login/oidc", auth.OIDCLoginHandler) + mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler) + mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler) mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) diff --git a/internal/api/v1/auth/auth.go b/internal/api/v1/auth/auth.go index 23062db2..26876387 100644 --- a/internal/api/v1/auth/auth.go +++ b/internal/api/v1/auth/auth.go @@ -1,8 +1,6 @@ package auth import ( - "bytes" - "encoding/json" "fmt" "net/http" "time" @@ -25,51 +23,37 @@ type ( } ) -var ( - ErrInvalidUsername = E.New("invalid username") - ErrInvalidPassword = E.New("invalid password") -) - -func validatePassword(cred *Credentials) error { - if cred.Username != common.APIUser { - return ErrInvalidUsername.Subject(cred.Username) - } - if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { - return ErrInvalidPassword.Subject(cred.Password) +// Initialize sets up authentication providers. +func Initialize() error { + // Initialize OIDC if configured. + if common.OIDCIssuerURL != "" { + return InitOIDC( + common.OIDCIssuerURL, + common.OIDCClientID, + common.OIDCClientSecret, + common.OIDCRedirectURL, + ) } return nil } -func LoginHandler(w http.ResponseWriter, r *http.Request) { - var creds Credentials - err := json.NewDecoder(r.Body).Decode(&creds) - if err != nil { - U.HandleErr(w, r, err, http.StatusBadRequest) - return - } - if err := validatePassword(&creds); err != nil { - U.HandleErr(w, r, err, http.StatusUnauthorized) - return - } - if err := setAuthenticatedCookie(w, creds.Username); err != nil { - U.HandleErr(w, r, err, http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) +func IsEnabled() bool { + return common.APIJWTSecret != nil || common.OIDCIssuerURL != "" } -func AuthMethodHandler(w http.ResponseWriter, r *http.Request) { +// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration. +func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) { switch { + case oauthConfig != nil: + RedirectOIDC(w, r) + return case common.APIJWTSecret == nil: - U.WriteBody(w, []byte("skip")) - case common.OIDCIssuerURL != "": - U.WriteBody(w, []byte("oidc")) - case common.APIPasswordHash != nil: - U.WriteBody(w, []byte("password")) + http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) + return default: U.WriteBody(w, []byte("skip")) + w.WriteHeader(http.StatusOK) } - w.WriteHeader(http.StatusOK) } func setAuthenticatedCookie(w http.ResponseWriter, username string) error { @@ -86,57 +70,44 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error { return err } http.SetCookie(w, &http.Cookie{ - Name: "token", + Name: CookieToken, Value: tokenStr, Expires: expiresAt, HttpOnly: true, + Secure: true, SameSite: http.SameSiteStrictMode, Path: "/", }) return nil } +// LogoutHandler clear authentication cookie and redirect to login page. func LogoutHandler(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ - Name: "token", + Name: CookieToken, Value: "", Expires: time.Unix(0, 0), HttpOnly: true, + Secure: true, SameSite: http.SameSiteStrictMode, Path: "/", }) - w.Header().Set("location", "/login") - w.WriteHeader(http.StatusTemporaryRedirect) -} - -// Initialize sets up authentication providers. -func Initialize() error { - // Initialize OIDC if configured. - if common.OIDCIssuerURL != "" { - return InitOIDC( - common.OIDCIssuerURL, - common.OIDCClientID, - common.OIDCClientSecret, - common.OIDCRedirectURL, - ) - } - return nil + AuthRedirectHandler(w, r) } func RequireAuth(next http.HandlerFunc) http.HandlerFunc { - if common.IsDebugSkipAuth || common.APIJWTSecret == nil { - return next - } - - return func(w http.ResponseWriter, r *http.Request) { - if checkToken(w, r) { - next(w, r) + if IsEnabled() { + return func(w http.ResponseWriter, r *http.Request) { + if checkToken(w, r) { + next(w, r) + } } } + return next } func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { - tokenCookie, err := r.Cookie("token") + tokenCookie, err := r.Cookie(CookieToken) if err != nil { U.RespondError(w, E.New("missing token"), http.StatusUnauthorized) return false diff --git a/internal/api/v1/auth/cookies.go b/internal/api/v1/auth/cookies.go new file mode 100644 index 00000000..74315d68 --- /dev/null +++ b/internal/api/v1/auth/cookies.go @@ -0,0 +1,6 @@ +package auth + +const ( + CookieToken = "token" + CookieOauthState = "oauth_state" +) diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index 50b11e86..d5cffff5 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -4,10 +4,8 @@ import ( "context" "fmt" "net/http" - "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/golang-jwt/jwt/v5" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" @@ -47,8 +45,8 @@ func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error { return nil } -// OIDCLoginHandler initiates the OIDC login flow. -func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) { +// RedirectOIDC initiates the OIDC login flow. +func RedirectOIDC(w http.ResponseWriter, r *http.Request) { if oauthConfig == nil { U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) return @@ -56,7 +54,7 @@ func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) { state := common.GenerateRandomString(32) http.SetCookie(w, &http.Cookie{ - Name: "oauth_state", + Name: CookieOauthState, Value: state, MaxAge: 300, HttpOnly: true, @@ -87,7 +85,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { return } - state, err := r.Cookie("oauth_state") + state, err := r.Cookie(CookieOauthState) if err != nil { U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) return @@ -137,7 +135,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { // handleTestCallback handles OIDC callback in test environment. func handleTestCallback(w http.ResponseWriter, r *http.Request) { - state, err := r.Cookie("oauth_state") + state, err := r.Cookie(CookieOauthState) if err != nil { U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) return @@ -149,29 +147,10 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) { } // Create test JWT token - expiresAt := time.Now().Add(common.APIJWTTokenTTL) - jwtClaims := &Claims{ - Username: "test-user", - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expiresAt), - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS512, jwtClaims) - tokenStr, err := token.SignedString(common.APIJWTSecret) - if err != nil { + if err := setAuthenticatedCookie(w, "test-user"); err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return } - http.SetCookie(w, &http.Cookie{ - Name: "token", - Value: tokenStr, - Expires: expiresAt, - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - Path: "/", - }) - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index c5fb54a3..1e1f986d 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -68,10 +68,10 @@ func TestOIDCLoginHandler(t *testing.T) { oauthConfig = nil } - req := httptest.NewRequest(http.MethodGet, "/login/oidc", nil) + req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) w := httptest.NewRecorder() - OIDCLoginHandler(w, req) + RedirectOIDC(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) diff --git a/internal/api/v1/auth/userpass.go b/internal/api/v1/auth/userpass.go new file mode 100644 index 00000000..6d00e6a7 --- /dev/null +++ b/internal/api/v1/auth/userpass.go @@ -0,0 +1,45 @@ +package auth + +import ( + "bytes" + "encoding/json" + "net/http" + + U "github.com/yusing/go-proxy/internal/api/v1/utils" + "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" +) + +var ( + ErrInvalidUsername = E.New("invalid username") + ErrInvalidPassword = E.New("invalid password") +) + +func validatePassword(cred *Credentials) error { + if cred.Username != common.APIUser { + return ErrInvalidUsername.Subject(cred.Username) + } + if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { + return ErrInvalidPassword.Subject(cred.Password) + } + return nil +} + +// UserPassLoginHandler handles user login. +func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) { + var creds Credentials + err := json.NewDecoder(r.Body).Decode(&creds) + if err != nil { + U.HandleErr(w, r, err, http.StatusBadRequest) + return + } + if err := validatePassword(&creds); err != nil { + U.HandleErr(w, r, err, http.StatusUnauthorized) + return + } + if err := setAuthenticatedCookie(w, creds.Username); err != nil { + U.HandleErr(w, r, err, http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} diff --git a/internal/api/v1/utils/error.go b/internal/api/v1/utils/error.go index 3f095dbd..0b886893 100644 --- a/internal/api/v1/utils/error.go +++ b/internal/api/v1/utils/error.go @@ -7,7 +7,7 @@ import ( "github.com/yusing/go-proxy/internal/utils/strutils/ansi" ) -// HandleErr logs the error and returns an HTTP error response to the client. +// HandleErr logs the error and returns an error code to the client. // If code is specified, it will be used as the HTTP status code; otherwise, // http.StatusInternalServerError is used. // @@ -23,10 +23,14 @@ func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) { http.Error(w, http.StatusText(code[0]), code[0]) } +// RespondError returns error details to the client. +// If code is specified, it will be used as the HTTP status code; otherwise, +// http.StatusBadRequest is used. func RespondError(w http.ResponseWriter, err error, code ...int) { if len(code) == 0 { code = []int{http.StatusBadRequest} } + // strip ANSI color codes added from Error.WithSubject http.Error(w, ansi.StripANSI(err.Error()), code[0]) } diff --git a/internal/api/v1/utils/utils.go b/internal/api/v1/utils/utils.go index 94f9fb01..8cdfcaa5 100644 --- a/internal/api/v1/utils/utils.go +++ b/internal/api/v1/utils/utils.go @@ -11,7 +11,7 @@ import ( func WriteBody(w http.ResponseWriter, body []byte) { if _, err := w.Write(body); err != nil { - HandleErr(w, nil, err) + logging.Err(err).Msg("failed to write body") } } diff --git a/internal/common/env.go b/internal/common/env.go index 9e23ec58..97eff076 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -14,11 +14,10 @@ import ( var ( prefixes = []string{"GODOXY_", "GOPROXY_", ""} - IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test") - IsDebug = GetEnvBool("DEBUG", IsTest) - IsDebugSkipAuth = GetEnvBool("DEBUG_SKIP_AUTH", false) - IsTrace = GetEnvBool("TRACE", false) && IsDebug - IsProduction = !IsTest && !IsDebug + IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test") + IsDebug = GetEnvBool("DEBUG", IsTest) + IsTrace = GetEnvBool("TRACE", false) && IsDebug + IsProduction = !IsTest && !IsDebug ProxyHTTPAddr, ProxyHTTPHost, @@ -46,7 +45,7 @@ var ( APIUser = GetEnvString("API_USER", "admin") APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password")) - // OIDC Configuration + // OIDC Configuration. OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") diff --git a/internal/route/provider/docker_test.go b/internal/route/provider/docker_test.go index e077b8fc..823c9ea8 100644 --- a/internal/route/provider/docker_test.go +++ b/internal/route/provider/docker_test.go @@ -31,6 +31,7 @@ func makeEntries(cont *types.Container, dockerHostIP ...string) route.RawEntries } else { host = client.DefaultDockerHost } + p.name = "test" entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(cont, host))) entries.RangeAll(func(k string, v *route.RawEntry) { v.Finalize()