mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 16:58:31 +02:00
fix(middlewares): correctly bypass middlewares with response rules
This commit is contained in:
@@ -30,8 +30,8 @@ func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNex
|
|||||||
return c.modReq.before(w, r)
|
return c.modReq.before(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *checkBypass) modifyResponse(w http.ResponseWriter, resp *http.Response) error {
|
func (c *checkBypass) modifyResponse(resp *http.Response) error {
|
||||||
if c.modRes == nil || c.bypass.ShouldBypass(w, resp.Request) {
|
if c.modRes == nil || c.bypass.ShouldBypass(rules.ResponseAsRW(resp), resp.Request) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.modRes.modifyResponse(resp)
|
return c.modRes.modifyResponse(resp)
|
||||||
|
|||||||
@@ -138,6 +138,82 @@ func TestReverseProxyBypass(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBypassResponse(t *testing.T) {
|
||||||
|
t.Run("req_rules", func(t *testing.T) {
|
||||||
|
mr, err := ModifyResponse.New(map[string]any{
|
||||||
|
"bypass": []string{"path glob(/test/*) | path /api"},
|
||||||
|
"set_headers": map[string]string{
|
||||||
|
"Test-Header": "test-value",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
expect.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
expectBypass bool
|
||||||
|
}{
|
||||||
|
{"bypass", "/test/123", true},
|
||||||
|
{"bypass2", "/test/123/456", true},
|
||||||
|
{"bypass3", "/api", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("test")),
|
||||||
|
Request: req,
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
mErr := mr.ModifyResponse(resp)
|
||||||
|
expect.NoError(t, mErr)
|
||||||
|
if test.expectBypass {
|
||||||
|
expect.Equal(t, resp.Header.Get("Test-Header"), "")
|
||||||
|
} else {
|
||||||
|
expect.Equal(t, resp.Header.Get("Test-Header"), "test-value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("res_rules", func(t *testing.T) {
|
||||||
|
mr, err := ModifyResponse.New(map[string]any{
|
||||||
|
"bypass": []string{"status 200"},
|
||||||
|
"set_headers": map[string]string{
|
||||||
|
"Test-Header": "test-value",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
expect.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
statusCode int
|
||||||
|
expectBypass bool
|
||||||
|
}{
|
||||||
|
{"bypass", 200, true},
|
||||||
|
{"no_bypass", 201, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: test.statusCode,
|
||||||
|
Body: io.NopCloser(strings.NewReader("test")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
mErr := mr.ModifyResponse(resp)
|
||||||
|
expect.NoError(t, mErr)
|
||||||
|
if test.expectBypass {
|
||||||
|
expect.Equal(t, resp.Header.Get("Test-Header"), "")
|
||||||
|
} else {
|
||||||
|
expect.Equal(t, resp.Header.Get("Test-Header"), "test-value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestEntrypointBypassRoute(t *testing.T) {
|
func TestEntrypointBypassRoute(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("test"))
|
w.Write([]byte("test"))
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
"github.com/yusing/goutils/synk"
|
"github.com/yusing/goutils/synk"
|
||||||
)
|
)
|
||||||
@@ -43,6 +45,29 @@ func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type responseAsRW struct {
|
||||||
|
resp *http.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r responseAsRW) WriteHeader(code int) {
|
||||||
|
log.Error().Msg("write header after response has been created")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r responseAsRW) Write(b []byte) (int, error) {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r responseAsRW) Header() http.Header {
|
||||||
|
return r.resp.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResponseAsRW(resp *http.Response) *ResponseModifier {
|
||||||
|
return &ResponseModifier{
|
||||||
|
statusCode: resp.StatusCode,
|
||||||
|
w: responseAsRW{resp},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetInitResponseModifier returns the response modifier for the given response writer.
|
// GetInitResponseModifier returns the response modifier for the given response writer.
|
||||||
// If the response writer is already wrapped, it will return the wrapped response modifier.
|
// If the response writer is already wrapped, it will return the wrapped response modifier.
|
||||||
// Otherwise, it will return a new response modifier.
|
// Otherwise, it will return a new response modifier.
|
||||||
|
|||||||
Reference in New Issue
Block a user