mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
fix(http): enhance Content-Length handling in ResponseModifier
- Introduced origContentLength and bodyModified fields to track original content length and body modification status. - Updated ContentLength and ContentLengthStr methods to return accurate content length based on body modification state. - Adjusted Write and FlushRelease methods to ensure proper handling of Content-Length header. - Modified middleware to use the new ContentLengthStr method.
This commit is contained in:
@@ -210,8 +210,8 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
||||
// override the response status code
|
||||
rm.WriteHeader(currentResp.StatusCode)
|
||||
|
||||
// overriding the response header is not necessary
|
||||
// modifyResponse is supposed to write to Header directly instead of assigning a new header map)
|
||||
// overriding the response header
|
||||
maps.Copy(rm.Header(), currentResp.Header)
|
||||
|
||||
// override the content length and body if changed
|
||||
if currentResp.Body != currentBody {
|
||||
|
||||
@@ -23,6 +23,9 @@ type ResponseModifier struct {
|
||||
statusCode int
|
||||
shared Cache
|
||||
|
||||
origContentLength int64 // from http.Response in ResponseAsRW
|
||||
bodyModified bool
|
||||
|
||||
hijacked bool
|
||||
|
||||
errs gperr.Builder
|
||||
@@ -64,8 +67,9 @@ func (r responseAsRW) Header() http.Header {
|
||||
|
||||
func ResponseAsRW(resp *http.Response) *ResponseModifier {
|
||||
return &ResponseModifier{
|
||||
statusCode: resp.StatusCode,
|
||||
w: responseAsRW{resp},
|
||||
statusCode: resp.StatusCode,
|
||||
w: responseAsRW{resp},
|
||||
origContentLength: resp.ContentLength,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,6 +125,9 @@ func (rm *ResponseModifier) BodyReader() io.ReadCloser {
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) ResetBody() {
|
||||
if !rm.bodyModified {
|
||||
return
|
||||
}
|
||||
if rm.buf == nil {
|
||||
return
|
||||
}
|
||||
@@ -134,6 +141,8 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
|
||||
rm.buf.Reset()
|
||||
}
|
||||
|
||||
rm.bodyModified = true
|
||||
|
||||
_, err := io.Copy(rm.buf, r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to copy body: %w", err)
|
||||
@@ -143,12 +152,26 @@ func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) ContentLength() int {
|
||||
if rm.buf == nil {
|
||||
return 0
|
||||
if !rm.bodyModified {
|
||||
if rm.origContentLength > 0 {
|
||||
return int(rm.origContentLength)
|
||||
}
|
||||
contentLength, _ := strconv.Atoi(rm.ContentLengthStr())
|
||||
return contentLength
|
||||
}
|
||||
return rm.buf.Len()
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) ContentLengthStr() string {
|
||||
if !rm.bodyModified {
|
||||
if rm.origContentLength > 0 {
|
||||
return strconv.FormatInt(rm.origContentLength, 10)
|
||||
}
|
||||
return rm.w.Header().Get("Content-Length")
|
||||
}
|
||||
return strconv.Itoa(rm.buf.Len())
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) Content() []byte {
|
||||
if rm.buf == nil {
|
||||
return nil
|
||||
@@ -172,6 +195,11 @@ func (rm *ResponseModifier) Response() Response {
|
||||
}
|
||||
|
||||
func (rm *ResponseModifier) Write(b []byte) (int, error) {
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
rm.bodyModified = true
|
||||
if rm.buf == nil {
|
||||
rm.buf = rm.bufPool.GetBuffer()
|
||||
}
|
||||
@@ -200,26 +228,24 @@ func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
func (rm *ResponseModifier) FlushRelease() (int, error) {
|
||||
n := 0
|
||||
if !rm.hijacked {
|
||||
h := rm.w.Header()
|
||||
// for k := range h {
|
||||
// if strings.EqualFold(k, "content-length") {
|
||||
// h.Del(k)
|
||||
// }
|
||||
// }
|
||||
contentLength := rm.ContentLength()
|
||||
h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
|
||||
h.Del("Transfer-Encoding")
|
||||
h.Del("Trailer")
|
||||
if rm.bodyModified {
|
||||
h := rm.w.Header()
|
||||
h.Set("Content-Length", rm.ContentLengthStr())
|
||||
h.Del("Transfer-Encoding")
|
||||
h.Del("Trailer")
|
||||
}
|
||||
rm.w.WriteHeader(rm.StatusCode())
|
||||
|
||||
if contentLength > 0 {
|
||||
nn, werr := rm.w.Write(rm.Content())
|
||||
n += nn
|
||||
if werr != nil {
|
||||
rm.errs.Addf("write error: %w", werr)
|
||||
}
|
||||
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
rm.errs.Addf("flush error: %w", err)
|
||||
if rm.bodyModified {
|
||||
if content := rm.Content(); len(content) > 0 {
|
||||
nn, werr := rm.w.Write(content)
|
||||
n += nn
|
||||
if werr != nil {
|
||||
rm.errs.Addf("write error: %w", werr)
|
||||
}
|
||||
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
||||
rm.errs.Addf("flush error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ var staticReqVarSubsMap = map[string]reqVarGetter{
|
||||
|
||||
var staticRespVarSubsMap = map[string]respVarGetter{
|
||||
VarRespContentType: func(resp *ResponseModifier) string { return resp.Header().Get("Content-Type") },
|
||||
VarRespContentLen: func(resp *ResponseModifier) string { return strconv.Itoa(resp.ContentLength()) },
|
||||
VarRespContentLen: func(resp *ResponseModifier) string { return resp.ContentLengthStr() },
|
||||
VarRespStatusCode: func(resp *ResponseModifier) string { return strconv.Itoa(resp.StatusCode()) },
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user