Files

156 lines
4.4 KiB
Go

package middleware
import (
"bytes"
"io"
"net/http"
"strconv"
"github.com/PuerkitoBio/goquery"
"github.com/rs/zerolog/log"
httputils "github.com/yusing/goutils/http"
ioutils "github.com/yusing/goutils/io"
"github.com/yusing/goutils/synk"
"golang.org/x/net/html"
)
type modifyHTML struct {
Target string // css selector
HTML string // html to inject
Replace bool // replace the target element with the new html instead of appending it
}
var ModifyHTML = NewMiddleware[modifyHTML]()
func (m *modifyHTML) before(_ http.ResponseWriter, req *http.Request) bool {
req.Header.Set("Accept-Encoding", "identity")
return true
}
func readerWithRelease(b []byte, release func([]byte)) io.ReadCloser {
return ioutils.NewHookReadCloser(io.NopCloser(bytes.NewReader(b)), func() {
release(b)
})
}
type eofReader struct{}
func (eofReader) Read([]byte) (int, error) { return 0, io.EOF }
func (eofReader) Close() error { return nil }
// modifyResponse implements ResponseModifier.
func (m *modifyHTML) modifyResponse(resp *http.Response) error {
// Skip HEAD requests - no body to modify
if resp.Request.Method == http.MethodHead {
return nil
}
// including text/html and application/xhtml+xml
if !httputils.GetContentType(resp.Header).IsHTML() {
return nil
}
// Skip modification for streaming/chunked responses to avoid blocking reads
// Unknown content length or any transfer encoding indicates streaming.
// if resp.ContentLength < 0 || len(resp.TransferEncoding) > 0 {
// log.Debug().Str("url", fullURL(resp.Request)).Strs("transfer-encoding", resp.TransferEncoding).Msg("skipping modification for streaming/chunked response")
// return nil
// }
// NOTE: do not put it in the defer, it will be used as resp.Body
content, release, err := httputils.ReadAllBody(resp)
resp.Body.Close()
if err != nil {
log.Err(err).Str("url", fullURL(resp.Request)).Msg("failed to read response body")
// Fail open: do not abort the response. Return an empty body safely.
resp.ContentLength = 0
resp.Header.Set("Content-Length", "0")
resp.Header.Del("Transfer-Encoding")
resp.Header.Del("Trailer")
resp.Header.Del("Content-Encoding")
resp.Body = eofReader{}
return nil
}
doc, err := goquery.NewDocumentFromReader(bytes.NewReader(content))
if err != nil {
// invalid html, restore the original body
resp.Body = readerWithRelease(content, release)
log.Err(err).Str("url", fullURL(resp.Request)).Msg("invalid html found")
return nil
}
ele := doc.Find(m.Target)
if ele.Length() == 0 {
// no target found, restore the original body
resp.Body = readerWithRelease(content, release)
return nil
}
if m.Replace {
// replace all matching elements
ele.ReplaceWithHtml(m.HTML)
} else {
// append to the first matching element
ele.First().AppendHtml(m.HTML)
}
pool := synk.GetUnsizedBytesPool()
buf := pool.GetBuffer()
err = buildHTML(doc, buf)
if err != nil {
pool.PutBuffer(buf)
log.Err(err).Str("url", fullURL(resp.Request)).Msg("failed to build html")
// invalid html, restore the original body
resp.Body = readerWithRelease(content, release)
return err
}
release(content)
resp.ContentLength = int64(buf.Len())
resp.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
resp.Header.Del("Transfer-Encoding")
resp.Header.Del("Trailer")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Body = readerWithRelease(buf.Bytes(), func(_ []byte) {
pool.PutBuffer(buf)
})
return nil
}
// copied and modified from (*goquery.Selection).Html()
func buildHTML(s *goquery.Document, buf *bytes.Buffer) error {
// Merge all head nodes into one
headNodes := s.Find("head")
if headNodes.Length() > 1 {
// Get the first head node to merge everything into
firstHead := headNodes.First()
// Merge content from all other head nodes into the first one
headNodes.Slice(1, headNodes.Length()).Each(func(i int, otherHead *goquery.Selection) {
// Move all children from other head nodes to the first head
otherHead.Children().Each(func(j int, child *goquery.Selection) {
firstHead.AppendSelection(child)
})
})
// Remove the duplicate head nodes (keep only the first one)
headNodes.Slice(1, headNodes.Length()).Remove()
}
if len(s.Nodes) > 0 {
for c := s.Nodes[0].FirstChild; c != nil; c = c.NextSibling {
err := html.Render(buf, c)
if err != nil {
return err
}
}
}
return nil
}
func fullURL(req *http.Request) string {
return req.Host + req.RequestURI
}