Files
godoxy/internal/net/gphttp/middleware/modify_html.go
yusing dd33980d18 fix(middleware): allow HTML rewrite for chunked and unknown-length bodies
Relax response-body gating so HTML and XHTML can be buffered when
Transfer-Encoding is chunked-only, or when Content-Length is missing,
while still rejecting non-identity encodings that are not chunked HTML
and other non-HTML cases.

Update modifyHTML to cap reads for unknown length, splice the original
stream back when the cap is hit, and document the behavior in the
package README. Extend tests for themed middleware and the rewrite gate.
2026-04-23 17:23:39 +08:00

180 lines
4.9 KiB
Go

package middleware
import (
"bytes"
"io"
"net/http"
"strconv"
"github.com/PuerkitoBio/goquery"
"github.com/rs/zerolog/log"
httputils "github.com/yusing/goutils/http"
ioutils "github.com/yusing/goutils/io"
"github.com/yusing/goutils/synk"
"golang.org/x/net/html"
)
type modifyHTML struct {
Target string // css selector
HTML string // html to inject
Replace bool // replace the target element with the new html instead of appending it
}
var ModifyHTML = NewMiddleware[modifyHTML]()
func (*modifyHTML) isBodyResponseModifier() {}
func (m *modifyHTML) before(_ http.ResponseWriter, req *http.Request) bool {
req.Header.Set("Accept-Encoding", "identity")
return true
}
func readerWithRelease(b []byte, release func([]byte)) io.ReadCloser {
return ioutils.NewHookReadCloser(io.NopCloser(bytes.NewReader(b)), func() {
release(b)
})
}
type eofReader struct{}
func (eofReader) Read([]byte) (int, error) { return 0, io.EOF }
func (eofReader) Close() error { return nil }
// modifyResponse implements ResponseModifier.
func (m *modifyHTML) modifyResponse(resp *http.Response) error {
// Skip HEAD requests - no body to modify
if resp.Request.Method == http.MethodHead {
return nil
}
// including text/html and application/xhtml+xml
if !httputils.GetContentType(resp.Header).IsHTML() {
return nil
}
// NOTE: do not put it in the defer, it will be used as resp.Body
var (
content []byte
release func([]byte)
err error
)
if resp.ContentLength >= int64(maxModifiableBody) {
return nil
}
originalBody := resp.Body
if resp.ContentLength < 0 {
// tmp swap Body to LimitedReader
limited := io.LimitedReader{R: originalBody, N: int64(maxModifiableBody) + 1}
resp.Body = io.NopCloser(&limited)
content, release, err = httputils.ReadAllBody(resp)
// Successfully read N bytes
if err == nil && limited.N == 0 {
fullReader := io.NopCloser(io.MultiReader(bytes.NewReader(content), originalBody))
onClose := func() {
release(content)
_ = originalBody.Close()
}
resp.Body = ioutils.NewHookReadCloser(fullReader, onClose)
resp.ContentLength = -1
resp.Header.Del("Content-Length")
return nil
}
} else {
content, release, err = httputils.ReadAllBody(resp)
}
_ = originalBody.Close()
if err != nil {
log.Err(err).Str("url", fullURL(resp.Request)).Msg("failed to read response body")
// Fail open: do not abort the response. Return an empty body safely.
resp.ContentLength = 0
resp.Header.Set("Content-Length", "0")
resp.Header.Del("Transfer-Encoding")
resp.Header.Del("Trailer")
resp.Header.Del("Content-Encoding")
resp.Body = eofReader{}
return nil
}
doc, err := goquery.NewDocumentFromReader(bytes.NewReader(content))
if err != nil {
// invalid html, restore the original body
resp.Body = readerWithRelease(content, release)
log.Err(err).Str("url", fullURL(resp.Request)).Msg("invalid html found")
return nil
}
ele := doc.Find(m.Target)
if ele.Length() == 0 {
// no target found, restore the original body
resp.Body = readerWithRelease(content, release)
return nil
}
if m.Replace {
// replace all matching elements
ele.ReplaceWithHtml(m.HTML)
} else {
// append to the first matching element
ele.First().AppendHtml(m.HTML)
}
pool := synk.GetUnsizedBytesPool()
buf := pool.GetBuffer()
err = buildHTML(doc, buf)
if err != nil {
pool.PutBuffer(buf)
log.Err(err).Str("url", fullURL(resp.Request)).Msg("failed to build html")
// invalid html, restore the original body
resp.Body = readerWithRelease(content, release)
return err
}
release(content)
resp.ContentLength = int64(buf.Len())
resp.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
resp.Header.Del("Transfer-Encoding")
resp.Header.Del("Trailer")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Body = readerWithRelease(buf.Bytes(), func(_ []byte) {
pool.PutBuffer(buf)
})
return nil
}
// copied and modified from (*goquery.Selection).Html()
func buildHTML(s *goquery.Document, buf *bytes.Buffer) error {
// Merge all head nodes into one
headNodes := s.Find("head")
if headNodes.Length() > 1 {
// Get the first head node to merge everything into
firstHead := headNodes.First()
// Merge content from all other head nodes into the first one
headNodes.Slice(1, headNodes.Length()).Each(func(i int, otherHead *goquery.Selection) {
// Move all children from other head nodes to the first head
otherHead.Children().Each(func(j int, child *goquery.Selection) {
firstHead.AppendSelection(child)
})
})
// Remove the duplicate head nodes (keep only the first one)
headNodes.Slice(1, headNodes.Length()).Remove()
}
if len(s.Nodes) > 0 {
for c := s.Nodes[0].FirstChild; c != nil; c = c.NextSibling {
err := html.Render(buf, c)
if err != nil {
return err
}
}
}
return nil
}
func fullURL(req *http.Request) string {
return req.Host + req.RequestURI
}