get back aa6fafd5, accidentally reverted in 03cad9f3

This commit is contained in:
yusing
2024-10-07 00:06:29 +08:00
parent de7805f281
commit 929b7f7059
3 changed files with 55 additions and 31 deletions

View File

@@ -54,7 +54,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
}
// TODO: use tr from reverse proxy
tr, ok := fa.forwardAuthOpts.transport.(*http.Transport)
tr, ok := fa.transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
@@ -81,7 +81,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -89,12 +89,13 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
gpHTTP.CopyHeader(faReq.Header, req.Header)
gpHTTP.RemoveHop(faReq.Header)
gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
faReq.Header = gpHTTP.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := fa.client.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -102,30 +103,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTracef("status %d", faResp.StatusCode)
fa.m.AddTraceResponse("forward auth response", faResp)
gpHTTP.CopyHeader(w.Header(), faResp.Header)
gpHTTP.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String())
fa.m.AddTracef("redirect to %q", redirectURL.String()).WithResponse(faResp)
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp)
}
return
}