mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 16:58:31 +02:00
cleanup and simplify middleware implementations, refactor some other code
This commit is contained in:
@@ -12,16 +12,17 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type (
|
||||
forwardAuth struct {
|
||||
forwardAuthOpts
|
||||
m *Middleware
|
||||
ForwardAuthOpts
|
||||
*Tracer
|
||||
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
|
||||
}
|
||||
forwardAuthOpts struct {
|
||||
ForwardAuthOpts struct {
|
||||
Address string `validate:"url,required"`
|
||||
TrustForwardHeader bool
|
||||
AuthResponseHeaders []string
|
||||
@@ -29,36 +30,30 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
|
||||
var ForwardAuth = NewMiddleware[forwardAuth]()
|
||||
|
||||
var faHTTPClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: func(r *Request, via []*Request) error {
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||
fa := new(forwardAuth)
|
||||
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fa.m = &Middleware{
|
||||
impl: fa,
|
||||
before: fa.forward,
|
||||
}
|
||||
return fa.m, nil
|
||||
// setup implements MiddlewareWithSetup.
|
||||
func (fa *forwardAuth) setup() {
|
||||
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
// before implements RequestModifier.
|
||||
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
|
||||
gphttp.RemoveHop(req.Header)
|
||||
|
||||
// Construct original URL for the redirect
|
||||
// scheme := "http"
|
||||
// if req.TLS != nil {
|
||||
// scheme = "https"
|
||||
// }
|
||||
// originalURL := scheme + "://" + req.Host + req.RequestURI
|
||||
scheme := "http"
|
||||
if req.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
originalURL := scheme + "://" + req.Host + req.RequestURI
|
||||
|
||||
url := fa.Address
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
@@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("new request err to %s", url).WithError(err)
|
||||
fa.AddTracef("new request err to %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
|
||||
fa.setAuthHeaders(req, faReq)
|
||||
// Set headers needed by Authentik
|
||||
// faReq.Header.Set("X-Original-URL", originalURL)
|
||||
fa.m.AddTraceRequest("forward auth request", faReq)
|
||||
faReq.Header.Set("X-Original-Url", originalURL)
|
||||
fa.AddTraceRequest("forward auth request", faReq)
|
||||
|
||||
faResp, err := faHTTPClient.Do(faReq)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to call %s", url).WithError(err)
|
||||
fa.AddTracef("failed to call %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||
fa.AddTracef("failed to read response body from %s", url).WithError(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
|
||||
fa.m.AddTraceResponse("forward auth response", faResp)
|
||||
fa.AddTraceResponse("forward auth response", faResp)
|
||||
gphttp.CopyHeader(w.Header(), faResp.Header)
|
||||
gphttp.RemoveHop(w.Header())
|
||||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
|
||||
fa.AddTracef("%s", "redirect to "+redirectURL.String())
|
||||
}
|
||||
|
||||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
|
||||
|
||||
authCookies := faResp.Cookies()
|
||||
|
||||
if len(authCookies) == 0 {
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
if len(authCookies) > 0 {
|
||||
fa.reqCookiesMap.Store(req, authCookies)
|
||||
}
|
||||
|
||||
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
|
||||
fa.setAuthCookies(resp, authCookies)
|
||||
return nil
|
||||
}), req)
|
||||
return true
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
|
||||
// modifyResponse implements ResponseModifier.
|
||||
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
|
||||
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
|
||||
fa.setAuthCookies(resp, cookies)
|
||||
fa.reqCookiesMap.Delete(resp.Request)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
|
||||
if len(fa.AddAuthCookiesToResponse) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
|
||||
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if fa.TrustForwardHeader {
|
||||
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {
|
||||
|
||||
Reference in New Issue
Block a user