refactor(forwardauth): finalize middleware implementation with better headers handling

This commit is contained in:
yusing
2025-09-02 22:58:13 +08:00
parent c550255458
commit a2d4c468cd

View File

@@ -2,13 +2,14 @@ package middleware
import ( import (
"context" "context"
"errors"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/rs/zerolog/log" "github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
"github.com/yusing/go-proxy/internal/route/routes" "github.com/yusing/go-proxy/internal/route/routes"
"github.com/yusing/go-proxy/internal/utils"
) )
type ( type (
@@ -17,10 +18,11 @@ type (
} }
ForwardAuthMiddlewareOpts struct { ForwardAuthMiddlewareOpts struct {
ForwardAuthRoute string `json:"forwardauth_route"` // default: "tinyauth" Route string `json:"route" validate:"required"` // route name (alias), default: "tinyauth"
ForwardAuthLogin string `json:"forwardauth_login"` // the redirect login path, e.g. "/login?redirect_uri=" AuthEndpoint string `json:"auth_endpoint" validate:"required,uri"` // default: "/api/auth/nginx"
ForwardAuthEndpoint string `json:"forwardauth_endpoint"` // default: "/api/auth/nginx" AuthResponseHeaders []string `json:"headers"` // additional headers to forward from auth server to upstream, e.g. ["Remote-User", "Remote-Name"]
ForwardAuthHeaders []string `json:"forwardauth_headers"` // additional headers to forward from auth server to upstream, e.g. ["Remote-User", "Remote-Name"]
httpClient *http.Client
} }
) )
@@ -28,81 +30,103 @@ var ForwardAuth = NewMiddleware[forwardAuthMiddleware]()
func (m *forwardAuthMiddleware) setup() { func (m *forwardAuthMiddleware) setup() {
m.ForwardAuthMiddlewareOpts = ForwardAuthMiddlewareOpts{ m.ForwardAuthMiddlewareOpts = ForwardAuthMiddlewareOpts{
ForwardAuthRoute: "tinyauth", Route: "tinyauth",
ForwardAuthLogin: "/login?redirect_uri=", AuthEndpoint: "/api/auth/traefik",
ForwardAuthEndpoint: "/api/auth/nginx", AuthResponseHeaders: []string{"Remote-User", "Remote-Name", "Remote-Email", "Remote-Groups"},
ForwardAuthHeaders: []string{"Remote-User", "Remote-Name", "Remote-Email", "Remote-Groups"}, httpClient: &http.Client{
Timeout: 5 * time.Second,
// do not follow redirects, we handle them in the middleware
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
} }
} }
// before implements RequestModifier. // before implements RequestModifier.
func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
route, ok := routes.HTTP.Get(m.ForwardAuthRoute) route, ok := routes.HTTP.Get(m.Route)
if !ok { if !ok {
log.Warn().Str("route", m.ForwardAuthRoute).Msg("forwardauth route not found") ForwardAuth.LogWarn(r).Str("route", m.Route).Msg("forwardauth route not found")
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return false return false
} }
forwardAuthUrl := *route.TargetURL() forwardAuthURL := *route.TargetURL()
forwardAuthUrl.Path = m.ForwardAuthEndpoint forwardAuthURL.Path = m.AuthEndpoint
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, forwardAuthUrl.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, forwardAuthURL.String(), nil)
if err != nil { if err != nil {
log.Err(err).Msg("failed to create request") ForwardAuth.LogError(r).Err(err).Msg("failed to create request")
w.WriteHeader(http.StatusInternalServerError)
return false return false
} }
xff, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
xff = r.RemoteAddr
}
proto := "http"
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
proto = "https"
}
req.Header = r.Header.Clone() req.Header = r.Header.Clone()
req.Header.Set("X-Forwarded-Proto", r.Proto) req.Header.Set("X-Forwarded-For", xff)
req.Header.Set("X-Forwarded-Proto", proto)
req.Header.Set("X-Forwarded-Host", r.Host) req.Header.Set("X-Forwarded-Host", r.Host)
req.Header.Set("X-Forwarded-Uri", r.URL.RequestURI()) req.Header.Set("X-Forwarded-Uri", r.URL.RequestURI())
resp, err := http.DefaultClient.Do(req) //nolint:gosec resp, err := m.httpClient.Do(req)
if err != nil { if err != nil {
log.Err(err).Msg("failed to connect to forwardauth server") ForwardAuth.LogError(r).Err(err).Msg("failed to connect to forwardauth server")
w.WriteHeader(http.StatusInternalServerError)
return false return false
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == http.StatusOK { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
for _, h := range m.ForwardAuthHeaders { body, release, err := utils.ReadAllBody(resp)
if v := resp.Header.Get(h); v != "" { defer release()
w.Header().Set(h, v)
}
}
return true
}
if resp.StatusCode == http.StatusUnauthorized {
host, _, err := net.SplitHostPort(r.Host)
if err != nil { if err != nil {
host = r.Host ForwardAuth.LogError(r).Err(err).Msg("failed to read response body")
w.WriteHeader(http.StatusInternalServerError)
return false
} }
scheme := "http://" httpheaders.CopyHeader(w.Header(), resp.Header)
if r.TLS != nil { httpheaders.RemoveHopByHopHeaders(w.Header())
scheme = "https://"
}
parts := strings.Split(host, ".") loc, err := resp.Location()
if len(parts) > 2 { if err != nil {
host = strings.Join(parts[len(parts)-2:], ".") if !errors.Is(err, http.ErrNoLocation) {
ForwardAuth.LogError(r).Err(err).Msg("failed to get location")
w.WriteHeader(http.StatusInternalServerError)
return false
}
} else if loc := loc.String(); loc != "" {
r.Header.Set("Location", loc)
} }
w.WriteHeader(resp.StatusCode)
redirectUrl := scheme + m.ForwardAuthRoute + "." + host + m.ForwardAuthLogin + scheme + r.Host + r.URL.RequestURI() _, err = w.Write(body)
http.Redirect(w, r, redirectUrl, http.StatusPermanentRedirect) if err != nil {
ForwardAuth.LogError(r).Err(err).Msg("failed to write response body")
}
return false return false
} }
if resp.StatusCode == http.StatusForbidden { for _, h := range m.AuthResponseHeaders {
http.Error(w, "Forbidden", http.StatusForbidden) if v := resp.Header.Get(h); v != "" {
return false // NOTE: need to set the header to the original request to forward to upstream
r.Header.Set(h, v)
}
} }
return true
return false
} }