diff --git a/internal/net/gphttp/middleware/forwardauth.go b/internal/net/gphttp/middleware/forwardauth.go index c105285d..7b16ee00 100644 --- a/internal/net/gphttp/middleware/forwardauth.go +++ b/internal/net/gphttp/middleware/forwardauth.go @@ -2,13 +2,14 @@ package middleware import ( "context" + "errors" "net" "net/http" - "strings" "time" - "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders" "github.com/yusing/go-proxy/internal/route/routes" + "github.com/yusing/go-proxy/internal/utils" ) type ( @@ -17,10 +18,11 @@ type ( } ForwardAuthMiddlewareOpts struct { - ForwardAuthRoute string `json:"forwardauth_route"` // default: "tinyauth" - ForwardAuthLogin string `json:"forwardauth_login"` // the redirect login path, e.g. "/login?redirect_uri=" - ForwardAuthEndpoint string `json:"forwardauth_endpoint"` // default: "/api/auth/nginx" - ForwardAuthHeaders []string `json:"forwardauth_headers"` // additional headers to forward from auth server to upstream, e.g. ["Remote-User", "Remote-Name"] + Route string `json:"route" validate:"required"` // route name (alias), default: "tinyauth" + AuthEndpoint string `json:"auth_endpoint" validate:"required,uri"` // default: "/api/auth/nginx" + AuthResponseHeaders []string `json:"headers"` // additional headers to forward from auth server to upstream, e.g. ["Remote-User", "Remote-Name"] + + httpClient *http.Client } ) @@ -28,81 +30,103 @@ var ForwardAuth = NewMiddleware[forwardAuthMiddleware]() func (m *forwardAuthMiddleware) setup() { m.ForwardAuthMiddlewareOpts = ForwardAuthMiddlewareOpts{ - ForwardAuthRoute: "tinyauth", - ForwardAuthLogin: "/login?redirect_uri=", - ForwardAuthEndpoint: "/api/auth/nginx", - ForwardAuthHeaders: []string{"Remote-User", "Remote-Name", "Remote-Email", "Remote-Groups"}, + Route: "tinyauth", + AuthEndpoint: "/api/auth/traefik", + AuthResponseHeaders: []string{"Remote-User", "Remote-Name", "Remote-Email", "Remote-Groups"}, + httpClient: &http.Client{ + Timeout: 5 * time.Second, + // do not follow redirects, we handle them in the middleware + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, } } // before implements RequestModifier. func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - route, ok := routes.HTTP.Get(m.ForwardAuthRoute) + route, ok := routes.HTTP.Get(m.Route) if !ok { - log.Warn().Str("route", m.ForwardAuthRoute).Msg("forwardauth route not found") + ForwardAuth.LogWarn(r).Str("route", m.Route).Msg("forwardauth route not found") w.WriteHeader(http.StatusInternalServerError) return false } - forwardAuthUrl := *route.TargetURL() - forwardAuthUrl.Path = m.ForwardAuthEndpoint + forwardAuthURL := *route.TargetURL() + forwardAuthURL.Path = m.AuthEndpoint - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, forwardAuthUrl.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, forwardAuthURL.String(), nil) if err != nil { - log.Err(err).Msg("failed to create request") + ForwardAuth.LogError(r).Err(err).Msg("failed to create request") + w.WriteHeader(http.StatusInternalServerError) return false } + xff, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + xff = r.RemoteAddr + } + + proto := "http" + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + proto = "https" + } + req.Header = r.Header.Clone() - req.Header.Set("X-Forwarded-Proto", r.Proto) + req.Header.Set("X-Forwarded-For", xff) + req.Header.Set("X-Forwarded-Proto", proto) req.Header.Set("X-Forwarded-Host", r.Host) req.Header.Set("X-Forwarded-Uri", r.URL.RequestURI()) - resp, err := http.DefaultClient.Do(req) //nolint:gosec + resp, err := m.httpClient.Do(req) if err != nil { - log.Err(err).Msg("failed to connect to forwardauth server") + ForwardAuth.LogError(r).Err(err).Msg("failed to connect to forwardauth server") + w.WriteHeader(http.StatusInternalServerError) return false } defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { - for _, h := range m.ForwardAuthHeaders { - if v := resp.Header.Get(h); v != "" { - w.Header().Set(h, v) - } - } - return true - } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + body, release, err := utils.ReadAllBody(resp) + defer release() - if resp.StatusCode == http.StatusUnauthorized { - host, _, err := net.SplitHostPort(r.Host) if err != nil { - host = r.Host + ForwardAuth.LogError(r).Err(err).Msg("failed to read response body") + w.WriteHeader(http.StatusInternalServerError) + return false } - scheme := "http://" - if r.TLS != nil { - scheme = "https://" - } + httpheaders.CopyHeader(w.Header(), resp.Header) + httpheaders.RemoveHopByHopHeaders(w.Header()) - parts := strings.Split(host, ".") - if len(parts) > 2 { - host = strings.Join(parts[len(parts)-2:], ".") + loc, err := resp.Location() + if err != nil { + if !errors.Is(err, http.ErrNoLocation) { + ForwardAuth.LogError(r).Err(err).Msg("failed to get location") + w.WriteHeader(http.StatusInternalServerError) + return false + } + } else if loc := loc.String(); loc != "" { + r.Header.Set("Location", loc) } + w.WriteHeader(resp.StatusCode) - redirectUrl := scheme + m.ForwardAuthRoute + "." + host + m.ForwardAuthLogin + scheme + r.Host + r.URL.RequestURI() - http.Redirect(w, r, redirectUrl, http.StatusPermanentRedirect) + _, err = w.Write(body) + if err != nil { + ForwardAuth.LogError(r).Err(err).Msg("failed to write response body") + } return false } - if resp.StatusCode == http.StatusForbidden { - http.Error(w, "Forbidden", http.StatusForbidden) - return false + for _, h := range m.AuthResponseHeaders { + if v := resp.Header.Get(h); v != "" { + // NOTE: need to set the header to the original request to forward to upstream + r.Header.Set(h, v) + } } - - return false + return true }