diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index 26284245..0c19520e 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -9,9 +9,9 @@ import ( "github.com/bytedance/sonic" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/yusing/godoxy/internal/route/rules" "github.com/yusing/godoxy/internal/serialization" gperr "github.com/yusing/goutils/errs" - httputils "github.com/yusing/goutils/http" "github.com/yusing/goutils/http/reverseproxy" ) @@ -184,17 +184,42 @@ func (m *Middleware) ModifyResponse(resp *http.Response) error { } func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) { - if exec, ok := m.impl.(ResponseModifier); ok { - w = httputils.NewModifyResponseWriter(w, r, func(resp *http.Response) error { - return exec.modifyResponse(resp) - }) - } if exec, ok := m.impl.(RequestModifier); ok { if proceed := exec.before(w, r); !proceed { return } } - next(w, r) + + if exec, ok := m.impl.(ResponseModifier); ok { + rm := rules.NewResponseModifier(w) + defer rm.FlushRelease() + next(rm, r) + + currentBody := rm.BodyReader() + currentResp := &http.Response{ + StatusCode: rm.StatusCode(), + Header: rm.Header(), + ContentLength: int64(rm.ContentLength()), + Body: currentBody, + Request: r, + } + if err := exec.modifyResponse(currentResp); err != nil { + log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response") + } + + // override the response status code + rm.WriteHeader(currentResp.StatusCode) + + // overriding the response header is not necessary + // modifyResponse is supposed to write to Header directly instead of assigning a new header map) + + // override the content length and body if changed + if currentResp.Body != currentBody { + rm.SetBody(currentResp.Body) + } + } else { + next(w, r) + } } func (m *Middleware) LogWarn(req *http.Request) *zerolog.Event { diff --git a/internal/route/rules/response_modifier.go b/internal/route/rules/response_modifier.go index 965f19f0..9d3683c8 100644 --- a/internal/route/rules/response_modifier.go +++ b/internal/route/rules/response_modifier.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "errors" + "fmt" "io" "net" "net/http" @@ -110,6 +111,15 @@ func (rm *ResponseModifier) WriteHeader(code int) { rm.statusCode = code } +// BodyReader returns a reader for the response body. +// Every call to this function will return a new reader that starts from the beginning of the buffer. +func (rm *ResponseModifier) BodyReader() io.ReadCloser { + if rm.buf == nil { + return io.NopCloser(bytes.NewReader(nil)) + } + return io.NopCloser(bytes.NewReader(rm.buf.Bytes())) +} + func (rm *ResponseModifier) ResetBody() { if rm.buf == nil { return @@ -117,6 +127,21 @@ func (rm *ResponseModifier) ResetBody() { rm.buf.Reset() } +func (rm *ResponseModifier) SetBody(r io.ReadCloser) error { + if rm.buf == nil { + rm.buf = rm.bufPool.GetBuffer() + } else { + rm.buf.Reset() + } + + _, err := io.Copy(rm.buf, r) + if err != nil { + return fmt.Errorf("failed to copy body: %w", err) + } + r.Close() + return nil +} + func (rm *ResponseModifier) ContentLength() int { if rm.buf == nil { return 0