From 429b0d9ce8f2cf0f5312bcdb28dd343de624f698 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 4 Dec 2025 17:26:15 +0800 Subject: [PATCH] fix(http): enhance Content-Length handling in ResponseModifier - Introduced origContentLength and bodyModified fields to track original content length and body modification status. - Updated ContentLength and ContentLengthStr methods to return accurate content length based on body modification state. - Adjusted Write and FlushRelease methods to ensure proper handling of Content-Length header. - Modified middleware to use the new ContentLengthStr method. --- internal/net/gphttp/middleware/middleware.go | 4 +- internal/route/rules/response_modifier.go | 70 ++++++++++++++------ internal/route/rules/vars_static.go | 2 +- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index 0c19520e..b34da314 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -210,8 +210,8 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * // override the response status code rm.WriteHeader(currentResp.StatusCode) - // overriding the response header is not necessary - // modifyResponse is supposed to write to Header directly instead of assigning a new header map) + // overriding the response header + maps.Copy(rm.Header(), currentResp.Header) // override the content length and body if changed if currentResp.Body != currentBody { diff --git a/internal/route/rules/response_modifier.go b/internal/route/rules/response_modifier.go index 9d3683c8..b112f8a8 100644 --- a/internal/route/rules/response_modifier.go +++ b/internal/route/rules/response_modifier.go @@ -23,6 +23,9 @@ type ResponseModifier struct { statusCode int shared Cache + origContentLength int64 // from http.Response in ResponseAsRW + bodyModified bool + hijacked bool errs gperr.Builder @@ -64,8 +67,9 @@ func (r responseAsRW) Header() http.Header { func ResponseAsRW(resp *http.Response) *ResponseModifier { return &ResponseModifier{ - statusCode: resp.StatusCode, - w: responseAsRW{resp}, + statusCode: resp.StatusCode, + w: responseAsRW{resp}, + origContentLength: resp.ContentLength, } } @@ -121,6 +125,9 @@ func (rm *ResponseModifier) BodyReader() io.ReadCloser { } func (rm *ResponseModifier) ResetBody() { + if !rm.bodyModified { + return + } if rm.buf == nil { return } @@ -134,6 +141,8 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error { rm.buf.Reset() } + rm.bodyModified = true + _, err := io.Copy(rm.buf, r) if err != nil { return fmt.Errorf("failed to copy body: %w", err) @@ -143,12 +152,26 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error { } func (rm *ResponseModifier) ContentLength() int { - if rm.buf == nil { - return 0 + if !rm.bodyModified { + if rm.origContentLength > 0 { + return int(rm.origContentLength) + } + contentLength, _ := strconv.Atoi(rm.ContentLengthStr()) + return contentLength } return rm.buf.Len() } +func (rm *ResponseModifier) ContentLengthStr() string { + if !rm.bodyModified { + if rm.origContentLength > 0 { + return strconv.FormatInt(rm.origContentLength, 10) + } + return rm.w.Header().Get("Content-Length") + } + return strconv.Itoa(rm.buf.Len()) +} + func (rm *ResponseModifier) Content() []byte { if rm.buf == nil { return nil @@ -172,6 +195,11 @@ func (rm *ResponseModifier) Response() Response { } func (rm *ResponseModifier) Write(b []byte) (int, error) { + if len(b) == 0 { + return 0, nil + } + + rm.bodyModified = true if rm.buf == nil { rm.buf = rm.bufPool.GetBuffer() } @@ -200,26 +228,24 @@ func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (rm *ResponseModifier) FlushRelease() (int, error) { n := 0 if !rm.hijacked { - h := rm.w.Header() - // for k := range h { - // if strings.EqualFold(k, "content-length") { - // h.Del(k) - // } - // } - contentLength := rm.ContentLength() - h.Set("Content-Length", strconv.Itoa(rm.ContentLength())) - h.Del("Transfer-Encoding") - h.Del("Trailer") + if rm.bodyModified { + h := rm.w.Header() + h.Set("Content-Length", rm.ContentLengthStr()) + h.Del("Transfer-Encoding") + h.Del("Trailer") + } rm.w.WriteHeader(rm.StatusCode()) - if contentLength > 0 { - nn, werr := rm.w.Write(rm.Content()) - n += nn - if werr != nil { - rm.errs.Addf("write error: %w", werr) - } - if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { - rm.errs.Addf("flush error: %w", err) + if rm.bodyModified { + if content := rm.Content(); len(content) > 0 { + nn, werr := rm.w.Write(content) + n += nn + if werr != nil { + rm.errs.Addf("write error: %w", werr) + } + if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { + rm.errs.Addf("flush error: %w", err) + } } } } diff --git a/internal/route/rules/vars_static.go b/internal/route/rules/vars_static.go index d1434919..ca52c1ee 100644 --- a/internal/route/rules/vars_static.go +++ b/internal/route/rules/vars_static.go @@ -88,7 +88,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{ var staticRespVarSubsMap = map[string]respVarGetter{ VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") }, - VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) }, + VarRespContentLen: func(resp *ResponseModifier) string { return resp.ContentLengthStr() }, VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) }, }