diff --git a/internal/net/gphttp/middleware/bypass.go b/internal/net/gphttp/middleware/bypass.go index 04a88ce3..b8baf776 100644 --- a/internal/net/gphttp/middleware/bypass.go +++ b/internal/net/gphttp/middleware/bypass.go @@ -30,8 +30,8 @@ func (c *checkBypass) before(w http.ResponseWriter, r *http.Request) (proceedNex return c.modReq.before(w, r) } -func (c *checkBypass) modifyResponse(w http.ResponseWriter, resp *http.Response) error { - if c.modRes == nil || c.bypass.ShouldBypass(w, resp.Request) { +func (c *checkBypass) modifyResponse(resp *http.Response) error { + if c.modRes == nil || c.bypass.ShouldBypass(rules.ResponseAsRW(resp), resp.Request) { return nil } return c.modRes.modifyResponse(resp) diff --git a/internal/net/gphttp/middleware/bypass_test.go b/internal/net/gphttp/middleware/bypass_test.go index 74aec5ba..309a3554 100644 --- a/internal/net/gphttp/middleware/bypass_test.go +++ b/internal/net/gphttp/middleware/bypass_test.go @@ -138,6 +138,82 @@ func TestReverseProxyBypass(t *testing.T) { } } +func TestBypassResponse(t *testing.T) { + t.Run("req_rules", func(t *testing.T) { + mr, err := ModifyResponse.New(map[string]any{ + "bypass": []string{"path glob(/test/*) | path /api"}, + "set_headers": map[string]string{ + "Test-Header": "test-value", + }, + }) + expect.NoError(t, err) + + tests := []struct { + name string + path string + expectBypass bool + }{ + {"bypass", "/test/123", true}, + {"bypass2", "/test/123/456", true}, + {"bypass3", "/api", true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://example.com"+test.path, nil) + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("test")), + Request: req, + Header: make(http.Header), + } + mErr := mr.ModifyResponse(resp) + expect.NoError(t, mErr) + if test.expectBypass { + expect.Equal(t, resp.Header.Get("Test-Header"), "") + } else { + expect.Equal(t, resp.Header.Get("Test-Header"), "test-value") + } + }) + } + }) + t.Run("res_rules", func(t *testing.T) { + mr, err := ModifyResponse.New(map[string]any{ + "bypass": []string{"status 200"}, + "set_headers": map[string]string{ + "Test-Header": "test-value", + }, + }) + expect.NoError(t, err) + + tests := []struct { + name string + statusCode int + expectBypass bool + }{ + {"bypass", 200, true}, + {"no_bypass", 201, false}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resp := &http.Response{ + StatusCode: test.statusCode, + Body: io.NopCloser(strings.NewReader("test")), + Header: make(http.Header), + } + mErr := mr.ModifyResponse(resp) + expect.NoError(t, mErr) + if test.expectBypass { + expect.Equal(t, resp.Header.Get("Test-Header"), "") + } else { + expect.Equal(t, resp.Header.Get("Test-Header"), "test-value") + } + }) + } + }) +} + func TestEntrypointBypassRoute(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("test")) diff --git a/internal/route/rules/response_modifier.go b/internal/route/rules/response_modifier.go index 07a1e6a5..35571bf9 100644 --- a/internal/route/rules/response_modifier.go +++ b/internal/route/rules/response_modifier.go @@ -4,10 +4,12 @@ import ( "bufio" "bytes" "errors" + "io" "net" "net/http" "strconv" + "github.com/rs/zerolog/log" gperr "github.com/yusing/goutils/errs" "github.com/yusing/goutils/synk" ) @@ -43,6 +45,29 @@ func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier { } } +type responseAsRW struct { + resp *http.Response +} + +func (r responseAsRW) WriteHeader(code int) { + log.Error().Msg("write header after response has been created") +} + +func (r responseAsRW) Write(b []byte) (int, error) { + return 0, io.ErrClosedPipe +} + +func (r responseAsRW) Header() http.Header { + return r.resp.Header +} + +func ResponseAsRW(resp *http.Response) *ResponseModifier { + return &ResponseModifier{ + statusCode: resp.StatusCode, + w: responseAsRW{resp}, + } +} + // GetInitResponseModifier returns the response modifier for the given response writer. // If the response writer is already wrapped, it will return the wrapped response modifier. // Otherwise, it will return a new response modifier.