cleanup and simplify middleware implementations, refactor some other code

This commit is contained in:
yusing
2024-12-16 10:19:14 +08:00
parent 8a9cb2527e
commit 59f4eaf3ea
34 changed files with 641 additions and 720 deletions

View File

@@ -12,16 +12,17 @@ import (
"strings"
"time"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type (
forwardAuth struct {
forwardAuthOpts
m *Middleware
ForwardAuthOpts
*Tracer
reqCookiesMap F.Map[*http.Request, []*http.Cookie]
}
forwardAuthOpts struct {
ForwardAuthOpts struct {
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
@@ -29,36 +30,30 @@ type (
}
)
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
var ForwardAuth = NewMiddleware[forwardAuth]()
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error {
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err
}
fa.m = &Middleware{
impl: fa,
before: fa.forward,
}
return fa.m, nil
// setup implements MiddlewareWithSetup.
func (fa *forwardAuth) setup() {
fa.reqCookiesMap = F.NewMapOf[*http.Request, []*http.Cookie]()
}
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
// before implements RequestModifier.
func (fa *forwardAuth) before(w http.ResponseWriter, req *http.Request) (proceed bool) {
gphttp.RemoveHop(req.Header)
// Construct original URL for the redirect
// scheme := "http"
// if req.TLS != nil {
// scheme = "https"
// }
// originalURL := scheme + "://" + req.Host + req.RequestURI
scheme := "http"
if req.TLS != nil {
scheme = "https"
}
originalURL := scheme + "://" + req.Host + req.RequestURI
url := fa.Address
faReq, err := http.NewRequestWithContext(
@@ -68,7 +63,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", url).WithError(err)
fa.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -79,12 +74,12 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
faReq.Header = gphttp.FilterHeaders(faReq.Header, fa.AuthResponseHeaders)
fa.setAuthHeaders(req, faReq)
// Set headers needed by Authentik
// faReq.Header.Set("X-Original-URL", originalURL)
fa.m.AddTraceRequest("forward auth request", faReq)
faReq.Header.Set("X-Original-Url", originalURL)
fa.AddTraceRequest("forward auth request", faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", url).WithError(err)
fa.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -92,30 +87,30 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
fa.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if faResp.StatusCode < http.StatusOK || faResp.StatusCode >= http.StatusMultipleChoices {
fa.m.AddTraceResponse("forward auth response", faResp)
fa.AddTraceResponse("forward auth response", faResp)
gphttp.CopyHeader(w.Header(), faResp.Header)
gphttp.RemoveHop(w.Header())
redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
fa.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
w.Header().Set("Location", redirectURL.String())
fa.m.AddTracef("%s", "redirect to "+redirectURL.String())
fa.AddTracef("%s", "redirect to "+redirectURL.String())
}
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
fa.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}
@@ -132,18 +127,22 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
authCookies := faResp.Cookies()
if len(authCookies) == 0 {
next.ServeHTTP(w, req)
return
if len(authCookies) > 0 {
fa.reqCookiesMap.Store(req, authCookies)
}
next.ServeHTTP(gphttp.NewModifyResponseWriter(w, req, func(resp *http.Response) error {
fa.setAuthCookies(resp, authCookies)
return nil
}), req)
return true
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie) {
// modifyResponse implements ResponseModifier.
func (fa *forwardAuth) modifyResponse(resp *http.Response) error {
if cookies, ok := fa.reqCookiesMap.Load(resp.Request); ok {
fa.setAuthCookies(resp, cookies)
fa.reqCookiesMap.Delete(resp.Request)
}
return nil
}
func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*http.Cookie) {
if len(fa.AddAuthCookiesToResponse) == 0 {
return
}
@@ -166,7 +165,7 @@ func (fa *forwardAuth) setAuthCookies(resp *http.Response, authCookies []*Cookie
}
}
func (fa *forwardAuth) setAuthHeaders(req, faReq *Request) {
func (fa *forwardAuth) setAuthHeaders(req, faReq *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if fa.TrustForwardHeader {
if prior, ok := req.Header[gphttp.HeaderXForwardedFor]; ok {