diff --git a/internal/net/gphttp/middleware/bypass.go b/internal/net/gphttp/middleware/bypass.go index df3cb61f..c598dff2 100644 --- a/internal/net/gphttp/middleware/bypass.go +++ b/internal/net/gphttp/middleware/bypass.go @@ -122,10 +122,6 @@ func (c *checkBypass) modifyResponse(resp *http.Response) error { return c.modRes.modifyResponse(resp) } -func (c *checkBypass) requiresBodyRewrite() bool { - return requiresBodyRewrite(c.modRes) -} - func (m *Middleware) withCheckBypass() any { if len(m.Bypass) > 0 { modReq, _ := m.impl.(RequestModifier) diff --git a/internal/net/gphttp/middleware/custom_error_page.go b/internal/net/gphttp/middleware/custom_error_page.go index 379d28db..905a89ee 100644 --- a/internal/net/gphttp/middleware/custom_error_page.go +++ b/internal/net/gphttp/middleware/custom_error_page.go @@ -20,9 +20,7 @@ var CustomErrorPage = NewMiddleware[customErrorPage]() const StaticFilePathPrefix = "/$gperrorpage/" -func (customErrorPage) requiresBodyRewrite() bool { - return true -} +func (customErrorPage) isBodyResponseModifier() {} // before implements RequestModifier. func (customErrorPage) before(w http.ResponseWriter, r *http.Request) (proceed bool) { diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index 817c401d..3c86ced6 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -54,7 +54,11 @@ type ( RequestModifier interface { before(w http.ResponseWriter, r *http.Request) (proceed bool) } - ResponseModifier interface{ modifyResponse(r *http.Response) error } + ResponseModifier interface{ modifyResponse(r *http.Response) error } + BodyResponseModifier interface { + ResponseModifier + isBodyResponseModifier() + } MiddlewareWithSetup interface{ setup() } MiddlewareFinalizer interface{ finalize() } MiddlewareFinalizerWithError interface { @@ -208,8 +212,15 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * next(w, r) return } + isBodyModifier := isBodyResponseModifier(exec) - lrm := httputils.NewLazyResponseModifier(w, canBufferAndModifyResponseBody) + shouldBuffer := canBufferAndModifyResponseBody + if !isBodyModifier { + // Header-only response modifiers do not need body rewrite capability checks. + // We still respect max buffer limits and may fall back to passthrough for large bodies. + shouldBuffer = func(http.Header) bool { return true } + } + lrm := httputils.NewLazyResponseModifier(w, shouldBuffer) lrm.SetMaxBufferedBytes(maxModifiableBody) defer func() { _, err := lrm.FlushRelease() @@ -225,6 +236,9 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * } rm := lrm.ResponseModifier() + if rm.IsPassthrough() { + return + } currentBody := rm.BodyReader() currentResp := &http.Response{ StatusCode: rm.StatusCode(), @@ -246,7 +260,7 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * maps.Copy(rm.Header(), respToModify.Header) // override the body if changed - if respToModify.Body != currentBody { + if isBodyModifier && respToModify.Body != currentBody { err := rm.SetBody(respToModify.Body) if err != nil { m.LogError(r).Err(err).Msg("failed to set response body") diff --git a/internal/net/gphttp/middleware/middleware_chain.go b/internal/net/gphttp/middleware/middleware_chain.go index b9c1781c..dbd646ef 100644 --- a/internal/net/gphttp/middleware/middleware_chain.go +++ b/internal/net/gphttp/middleware/middleware_chain.go @@ -1,7 +1,6 @@ package middleware import ( - "fmt" "net/http" "strconv" @@ -9,12 +8,9 @@ import ( ) type middlewareChain struct { - befores []RequestModifier - modResps []ResponseModifier -} - -type bodyRewriteRequired interface { - requiresBodyRewrite() bool + befores []RequestModifier + respHeader []ResponseModifier + respBody []ResponseModifier } // TODO: check conflict or duplicates. @@ -27,7 +23,11 @@ func NewMiddlewareChain(name string, chain []*Middleware) *Middleware { chainMid.befores = append(chainMid.befores, before) } if mr, ok := comp.impl.(ResponseModifier); ok { - chainMid.modResps = append(chainMid.modResps, mr) + if isBodyResponseModifier(mr) { + chainMid.respBody = append(chainMid.respBody, mr) + } else { + chainMid.respHeader = append(chainMid.respHeader, mr) + } } } return m @@ -48,55 +48,32 @@ func (m *middlewareChain) before(w http.ResponseWriter, r *http.Request) (procee // modifyResponse implements ResponseModifier. func (m *middlewareChain) modifyResponse(resp *http.Response) error { - if len(m.modResps) == 0 { - return nil - } - for i, mr := range m.modResps { - if err := modifyResponseWithBodyRewriteGate(mr, resp); err != nil { + for i, mr := range m.respHeader { + if err := mr.modifyResponse(resp); err != nil { return gperr.PrependSubject(err, strconv.Itoa(i)) } } - return nil -} - -func modifyResponseWithBodyRewriteGate(mr ResponseModifier, resp *http.Response) error { - originalBody := resp.Body - originalContentLength := resp.ContentLength - allowBodyRewrite := canBufferAndModifyResponseBody(responseHeaderForBodyRewriteGate(resp)) - if !allowBodyRewrite && requiresBodyRewrite(mr) { + if len(m.respBody) == 0 || !canBufferAndModifyResponseBody(responseHeaderForBodyRewriteGate(resp)) { return nil } - - if err := mr.modifyResponse(resp); err != nil { - return err - } - - if allowBodyRewrite || resp.Body == originalBody { - return nil - } - - if resp.Body != nil { - if err := resp.Body.Close(); err != nil { - return fmt.Errorf("close rewritten body: %w", err) + headerLen := len(m.respHeader) + for i, mr := range m.respBody { + if err := mr.modifyResponse(resp); err != nil { + return gperr.PrependSubject(err, strconv.Itoa(i+headerLen)) } } - if originalBody == nil || originalBody == http.NoBody { - resp.Body = http.NoBody - } else { - resp.Body = originalBody - } - resp.ContentLength = originalContentLength - if originalContentLength >= 0 { - resp.Header.Set("Content-Length", strconv.FormatInt(originalContentLength, 10)) - } else { - resp.Header.Del("Content-Length") - } return nil } -func requiresBodyRewrite(mr ResponseModifier) bool { - required, ok := mr.(bodyRewriteRequired) - return ok && required.requiresBodyRewrite() +func isBodyResponseModifier(mr ResponseModifier) bool { + if chain, ok := mr.(*middlewareChain); ok { + return len(chain.respBody) > 0 + } + if bypass, ok := mr.(*checkBypass); ok { + return isBodyResponseModifier(bypass.modRes) + } + _, ok := mr.(BodyResponseModifier) + return ok } func responseHeaderForBodyRewriteGate(resp *http.Response) http.Header { diff --git a/internal/net/gphttp/middleware/middleware_test.go b/internal/net/gphttp/middleware/middleware_test.go index 340e5872..6ae2ea2d 100644 --- a/internal/net/gphttp/middleware/middleware_test.go +++ b/internal/net/gphttp/middleware/middleware_test.go @@ -1,7 +1,6 @@ package middleware import ( - "errors" "io" "net/http" "net/http/httptest" @@ -17,50 +16,37 @@ type testPriority struct { } var test = NewMiddleware[testPriority]() -var responseRewrite = NewMiddleware[testResponseRewrite]() +var responseHeaderRewrite = NewMiddleware[testHeaderRewrite]() +var responseBodyRewrite = NewMiddleware[testBodyRewrite]() func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool { w.Header().Add("Test-Value", strconv.Itoa(t.Value)) return true } -type testResponseRewrite struct { +type testHeaderRewrite struct { StatusCode int `json:"status_code"` HeaderKey string `json:"header_key"` HeaderVal string `json:"header_val"` - Body string `json:"body"` } -type closeSensitiveBody struct { - data []byte - offset int - closed bool -} - -func (b *closeSensitiveBody) Read(p []byte) (int, error) { - if b.closed { - return 0, errors.New("http: read on closed response body") - } - if b.offset >= len(b.data) { - return 0, io.EOF - } - n := copy(p, b.data[b.offset:]) - b.offset += n - return n, nil -} - -func (b *closeSensitiveBody) Close() error { - b.closed = true +func (t testHeaderRewrite) modifyResponse(resp *http.Response) error { + resp.StatusCode = t.StatusCode + resp.Header.Set(t.HeaderKey, t.HeaderVal) return nil } -func (t testResponseRewrite) modifyResponse(resp *http.Response) error { - resp.StatusCode = t.StatusCode - resp.Header.Set(t.HeaderKey, t.HeaderVal) +type testBodyRewrite struct { + Body string `json:"body"` +} + +func (t testBodyRewrite) modifyResponse(resp *http.Response) error { resp.Body = io.NopCloser(strings.NewReader(t.Body)) return nil } +func (testBodyRewrite) isBodyResponseModifier() {} + func TestMiddlewarePriority(t *testing.T) { priorities := []int{4, 7, 9, 0} chain := make([]*Middleware, len(priorities)) @@ -78,50 +64,66 @@ func TestMiddlewarePriority(t *testing.T) { } func TestMiddlewareResponseRewriteGate(t *testing.T) { - opts := OptionsRaw{ + headerOpts := OptionsRaw{ "status_code": 418, "header_key": "X-Rewrite", "header_val": "1", - "body": "rewritten-body", } + bodyOpts := OptionsRaw{ + "body": "rewritten-body", + } + headerMid, err := responseHeaderRewrite.New(headerOpts) + expect.NoError(t, err) + bodyMid, err := responseBodyRewrite.New(bodyOpts) + expect.NoError(t, err) tests := []struct { - name string - respHeaders http.Header - respBody []byte - expectBody string + name string + respHeaders http.Header + respBody []byte + expectStatus int + expectHeader string + expectBody string }{ { name: "allow_body_rewrite_for_html", respHeaders: http.Header{ "Content-Type": []string{"text/html; charset=utf-8"}, }, - respBody: []byte("original"), - expectBody: "rewritten-body", + respBody: []byte("original"), + expectStatus: http.StatusTeapot, + expectHeader: "1", + expectBody: "rewritten-body", }, { name: "allow_body_rewrite_for_json", respHeaders: http.Header{ "Content-Type": []string{"application/json"}, }, - respBody: []byte(`{"message":"original"}`), - expectBody: "rewritten-body", + respBody: []byte(`{"message":"original"}`), + expectStatus: http.StatusTeapot, + expectHeader: "1", + expectBody: "rewritten-body", }, { name: "allow_body_rewrite_for_yaml", respHeaders: http.Header{ "Content-Type": []string{"application/yaml"}, }, - respBody: []byte("k: v"), - expectBody: "rewritten-body", + respBody: []byte("k: v"), + expectStatus: http.StatusTeapot, + expectHeader: "1", + expectBody: "rewritten-body", }, { name: "block_body_rewrite_for_binary_content", respHeaders: http.Header{ "Content-Type": []string{"application/octet-stream"}, }, - respBody: []byte("binary"), - expectBody: "binary", + respBody: []byte("binary"), + expectStatus: http.StatusTeapot, + expectHeader: "1", + expectBody: "binary", }, { name: "block_body_rewrite_for_transfer_encoded_html", @@ -129,8 +131,10 @@ func TestMiddlewareResponseRewriteGate(t *testing.T) { "Content-Type": []string{"text/html"}, "Transfer-Encoding": []string{"chunked"}, }, - respBody: []byte("original"), - expectBody: "original", + respBody: []byte("original"), + expectStatus: http.StatusTeapot, + expectHeader: "1", + expectBody: "original", }, { name: "block_body_rewrite_for_content_encoded_html", @@ -138,34 +142,42 @@ func TestMiddlewareResponseRewriteGate(t *testing.T) { "Content-Type": []string{"text/html"}, "Content-Encoding": []string{"gzip"}, }, - respBody: []byte("original"), - expectBody: "original", + respBody: []byte("original"), + expectStatus: http.StatusTeapot, + expectHeader: "1", + expectBody: "original", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - result, err := newMiddlewareTest(responseRewrite, &testArgs{ - middlewareOpt: opts, - respHeaders: tc.respHeaders, - respBody: tc.respBody, - respStatus: http.StatusOK, + result, err := newMiddlewaresTest([]*Middleware{headerMid, bodyMid}, &testArgs{ + respHeaders: tc.respHeaders, + respBody: tc.respBody, + respStatus: http.StatusOK, }) expect.NoError(t, err) - expect.Equal(t, result.ResponseStatus, http.StatusTeapot) - expect.Equal(t, result.ResponseHeaders.Get("X-Rewrite"), "1") + expect.Equal(t, result.ResponseStatus, tc.expectStatus) + expect.Equal(t, result.ResponseHeaders.Get("X-Rewrite"), tc.expectHeader) expect.Equal(t, string(result.Data), tc.expectBody) }) } } func TestMiddlewareResponseRewriteGateServeHTTP(t *testing.T) { - opts := OptionsRaw{ + headerOpts := OptionsRaw{ "status_code": 418, "header_key": "X-Rewrite", "header_val": "1", - "body": "rewritten-body", } + bodyOpts := OptionsRaw{ + "body": "rewritten-body", + } + headerMid, err := responseHeaderRewrite.New(headerOpts) + expect.NoError(t, err) + bodyMid, err := responseBodyRewrite.New(bodyOpts) + expect.NoError(t, err) + mid := NewMiddlewareChain("test", []*Middleware{headerMid, bodyMid}) tests := []struct { name string @@ -221,9 +233,6 @@ func TestMiddlewareResponseRewriteGateServeHTTP(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - mid, err := responseRewrite.New(opts) - expect.NoError(t, err) - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) rw := httptest.NewRecorder() @@ -251,33 +260,17 @@ func TestMiddlewareResponseRewriteGateServeHTTP(t *testing.T) { } } -func TestMiddlewareResponseRewriteGateSkipsBodyRewriterWhenRewriteBlocked(t *testing.T) { - originalBody := &closeSensitiveBody{ - data: []byte("original"), - } - req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) - resp := &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{ +func TestThemedSkipsBodyRewriteWhenRewriteBlocked(t *testing.T) { + result, err := newMiddlewareTest(Themed, &testArgs{ + middlewareOpt: OptionsRaw{ + "theme": DarkTheme, + }, + respHeaders: http.Header{ "Content-Type": []string{"text/html; charset=utf-8"}, "Transfer-Encoding": []string{"chunked"}, }, - Body: originalBody, - ContentLength: -1, - TransferEncoding: []string{"chunked"}, - Request: req, - } - - themedMid, err := Themed.New(OptionsRaw{ - "theme": DarkTheme, + respBody: []byte("original"), }) expect.NoError(t, err) - - respMod, ok := themedMid.impl.(ResponseModifier) - expect.True(t, ok) - expect.NoError(t, modifyResponseWithBodyRewriteGate(respMod, resp)) - - data, err := io.ReadAll(resp.Body) - expect.NoError(t, err) - expect.Equal(t, string(data), "original") + expect.Equal(t, string(result.Data), "original") } diff --git a/internal/net/gphttp/middleware/modify_html.go b/internal/net/gphttp/middleware/modify_html.go index ede67a54..ea609daa 100644 --- a/internal/net/gphttp/middleware/modify_html.go +++ b/internal/net/gphttp/middleware/modify_html.go @@ -22,9 +22,7 @@ type modifyHTML struct { var ModifyHTML = NewMiddleware[modifyHTML]() -func (*modifyHTML) requiresBodyRewrite() bool { - return true -} +func (*modifyHTML) isBodyResponseModifier() {} func (m *modifyHTML) before(_ http.ResponseWriter, req *http.Request) bool { req.Header.Set("Accept-Encoding", "identity") diff --git a/internal/net/gphttp/middleware/themed.go b/internal/net/gphttp/middleware/themed.go index 17ef4c6c..c8c32707 100644 --- a/internal/net/gphttp/middleware/themed.go +++ b/internal/net/gphttp/middleware/themed.go @@ -54,9 +54,7 @@ func (m *themed) modifyResponse(resp *http.Response) error { return m.m.modifyResponse(resp) } -func (*themed) requiresBodyRewrite() bool { - return true -} +func (*themed) isBodyResponseModifier() {} func (m *themed) finalize() error { m.m.Target = "body"