refactor(http): proper ResponseWriter and headers handling across files

This commit is contained in:
yusing
2025-10-28 14:43:10 +08:00
parent f29b69ff3b
commit d4dfec8293
3 changed files with 34 additions and 3 deletions

View File

@@ -45,13 +45,25 @@ func (m *modifyHTML) modifyResponse(resp *http.Response) error {
return nil return nil
} }
// Skip modification for streaming/chunked responses to avoid blocking reads
// Unknown content length or any transfer encoding indicates streaming.
if resp.ContentLength < 0 || len(resp.TransferEncoding) > 0 {
return nil
}
// NOTE: do not put it in the defer, it will be used as resp.Body // NOTE: do not put it in the defer, it will be used as resp.Body
content, release, err := httputils.ReadAllBody(resp) content, release, err := httputils.ReadAllBody(resp)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
log.Err(err).Str("url", fullURL(resp.Request)).Msg("failed to read response body") log.Err(err).Str("url", fullURL(resp.Request)).Msg("failed to read response body")
// Fail open: do not abort the response. Return an empty body safely.
resp.ContentLength = 0
resp.Header.Set("Content-Length", "0")
resp.Header.Del("Transfer-Encoding")
resp.Header.Del("Trailer")
resp.Header.Del("Content-Encoding")
resp.Body = eofReader{} resp.Body = eofReader{}
return err return nil
} }
doc, err := goquery.NewDocumentFromReader(bytes.NewReader(content)) doc, err := goquery.NewDocumentFromReader(bytes.NewReader(content))
@@ -92,6 +104,9 @@ func (m *modifyHTML) modifyResponse(resp *http.Response) error {
release(content) release(content)
resp.ContentLength = int64(buf.Len()) resp.ContentLength = int64(buf.Len())
resp.Header.Set("Content-Length", strconv.Itoa(buf.Len())) resp.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
resp.Header.Del("Transfer-Encoding")
resp.Header.Del("Trailer")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Content-Type", "text/html; charset=utf-8") resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Body = readerWithRelease(buf.Bytes(), func(_ []byte) { resp.Body = readerWithRelease(buf.Bytes(), func(_ []byte) {
pool.PutBuffer(buf) pool.PutBuffer(buf)

View File

@@ -158,6 +158,8 @@ func (rm *ResponseModifier) FlushRelease() (int, error) {
// } // }
contentLength := rm.ContentLength() contentLength := rm.ContentLength()
h.Set("Content-Length", strconv.Itoa(rm.ContentLength())) h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
h.Del("Transfer-Encoding")
h.Del("Trailer")
rm.w.WriteHeader(rm.StatusCode()) rm.w.WriteHeader(rm.StatusCode())
if contentLength > 0 { if contentLength > 0 {

View File

@@ -6,7 +6,7 @@ import (
"net/http" "net/http"
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
gperr "github.com/yusing/goutils/errs" "github.com/rs/zerolog/log"
) )
type ( type (
@@ -89,6 +89,11 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
if defaultRule.IsResponseRule() { if defaultRule.IsResponseRule() {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
rm := NewResponseModifier(w) rm := NewResponseModifier(w)
defer func() {
if _, err := rm.FlushRelease(); err != nil {
logError(err, r)
}
}()
w = rm w = rm
up(w, r) up(w, r)
err := defaultRule.Do.exec.Handle(w, r) err := defaultRule.Do.exec.Handle(w, r)
@@ -99,6 +104,11 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
rm := NewResponseModifier(w) rm := NewResponseModifier(w)
defer func() {
if _, err := rm.FlushRelease(); err != nil {
logError(err, r)
}
}()
w = rm w = rm
err := defaultRule.Do.exec.Handle(w, r) err := defaultRule.Do.exec.Handle(w, r)
if err == nil { if err == nil {
@@ -128,7 +138,7 @@ func (rules Rules) BuildHandler(up http.HandlerFunc) http.HandlerFunc {
rm := NewResponseModifier(w) rm := NewResponseModifier(w)
defer func() { defer func() {
if _, err := rm.FlushRelease(); err != nil { if _, err := rm.FlushRelease(); err != nil {
gperr.LogError("error executing rules", err) logError(err, r)
} }
}() }()
@@ -252,3 +262,7 @@ func (rule *Rule) Check(w http.ResponseWriter, r *http.Request) bool {
func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error { func (rule *Rule) Handle(w http.ResponseWriter, r *http.Request) error {
return rule.Do.exec.Handle(w, r) return rule.Do.exec.Handle(w, r)
} }
func logError(err error, r *http.Request) {
log.Err(err).Str("method", r.Method).Str("url", r.Host+r.URL.Path).Msg("error executing rules")
}