mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-24 01:08:31 +02:00
refactor(forwardauth): finalize middleware implementation with better headers handling
This commit is contained in:
@@ -2,13 +2,14 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/yusing/go-proxy/internal/net/gphttp/httpheaders"
|
||||||
"github.com/yusing/go-proxy/internal/route/routes"
|
"github.com/yusing/go-proxy/internal/route/routes"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -17,10 +18,11 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
ForwardAuthMiddlewareOpts struct {
|
ForwardAuthMiddlewareOpts struct {
|
||||||
ForwardAuthRoute string `json:"forwardauth_route"` // default: "tinyauth"
|
Route string `json:"route" validate:"required"` // route name (alias), default: "tinyauth"
|
||||||
ForwardAuthLogin string `json:"forwardauth_login"` // the redirect login path, e.g. "/login?redirect_uri="
|
AuthEndpoint string `json:"auth_endpoint" validate:"required,uri"` // default: "/api/auth/nginx"
|
||||||
ForwardAuthEndpoint string `json:"forwardauth_endpoint"` // default: "/api/auth/nginx"
|
AuthResponseHeaders []string `json:"headers"` // additional headers to forward from auth server to upstream, e.g. ["Remote-User", "Remote-Name"]
|
||||||
ForwardAuthHeaders []string `json:"forwardauth_headers"` // additional headers to forward from auth server to upstream, e.g. ["Remote-User", "Remote-Name"]
|
|
||||||
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,81 +30,103 @@ var ForwardAuth = NewMiddleware[forwardAuthMiddleware]()
|
|||||||
|
|
||||||
func (m *forwardAuthMiddleware) setup() {
|
func (m *forwardAuthMiddleware) setup() {
|
||||||
m.ForwardAuthMiddlewareOpts = ForwardAuthMiddlewareOpts{
|
m.ForwardAuthMiddlewareOpts = ForwardAuthMiddlewareOpts{
|
||||||
ForwardAuthRoute: "tinyauth",
|
Route: "tinyauth",
|
||||||
ForwardAuthLogin: "/login?redirect_uri=",
|
AuthEndpoint: "/api/auth/traefik",
|
||||||
ForwardAuthEndpoint: "/api/auth/nginx",
|
AuthResponseHeaders: []string{"Remote-User", "Remote-Name", "Remote-Email", "Remote-Groups"},
|
||||||
ForwardAuthHeaders: []string{"Remote-User", "Remote-Name", "Remote-Email", "Remote-Groups"},
|
httpClient: &http.Client{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
// do not follow redirects, we handle them in the middleware
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// before implements RequestModifier.
|
// before implements RequestModifier.
|
||||||
func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
func (m *forwardAuthMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
route, ok := routes.HTTP.Get(m.ForwardAuthRoute)
|
route, ok := routes.HTTP.Get(m.Route)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Warn().Str("route", m.ForwardAuthRoute).Msg("forwardauth route not found")
|
ForwardAuth.LogWarn(r).Str("route", m.Route).Msg("forwardauth route not found")
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardAuthUrl := *route.TargetURL()
|
forwardAuthURL := *route.TargetURL()
|
||||||
forwardAuthUrl.Path = m.ForwardAuthEndpoint
|
forwardAuthURL.Path = m.AuthEndpoint
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, forwardAuthUrl.String(), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, forwardAuthURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msg("failed to create request")
|
ForwardAuth.LogError(r).Err(err).Msg("failed to create request")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xff, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
xff = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := "http"
|
||||||
|
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||||
|
proto = "https"
|
||||||
|
}
|
||||||
|
|
||||||
req.Header = r.Header.Clone()
|
req.Header = r.Header.Clone()
|
||||||
req.Header.Set("X-Forwarded-Proto", r.Proto)
|
req.Header.Set("X-Forwarded-For", xff)
|
||||||
|
req.Header.Set("X-Forwarded-Proto", proto)
|
||||||
req.Header.Set("X-Forwarded-Host", r.Host)
|
req.Header.Set("X-Forwarded-Host", r.Host)
|
||||||
req.Header.Set("X-Forwarded-Uri", r.URL.RequestURI())
|
req.Header.Set("X-Forwarded-Uri", r.URL.RequestURI())
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req) //nolint:gosec
|
resp, err := m.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).Msg("failed to connect to forwardauth server")
|
ForwardAuth.LogError(r).Err(err).Msg("failed to connect to forwardauth server")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
for _, h := range m.ForwardAuthHeaders {
|
body, release, err := utils.ReadAllBody(resp)
|
||||||
if v := resp.Header.Get(h); v != "" {
|
defer release()
|
||||||
w.Header().Set(h, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusUnauthorized {
|
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
host = r.Host
|
ForwardAuth.LogError(r).Err(err).Msg("failed to read response body")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
scheme := "http://"
|
httpheaders.CopyHeader(w.Header(), resp.Header)
|
||||||
if r.TLS != nil {
|
httpheaders.RemoveHopByHopHeaders(w.Header())
|
||||||
scheme = "https://"
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(host, ".")
|
loc, err := resp.Location()
|
||||||
if len(parts) > 2 {
|
if err != nil {
|
||||||
host = strings.Join(parts[len(parts)-2:], ".")
|
if !errors.Is(err, http.ErrNoLocation) {
|
||||||
|
ForwardAuth.LogError(r).Err(err).Msg("failed to get location")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} else if loc := loc.String(); loc != "" {
|
||||||
|
r.Header.Set("Location", loc)
|
||||||
}
|
}
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
redirectUrl := scheme + m.ForwardAuthRoute + "." + host + m.ForwardAuthLogin + scheme + r.Host + r.URL.RequestURI()
|
_, err = w.Write(body)
|
||||||
http.Redirect(w, r, redirectUrl, http.StatusPermanentRedirect)
|
if err != nil {
|
||||||
|
ForwardAuth.LogError(r).Err(err).Msg("failed to write response body")
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusForbidden {
|
for _, h := range m.AuthResponseHeaders {
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
if v := resp.Header.Get(h); v != "" {
|
||||||
return false
|
// NOTE: need to set the header to the original request to forward to upstream
|
||||||
|
r.Header.Set(h, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user