updated validation for middleware options

This commit is contained in:
yusing
2024-11-30 04:00:55 +08:00
parent edc1ad952d
commit 6e9b5cc113
9 changed files with 97 additions and 94 deletions

View File

@@ -8,7 +8,6 @@ import (
"io"
"net"
"net/http"
"net/url"
"slices"
"strings"
"time"
@@ -20,64 +19,49 @@ import (
type (
forwardAuth struct {
forwardAuthOpts
m *Middleware
client http.Client
m *Middleware
}
forwardAuthOpts struct {
Address string `json:"address"`
TrustForwardHeader bool `json:"trustForwardHeader"`
AuthResponseHeaders []string `json:"authResponseHeaders"`
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse"`
transport http.RoundTripper
Address string `validate:"url,required"`
TrustForwardHeader bool
AuthResponseHeaders []string
AddAuthCookiesToResponse []string
}
)
var ForwardAuth = &Middleware{withOptions: NewForwardAuthfunc}
var ForwardAuth = &Middleware{withOptions: NewForwardAuth}
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.Error) {
var faHTTPClient = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
}
func NewForwardAuth(optsRaw OptionsRaw) (*Middleware, E.Error) {
fa := new(forwardAuth)
if err := Deserialize(optsRaw, &fa.forwardAuthOpts); err != nil {
return nil, err
}
if _, err := url.Parse(fa.Address); err != nil {
return nil, E.From(err)
}
fa.m = &Middleware{
impl: fa,
before: fa.forward,
}
// TODO: use tr from reverse proxy
tr, ok := fa.transport.(*http.Transport)
if ok {
tr = tr.Clone()
} else {
tr = gphttp.DefaultTransport.Clone()
}
fa.client = http.Client{
CheckRedirect: func(r *Request, via []*Request) error {
return http.ErrUseLastResponse
},
Timeout: 30 * time.Second,
Transport: tr,
}
return fa.m, nil
}
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gphttp.RemoveHop(req.Header)
url := fa.Address
faReq, err := http.NewRequestWithContext(
req.Context(),
http.MethodGet,
fa.Address,
url,
nil,
)
if err != nil {
fa.m.AddTracef("new request err to %s", fa.Address).WithError(err)
fa.m.AddTracef("new request err to %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -89,9 +73,9 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
fa.setAuthHeaders(req, faReq)
fa.m.AddTraceRequest("forward auth request", faReq)
faResp, err := fa.client.Do(faReq)
faResp, err := faHTTPClient.Do(faReq)
if err != nil {
fa.m.AddTracef("failed to call %s", fa.Address).WithError(err)
fa.m.AddTracef("failed to call %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -99,7 +83,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
body, err := io.ReadAll(faResp.Body)
if err != nil {
fa.m.AddTracef("failed to read response body from %s", fa.Address).WithError(err)
fa.m.AddTracef("failed to read response body from %s", url).WithError(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -111,7 +95,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
redirectURL, err := faResp.Location()
if err != nil {
fa.m.AddTracef("failed to get location from %s", fa.Address).WithError(err).WithResponse(faResp)
fa.m.AddTracef("failed to get location from %s", url).WithError(err).WithResponse(faResp)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
@@ -122,7 +106,7 @@ func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Req
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
fa.m.AddTracef("failed to write response body from %s", fa.Address).WithError(err).WithResponse(faResp)
fa.m.AddTracef("failed to write response body from %s", url).WithError(err).WithResponse(faResp)
}
return
}