mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-17 23:14:21 +01:00
Refactor ServeHTTP to properly handle response body mutations by: - Using ResponseModifier to capture response before modification - Reading body content and allowing middleware to modify it - Writing modified body back if changed during modification - Ensuring proper order: RequestModifier before, ResponseModifier after next() Previously, httputils.NewModifyResponseWriter did not correctly handle body mutations. The new implementation captures the full response, allows modification via modifyResponse, and properly writes back any changes to the body. Add BodyReader() and SetBody() methods to ResponseModifier to support reading and replacing response body content.
241 lines
5.4 KiB
Go
241 lines
5.4 KiB
Go
package rules
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
gperr "github.com/yusing/goutils/errs"
|
|
"github.com/yusing/goutils/synk"
|
|
)
|
|
|
|
type ResponseModifier struct {
|
|
bufPool synk.UnsizedBytesPool
|
|
|
|
w http.ResponseWriter
|
|
buf *bytes.Buffer
|
|
statusCode int
|
|
shared Cache
|
|
|
|
hijacked bool
|
|
|
|
errs gperr.Builder
|
|
}
|
|
|
|
type Response struct {
|
|
StatusCode int
|
|
Header http.Header
|
|
}
|
|
|
|
func unwrapResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|
for {
|
|
switch ww := w.(type) {
|
|
case *ResponseModifier:
|
|
return ww
|
|
case interface{ Unwrap() http.ResponseWriter }:
|
|
w = ww.Unwrap()
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
type responseAsRW struct {
|
|
resp *http.Response
|
|
}
|
|
|
|
func (r responseAsRW) WriteHeader(code int) {
|
|
log.Error().Msg("write header after response has been created")
|
|
}
|
|
|
|
func (r responseAsRW) Write(b []byte) (int, error) {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
|
|
func (r responseAsRW) Header() http.Header {
|
|
return r.resp.Header
|
|
}
|
|
|
|
func ResponseAsRW(resp *http.Response) *ResponseModifier {
|
|
return &ResponseModifier{
|
|
statusCode: resp.StatusCode,
|
|
w: responseAsRW{resp},
|
|
}
|
|
}
|
|
|
|
// GetInitResponseModifier returns the response modifier for the given response writer.
|
|
// If the response writer is already wrapped, it will return the wrapped response modifier.
|
|
// Otherwise, it will return a new response modifier.
|
|
func GetInitResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|
if rm := unwrapResponseModifier(w); rm != nil {
|
|
return rm
|
|
}
|
|
return NewResponseModifier(w)
|
|
}
|
|
|
|
// GetSharedData returns the shared data for the given response writer.
|
|
// It will initialize the shared data if not initialized.
|
|
func GetSharedData(w http.ResponseWriter) Cache {
|
|
rm := GetInitResponseModifier(w)
|
|
if rm.shared == nil {
|
|
rm.shared = NewCache()
|
|
}
|
|
return rm.shared
|
|
}
|
|
|
|
// NewResponseModifier returns a new response modifier for the given response writer.
|
|
//
|
|
// It should only be called once, at the very beginning of the request.
|
|
func NewResponseModifier(w http.ResponseWriter) *ResponseModifier {
|
|
return &ResponseModifier{
|
|
bufPool: synk.GetUnsizedBytesPool(),
|
|
w: w,
|
|
}
|
|
}
|
|
|
|
func (rm *ResponseModifier) BufPool() synk.UnsizedBytesPool {
|
|
return rm.bufPool
|
|
}
|
|
|
|
// func (rm *ResponseModifier) Unwrap() http.ResponseWriter {
|
|
// return rm.w
|
|
// }
|
|
|
|
func (rm *ResponseModifier) WriteHeader(code int) {
|
|
rm.statusCode = code
|
|
}
|
|
|
|
// BodyReader returns a reader for the response body.
|
|
// Every call to this function will return a new reader that starts from the beginning of the buffer.
|
|
func (rm *ResponseModifier) BodyReader() io.ReadCloser {
|
|
if rm.buf == nil {
|
|
return io.NopCloser(bytes.NewReader(nil))
|
|
}
|
|
return io.NopCloser(bytes.NewReader(rm.buf.Bytes()))
|
|
}
|
|
|
|
func (rm *ResponseModifier) ResetBody() {
|
|
if rm.buf == nil {
|
|
return
|
|
}
|
|
rm.buf.Reset()
|
|
}
|
|
|
|
func (rm *ResponseModifier) SetBody(r io.ReadCloser) error {
|
|
if rm.buf == nil {
|
|
rm.buf = rm.bufPool.GetBuffer()
|
|
} else {
|
|
rm.buf.Reset()
|
|
}
|
|
|
|
_, err := io.Copy(rm.buf, r)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to copy body: %w", err)
|
|
}
|
|
r.Close()
|
|
return nil
|
|
}
|
|
|
|
func (rm *ResponseModifier) ContentLength() int {
|
|
if rm.buf == nil {
|
|
return 0
|
|
}
|
|
return rm.buf.Len()
|
|
}
|
|
|
|
func (rm *ResponseModifier) Content() []byte {
|
|
if rm.buf == nil {
|
|
return nil
|
|
}
|
|
return rm.buf.Bytes()
|
|
}
|
|
|
|
func (rm *ResponseModifier) StatusCode() int {
|
|
if rm.statusCode == 0 {
|
|
return http.StatusOK
|
|
}
|
|
return rm.statusCode
|
|
}
|
|
|
|
func (rm *ResponseModifier) Header() http.Header {
|
|
return rm.w.Header()
|
|
}
|
|
|
|
func (rm *ResponseModifier) Response() Response {
|
|
return Response{StatusCode: rm.StatusCode(), Header: rm.Header()}
|
|
}
|
|
|
|
func (rm *ResponseModifier) Write(b []byte) (int, error) {
|
|
if rm.buf == nil {
|
|
rm.buf = rm.bufPool.GetBuffer()
|
|
}
|
|
return rm.buf.Write(b)
|
|
}
|
|
|
|
// AppendError appends an error to the response modifier
|
|
// the error will be formatted as "rule <rule.Name> error: <err>"
|
|
//
|
|
// It will be aggregated and returned in FlushRelease.
|
|
func (rm *ResponseModifier) AppendError(rule Rule, err error) {
|
|
rm.errs.Addf("rule %q error: %w", rule.Name, err)
|
|
}
|
|
|
|
func (rm *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
if hijacker, ok := rm.w.(http.Hijacker); ok {
|
|
rm.hijacked = true
|
|
return hijacker.Hijack()
|
|
}
|
|
return nil, nil, errors.New("hijack not supported")
|
|
}
|
|
|
|
// FlushRelease flushes the response modifier and releases the resources
|
|
// it returns the number of bytes written and the aggregated error
|
|
// if there is any error (rule errors or write error), it will be returned
|
|
func (rm *ResponseModifier) FlushRelease() (int, error) {
|
|
n := 0
|
|
if !rm.hijacked {
|
|
h := rm.w.Header()
|
|
// for k := range h {
|
|
// if strings.EqualFold(k, "content-length") {
|
|
// h.Del(k)
|
|
// }
|
|
// }
|
|
contentLength := rm.ContentLength()
|
|
h.Set("Content-Length", strconv.Itoa(rm.ContentLength()))
|
|
h.Del("Transfer-Encoding")
|
|
h.Del("Trailer")
|
|
rm.w.WriteHeader(rm.StatusCode())
|
|
|
|
if contentLength > 0 {
|
|
nn, werr := rm.w.Write(rm.Content())
|
|
n += nn
|
|
if werr != nil {
|
|
rm.errs.Addf("write error: %w", werr)
|
|
}
|
|
if err := http.NewResponseController(rm.w).Flush(); err != nil && !errors.Is(err, http.ErrNotSupported) {
|
|
rm.errs.Addf("flush error: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// release the buffer and reset the pointers
|
|
if rm.buf != nil {
|
|
rm.bufPool.PutBuffer(rm.buf)
|
|
rm.buf = nil
|
|
}
|
|
|
|
// release the shared data
|
|
if rm.shared != nil {
|
|
rm.shared.Release()
|
|
rm.shared = nil
|
|
}
|
|
|
|
return n, rm.errs.Error()
|
|
}
|