refactor(modifyhtml): improved memory manangement and response body handling

This commit is contained in:
yusing
2025-09-02 22:16:09 +08:00
parent 245dba034e
commit 4513c221d5

View File

@@ -9,19 +9,23 @@ import (
"github.com/PuerkitoBio/goquery" "github.com/PuerkitoBio/goquery"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
gphttp "github.com/yusing/go-proxy/internal/net/gphttp" gphttp "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/synk" "github.com/yusing/go-proxy/internal/utils/synk"
"golang.org/x/net/html" "golang.org/x/net/html"
) )
type modifyHTML struct { type modifyHTML struct {
Target string // css selector Target string // css selector
HTML string // html to inject HTML string // html to inject
Replace bool // replace the target element with the new html instead of appending it Replace bool // replace the target element with the new html instead of appending it
bytesPool *synk.BytesPool
} }
var ModifyHTML = NewMiddleware[modifyHTML]() var ModifyHTML = NewMiddleware[modifyHTML]()
var bytePool = synk.GetBytesPool() func (m *modifyHTML) setup() {
m.bytesPool = synk.GetBytesPool()
}
func (m *modifyHTML) before(_ http.ResponseWriter, req *http.Request) bool { func (m *modifyHTML) before(_ http.ResponseWriter, req *http.Request) bool {
req.Header.Set("Accept-Encoding", "") req.Header.Set("Accept-Encoding", "")
@@ -35,7 +39,8 @@ func (m *modifyHTML) modifyResponse(resp *http.Response) error {
return nil return nil
} }
content, err := io.ReadAll(resp.Body) // NOTE: do not put it in the defer, it will be used as resp.Body
content, release, err := utils.ReadAllBody(resp)
if err != nil { if err != nil {
resp.Body.Close() resp.Body.Close()
return err return err
@@ -65,21 +70,20 @@ func (m *modifyHTML) modifyResponse(resp *http.Response) error {
ele.First().AppendHtml(m.HTML) ele.First().AppendHtml(m.HTML)
} }
h, err := buildHTML(doc) buf := bytes.NewBuffer(content[:0])
err = buildHTML(m, doc, buf)
if err != nil { if err != nil {
return err return err
} }
resp.ContentLength = int64(len(h)) resp.ContentLength = int64(buf.Len())
resp.Header.Set("Content-Length", strconv.Itoa(len(h))) resp.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
resp.Header.Set("Content-Type", "text/html; charset=utf-8") resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Body = io.NopCloser(bytes.NewReader(h)) resp.Body = utils.NewHookCloser(io.NopCloser(bytes.NewReader(buf.Bytes())), release)
return nil return nil
} }
// copied and modified from (*goquery.Selection).Html() // copied and modified from (*goquery.Selection).Html()
func buildHTML(s *goquery.Document) (ret []byte, err error) { func buildHTML(m *modifyHTML, s *goquery.Document, buf *bytes.Buffer) error {
buf := bytes.NewBuffer(bytePool.Get())
// Merge all head nodes into one // Merge all head nodes into one
headNodes := s.Find("head") headNodes := s.Find("head")
if headNodes.Length() > 1 { if headNodes.Length() > 1 {
@@ -100,14 +104,13 @@ func buildHTML(s *goquery.Document) (ret []byte, err error) {
if len(s.Nodes) > 0 { if len(s.Nodes) > 0 {
for c := s.Nodes[0].FirstChild; c != nil; c = c.NextSibling { for c := s.Nodes[0].FirstChild; c != nil; c = c.NextSibling {
err = html.Render(buf, c) err := html.Render(buf, c)
if err != nil { if err != nil {
return return err
} }
} }
ret = buf.Bytes()
} }
return return nil
} }
func fullURL(req *http.Request) string { func fullURL(req *http.Request) string {