small update on reverse proxy and xforwarded middlewares

This commit is contained in:
yusing
2024-12-01 05:04:57 +08:00
parent a4f44348ef
commit 863bb3f474
7 changed files with 76 additions and 67 deletions

View File

@@ -53,6 +53,13 @@ func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
// scheme := "http"
// if req.TLS != nil {
// scheme = "https"
// }
// originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
req.Context(),
@@ -71,6 +78,8 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
@@ -100,7 +109,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String()).WithResponse(faResp)
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
@@ -160,54 +169,54 @@ func (fa *forwardAuth) setAuthCookies(resp *Response, authCookies []*Cookie) {
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[xForwardedFor]; ok {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
faReq.Header.Set(xForwardedFor, clientIP)
faReq.Header.Set(gphttp.HeaderXForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
xMethod := req.Header.Get(gphttp.HeaderXForwardedMethod)
switch {
case xMethod != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedMethod, xMethod)
faReq.Header.Set(gphttp.HeaderXForwardedMethod, xMethod)
case req.Method != "":
faReq.Header.Set(xForwardedMethod, req.Method)
faReq.Header.Set(gphttp.HeaderXForwardedMethod, req.Method)
default:
faReq.Header.Del(xForwardedMethod)
faReq.Header.Del(gphttp.HeaderXForwardedMethod)
}
xfp := req.Header.Get(xForwardedProto)
xfp := req.Header.Get(gphttp.HeaderXForwardedProto)
switch {
case xfp != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedProto, xfp)
faReq.Header.Set(gphttp.HeaderXForwardedProto, xfp)
case req.TLS != nil:
faReq.Header.Set(xForwardedProto, "https")
faReq.Header.Set(gphttp.HeaderXForwardedProto, "https")
default:
faReq.Header.Set(xForwardedProto, "http")
faReq.Header.Set(gphttp.HeaderXForwardedProto, "http")
}
if xfp := req.Header.Get(xForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(xForwardedPort, xfp)
if xfp := req.Header.Get(gphttp.HeaderXForwardedPort); xfp != "" && fa.TrustForwardHeader {
faReq.Header.Set(gphttp.HeaderXForwardedPort, xfp)
}
xfh := req.Header.Get(xForwardedHost)
xfh := req.Header.Get(gphttp.HeaderXForwardedHost)
switch {
case xfh != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedHost, xfh)
faReq.Header.Set(gphttp.HeaderXForwardedHost, xfh)
case req.Host != "":
faReq.Header.Set(xForwardedHost, req.Host)
faReq.Header.Set(gphttp.HeaderXForwardedHost, req.Host)
default:
faReq.Header.Del(xForwardedHost)
faReq.Header.Del(gphttp.HeaderXForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
xfURI := req.Header.Get(gphttp.HeaderXForwardedURI)
switch {
case xfURI != "" && fa.TrustForwardHeader:
faReq.Header.Set(xForwardedURI, xfURI)
faReq.Header.Set(gphttp.HeaderXForwardedURI, xfURI)
case req.URL.RequestURI() != "":
faReq.Header.Set(xForwardedURI, req.URL.RequestURI())
faReq.Header.Set(gphttp.HeaderXForwardedURI, req.URL.RequestURI())
default:
faReq.Header.Del(xForwardedURI)
faReq.Header.Del(gphttp.HeaderXForwardedURI)
}
}