fix(middlewares): correctly bypass middlewares with response rules

This commit is contained in:
yusing
2025-10-28 20:44:46 +08:00
parent 098fb7e62d
commit 1797a222cd
3 changed files with 103 additions and 2 deletions

View File

@@ -30,8 +30,8 @@ func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNex
return c.modReq.before(w, r)
}
func (c *checkBypass) modifyResponse(w http.ResponseWriter, resp *http.Response) error {
if c.modRes == nil || c.bypass.ShouldBypass(w, resp.Request) {
func (c *checkBypass) modifyResponse(resp *http.Response) error {
if c.modRes == nil || c.bypass.ShouldBypass(rules.ResponseAsRW(resp), resp.Request) {
return nil
}
return c.modRes.modifyResponse(resp)

View File

@@ -138,6 +138,82 @@ func TestReverseProxyBypass(t *testing.T) {
}
}
func TestBypassResponse(t *testing.T) {
t.Run("req_rules", func(t *testing.T) {
mr, err := ModifyResponse.New(map[string]any{
"bypass": []string{"path glob(/test/*) | path /api"},
"set_headers": map[string]string{
"Test-Header": "test-value",
},
})
expect.NoError(t, err)
tests := []struct {
name string
path string
expectBypass bool
}{
{"bypass", "/test/123", true},
{"bypass2", "/test/123/456", true},
{"bypass3", "/api", true},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com"+test.path, nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("test")),
Request: req,
Header: make(http.Header),
}
mErr := mr.ModifyResponse(resp)
expect.NoError(t, mErr)
if test.expectBypass {
expect.Equal(t, resp.Header.Get("Test-Header"), "")
} else {
expect.Equal(t, resp.Header.Get("Test-Header"), "test-value")
}
})
}
})
t.Run("res_rules", func(t *testing.T) {
mr, err := ModifyResponse.New(map[string]any{
"bypass": []string{"status 200"},
"set_headers": map[string]string{
"Test-Header": "test-value",
},
})
expect.NoError(t, err)
tests := []struct {
name string
statusCode int
expectBypass bool
}{
{"bypass", 200, true},
{"no_bypass", 201, false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resp := &http.Response{
StatusCode: test.statusCode,
Body: io.NopCloser(strings.NewReader("test")),
Header: make(http.Header),
}
mErr := mr.ModifyResponse(resp)
expect.NoError(t, mErr)
if test.expectBypass {
expect.Equal(t, resp.Header.Get("Test-Header"), "")
} else {
expect.Equal(t, resp.Header.Get("Test-Header"), "test-value")
}
})
}
})
}
func TestEntrypointBypassRoute(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("test"))