fix(http): enhance Content-Length handling in ResponseModifier

- Introduced origContentLength and bodyModified fields to track original content length and body modification status.
- Updated ContentLength and ContentLengthStr methods to return accurate content length based on body modification state.
- Adjusted Write and FlushRelease methods to ensure proper handling of Content-Length header.
- Modified middleware to use the new ContentLengthStr method.
This commit is contained in:
yusing
2025-12-04 17:26:15 +08:00
parent a9adf79551
commit 429b0d9ce8
3 changed files with 51 additions and 25 deletions

View File

@@ -210,8 +210,8 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
// override the response status code // override the response status code
rm.WriteHeader(currentResp.StatusCode) rm.WriteHeader(currentResp.StatusCode)
// overriding the response header is not necessary // overriding the response header
// modifyResponse is supposed to write to Header directly instead of assigning a new header map) maps.Copy(rm.Header(), currentResp.Header)
// override the content length and body if changed // override the content length and body if changed
if currentResp.Body != currentBody { if currentResp.Body != currentBody {

View File

@@ -23,6 +23,9 @@ type ResponseModifier struct {
statusCode int statusCode int
shared Cache shared Cache
origContentLength int64 // from http.Response in ResponseAsRW
bodyModified bool
hijacked bool hijacked bool
errs gperr.Builder errs gperr.Builder
@@ -64,8 +67,9 @@ func (r responseAsRW) Header() http.Header {
func ResponseAsRW(resp *http.Response) *ResponseModifier { func ResponseAsRW(resp *http.Response) *ResponseModifier {
return &ResponseModifier{ return &ResponseModifier{
statusCode: resp.StatusCode, statusCode: resp.StatusCode,
w: responseAsRW{resp}, w: responseAsRW{resp},
origContentLength: resp.ContentLength,
} }
} }
@@ -121,6 +125,9 @@ func (rm *ResponseModifier) BodyReader() io.ReadCloser {
} }
func (rm *ResponseModifier) ResetBody() { func (rm *ResponseModifier) ResetBody() {
if !rm.bodyModified {
return
}
if rm.buf == nil { if rm.buf == nil {
return return
} }
@@ -134,6 +141,8 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
rm.buf.Reset() rm.buf.Reset()
} }
rm.bodyModified = true
_, err := io.Copy(rm.buf, r) _, err := io.Copy(rm.buf, r)
if err != nil { if err != nil {
return fmt.Errorf("failed to copy body: %w", err) return fmt.Errorf("failed to copy body: %w", err)
@@ -143,12 +152,26 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
} }
func (rm *ResponseModifier) ContentLength() int { func (rm *ResponseModifier) ContentLength() int {
if rm.buf == nil { if !rm.bodyModified {
return 0 if rm.origContentLength > 0 {
return int(rm.origContentLength)
}
contentLength, _ := strconv.Atoi(rm.ContentLengthStr())
return contentLength
} }
return rm.buf.Len() return rm.buf.Len()
} }
func (rm *ResponseModifier) ContentLengthStr() string {
if !rm.bodyModified {
if rm.origContentLength > 0 {
return strconv.FormatInt(rm.origContentLength, 10)
}
return rm.w.Header().Get("Content-Length")
}
return strconv.Itoa(rm.buf.Len())
}
func (rm *ResponseModifier) Content() []byte { func (rm *ResponseModifier) Content() []byte {
if rm.buf == nil { if rm.buf == nil {
return nil return nil
@@ -172,6 +195,11 @@ func (rm *ResponseModifier) Response() Response {
} }
func (rm *ResponseModifier) Write(b []byte) (int, error) { func (rm *ResponseModifier) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
rm.bodyModified = true
if rm.buf == nil { if rm.buf == nil {
rm.buf = rm.bufPool.GetBuffer() rm.buf = rm.bufPool.GetBuffer()
} }
@@ -200,26 +228,24 @@ func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (rm *ResponseModifier) FlushRelease() (int, error) { func (rm *ResponseModifier) FlushRelease() (int, error) {
n := 0 n := 0
if !rm.hijacked { if !rm.hijacked {
h := rm.w.Header() if rm.bodyModified {
// for k := range h { h := rm.w.Header()
// if strings.EqualFold(k, "content-length") { h.Set("Content-Length", rm.ContentLengthStr())
// h.Del(k) h.Del("Transfer-Encoding")
// } h.Del("Trailer")
// } }
contentLength := rm.ContentLength()
h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
h.Del("Transfer-Encoding")
h.Del("Trailer")
rm.w.WriteHeader(rm.StatusCode()) rm.w.WriteHeader(rm.StatusCode())
if contentLength > 0 { if rm.bodyModified {
nn, werr := rm.w.Write(rm.Content()) if content := rm.Content(); len(content) > 0 {
n += nn nn, werr := rm.w.Write(content)
if werr != nil { n += nn
rm.errs.Addf("write error: %w", werr) if werr != nil {
} rm.errs.Addf("write error: %w", werr)
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) { }
rm.errs.Addf("flush error: %w", err) if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
rm.errs.Addf("flush error: %w", err)
}
} }
} }
} }

View File

@@ -88,7 +88,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
var staticRespVarSubsMap = map[string]respVarGetter{ var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") }, VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") },
VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) }, VarRespContentLen: func(resp *ResponseModifier) string { return resp.ContentLengthStr() },
VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) }, VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
} }