fix(http): enhance Content-Length handling in ResponseModifier

- Introduced origContentLength and bodyModified fields to track original content length and body modification status.
- Updated ContentLength and ContentLengthStr methods to return accurate content length based on body modification state.
- Adjusted Write and FlushRelease methods to ensure proper handling of Content-Length header.
- Modified middleware to use the new ContentLengthStr method.
This commit is contained in:
yusing
2025-12-04 17:26:15 +08:00
parent 9cdc985fb0
commit c098fef615
3 changed files with 51 additions and 25 deletions

View File

@@ -210,8 +210,8 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
// override the response status code
rm.WriteHeader(currentResp.StatusCode)
// overriding the response header is not necessary
// modifyResponse is supposed to write to Header directly instead of assigning a new header map)
// overriding the response header
maps.Copy(rm.Header(), currentResp.Header)
// override the content length and body if changed
if currentResp.Body != currentBody {

View File

@@ -23,6 +23,9 @@ type ResponseModifier struct {
statusCode int
shared Cache
origContentLength int64 // from http.Response in ResponseAsRW
bodyModified bool
hijacked bool
errs gperr.Builder
@@ -64,8 +67,9 @@ func (r responseAsRW) Header() http.Header {
func ResponseAsRW(resp *http.Response) *ResponseModifier {
return &ResponseModifier{
statusCode: resp.StatusCode,
w: responseAsRW{resp},
statusCode: resp.StatusCode,
w: responseAsRW{resp},
origContentLength: resp.ContentLength,
}
}
@@ -121,6 +125,9 @@ func (rm *ResponseModifier) BodyReader() io.ReadCloser {
}
func (rm *ResponseModifier) ResetBody() {
if !rm.bodyModified {
return
}
if rm.buf == nil {
return
}
@@ -134,6 +141,8 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
rm.buf.Reset()
}
rm.bodyModified = true
_, err := io.Copy(rm.buf, r)
if err != nil {
return fmt.Errorf("failed to copy body: %w", err)
@@ -143,12 +152,26 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
}
func (rm *ResponseModifier) ContentLength() int {
if rm.buf == nil {
return 0
if !rm.bodyModified {
if rm.origContentLength > 0 {
return int(rm.origContentLength)
}
contentLength, _ := strconv.Atoi(rm.ContentLengthStr())
return contentLength
}
return rm.buf.Len()
}
func (rm *ResponseModifier) ContentLengthStr() string {
if !rm.bodyModified {
if rm.origContentLength > 0 {
return strconv.FormatInt(rm.origContentLength, 10)
}
return rm.w.Header().Get("Content-Length")
}
return strconv.Itoa(rm.buf.Len())
}
func (rm *ResponseModifier) Content() []byte {
if rm.buf == nil {
return nil
@@ -172,6 +195,11 @@ func (rm *ResponseModifier) Response() Response {
}
func (rm *ResponseModifier) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
rm.bodyModified = true
if rm.buf == nil {
rm.buf = rm.bufPool.GetBuffer()
}
@@ -200,26 +228,24 @@ func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (rm *ResponseModifier) FlushRelease() (int, error) {
n := 0
if !rm.hijacked {
h := rm.w.Header()
// for k := range h {
// if strings.EqualFold(k, "content-length") {
// h.Del(k)
// }
// }
contentLength := rm.ContentLength()
h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
h.Del("Transfer-Encoding")
h.Del("Trailer")
if rm.bodyModified {
h := rm.w.Header()
h.Set("Content-Length", rm.ContentLengthStr())
h.Del("Transfer-Encoding")
h.Del("Trailer")
}
rm.w.WriteHeader(rm.StatusCode())
if contentLength > 0 {
nn, werr := rm.w.Write(rm.Content())
n += nn
if werr != nil {
rm.errs.Addf("write error: %w", werr)
}
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
rm.errs.Addf("flush error: %w", err)
if rm.bodyModified {
if content := rm.Content(); len(content) > 0 {
nn, werr := rm.w.Write(content)
n += nn
if werr != nil {
rm.errs.Addf("write error: %w", werr)
}
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
rm.errs.Addf("flush error: %w", err)
}
}
}
}

View File

@@ -88,7 +88,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
var staticRespVarSubsMap = map[string]respVarGetter{
VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") },
VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) },
VarRespContentLen: func(resp *ResponseModifier) string { return resp.ContentLengthStr() },
VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
}