mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-28 03:51:08 +01:00
fix(middleware): restore SSE streaming for POST endpoints (regression in v0.27.0) (#206)
* fix(middleware): restore SSE streaming for POST endpoints Regression introduced in16935865(v0.27.0). Before that commit, LazyResponseModifier only buffered HTML responses and let everything else pass through via the IsBuffered() early return. The refactor replaced it with NewResponseModifier which unconditionally buffers all writes until FlushRelease() fires after the handler returns. That kills real-time streaming for any SSE endpoint that uses POST. The existing bypass at ServeHTTP line 193 only fires when the *request* carries Accept: text/event-stream. That works for browser EventSource (which always sets that header) but not for programmatic fetch() calls, which set Content-Type: application/json on the request and only emit Content-Type: text/event-stream on the *response*. Fix: introduce ssePassthroughWriter, a thin http.ResponseWriter wrapper that sits in front of the ResponseModifier. It watches for Content-Type: text/event-stream in the response headers at the moment WriteHeader or the first Write is called. Once detected it copies the buffered headers to the real writer and switches all subsequent writes to pass directly through with an immediate Flush(), bypassing the ResponseModifier buffer entirely. Also tighten the Accept header check from == to strings.Contains so that Accept: text/event-stream, */* is handled correctly. Reported against Dockhand (https://github.com/Finsys/dockhand) where container update progress, image pull logs and vulnerability scan output all stopped streaming after users upgraded to GoDoxy v0.27.0. GET SSE endpoints (container logs) continued to work because browsers send Accept: text/event-stream for EventSource connections. * fix(middleware): make Content-Type SSE check case-insensitive * refactor(middleware): extract Content-Type into a named constant * fix(middleware): enhance safe guard to avoid buffering SSE, WS and large bodies Reverts some changes in16935865and apply more rubust handling. Use a lazy response modifier that buffers only when the response is safe to mutate. This prevents middleware from intercepting websocket/SSE streams, encoded payloads, and non-text or oversized responses. Set a 4MB max buffered size and gate buffering via response headers (content type, transfer/content encoding, and content length). Skip mutation when a response is not buffered or mutation setup fails, and simplify chained response modifiers to operate on the same response. Also update the goutils submodule for max body limit support. --------- Co-authored-by: yusing <yusing.wys@gmail.com>
This commit is contained in:
committed by
GitHub
parent
79327e98bd
commit
1bd8b5a696
@@ -3,9 +3,11 @@ package middleware
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"mime"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
@@ -17,6 +19,12 @@ import (
|
||||
"github.com/yusing/goutils/http/reverseproxy"
|
||||
)
|
||||
|
||||
const (
|
||||
mimeEventStream = "text/event-stream"
|
||||
headerContentType = "Content-Type"
|
||||
maxModifiableBody = 4 * 1024 * 1024 // 4MB
|
||||
)
|
||||
|
||||
type (
|
||||
ReverseProxy = reverseproxy.ReverseProxy
|
||||
ProxyRequest = reverseproxy.ProxyRequest
|
||||
@@ -190,78 +198,112 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *
|
||||
}
|
||||
}
|
||||
|
||||
if httpheaders.IsWebsocket(r.Header) || r.Header.Get("Accept") == "text/event-stream" {
|
||||
if httpheaders.IsWebsocket(r.Header) || strings.Contains(strings.ToLower(r.Header.Get("Accept")), mimeEventStream) {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if exec, ok := m.impl.(ResponseModifier); ok {
|
||||
rm := httputils.NewResponseModifier(w)
|
||||
defer func() {
|
||||
_, err := rm.FlushRelease()
|
||||
if err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to flush response")
|
||||
}
|
||||
}()
|
||||
next(rm, r)
|
||||
|
||||
currentBody := rm.BodyReader()
|
||||
currentResp := &http.Response{
|
||||
StatusCode: rm.StatusCode(),
|
||||
Header: rm.Header(),
|
||||
ContentLength: int64(rm.ContentLength()),
|
||||
Body: currentBody,
|
||||
Request: r,
|
||||
}
|
||||
allowBodyModification := canModifyResponseBody(currentResp)
|
||||
respToModify := currentResp
|
||||
if !allowBodyModification {
|
||||
shadow := *currentResp
|
||||
shadow.Body = eofReader{}
|
||||
respToModify = &shadow
|
||||
}
|
||||
if err := exec.modifyResponse(respToModify); err != nil {
|
||||
log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response")
|
||||
}
|
||||
|
||||
// override the response status code
|
||||
rm.WriteHeader(respToModify.StatusCode)
|
||||
|
||||
// overriding the response header
|
||||
maps.Copy(rm.Header(), respToModify.Header)
|
||||
|
||||
// override the content length and body if changed
|
||||
if respToModify.Body != currentBody {
|
||||
if allowBodyModification {
|
||||
if err := rm.SetBody(respToModify.Body); err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to set response body")
|
||||
}
|
||||
} else {
|
||||
respToModify.Body.Close()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
exec, ok := m.impl.(ResponseModifier)
|
||||
if !ok {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
lrm := httputils.NewLazyResponseModifier(w, canBufferAndModifyResponseBody)
|
||||
lrm.SetMaxBufferedBytes(maxModifiableBody)
|
||||
defer func() {
|
||||
_, err := lrm.FlushRelease()
|
||||
if err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to flush response")
|
||||
}
|
||||
}()
|
||||
next(lrm, r)
|
||||
|
||||
// Skip modification if response wasn't buffered
|
||||
if !lrm.IsBuffered() {
|
||||
return
|
||||
}
|
||||
|
||||
rm := lrm.ResponseModifier()
|
||||
currentBody := rm.BodyReader()
|
||||
currentResp := &http.Response{
|
||||
StatusCode: rm.StatusCode(),
|
||||
Header: rm.Header(),
|
||||
ContentLength: int64(rm.ContentLength()),
|
||||
Body: currentBody,
|
||||
Request: r,
|
||||
}
|
||||
respToModify := currentResp
|
||||
if err := exec.modifyResponse(respToModify); err != nil {
|
||||
log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response")
|
||||
return // skip modification if failed
|
||||
}
|
||||
|
||||
// override the response status code
|
||||
rm.WriteHeader(respToModify.StatusCode)
|
||||
|
||||
// overriding the response header
|
||||
maps.Copy(rm.Header(), respToModify.Header)
|
||||
|
||||
// override the body if changed
|
||||
if respToModify.Body != currentBody {
|
||||
err := rm.SetBody(respToModify.Body)
|
||||
if err != nil {
|
||||
m.LogError(r).Err(err).Msg("failed to set response body")
|
||||
return // skip modification if failed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func canModifyResponseBody(resp *http.Response) bool {
|
||||
if hasNonIdentityEncoding(resp.TransferEncoding) {
|
||||
// canBufferAndModifyResponseBody checks if the response body can be buffered and modified.
|
||||
//
|
||||
// A body can be buffered and modified if:
|
||||
// - The response is not a websocket and is not an event stream
|
||||
// - The response has identity transfer encoding
|
||||
// - The response has identity content encoding
|
||||
// - The response has a content length
|
||||
// - The content length is less than 4MB
|
||||
// - The content type is text-like
|
||||
func canBufferAndModifyResponseBody(respHeader http.Header) bool {
|
||||
if httpheaders.IsWebsocket(respHeader) {
|
||||
return false
|
||||
}
|
||||
if hasNonIdentityEncoding(resp.Header.Values("Transfer-Encoding")) {
|
||||
contentType := respHeader.Get("Content-Type")
|
||||
if contentType == "" { // safe default: skip if no content type
|
||||
return false
|
||||
}
|
||||
if hasNonIdentityEncoding(resp.Header.Values("Content-Encoding")) {
|
||||
contentType = strings.ToLower(contentType)
|
||||
if strings.Contains(contentType, mimeEventStream) {
|
||||
return false
|
||||
}
|
||||
return isTextLikeMediaType(string(httputils.GetContentType(resp.Header)))
|
||||
// strip charset or any other parameters
|
||||
contentType, _, err := mime.ParseMediaType(contentType)
|
||||
if err != nil { // skip if invalid content type
|
||||
return false
|
||||
}
|
||||
if hasNonIdentityEncoding(respHeader.Values("Transfer-Encoding")) {
|
||||
return false
|
||||
}
|
||||
if hasNonIdentityEncoding(respHeader.Values("Content-Encoding")) {
|
||||
return false
|
||||
}
|
||||
if contentLengthRaw := respHeader.Get("Content-Length"); contentLengthRaw != "" {
|
||||
contentLength, err := strconv.ParseInt(contentLengthRaw, 10, 64)
|
||||
if err != nil || contentLength >= maxModifiableBody {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if !isTextLikeMediaType(contentType) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasNonIdentityEncoding(values []string) bool {
|
||||
for _, value := range values {
|
||||
for _, token := range strings.Split(value, ",") {
|
||||
if strings.TrimSpace(token) == "" || strings.EqualFold(strings.TrimSpace(token), "identity") {
|
||||
for token := range strings.SplitSeq(value, ",") {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" || strings.EqualFold(token, "identity") {
|
||||
continue
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -47,23 +47,53 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error {
|
||||
if len(m.modResps) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowBodyModification := canModifyResponseBody(resp)
|
||||
for i, mr := range m.modResps {
|
||||
respToModify := resp
|
||||
if !allowBodyModification {
|
||||
shadow := *resp
|
||||
shadow.Body = eofReader{}
|
||||
respToModify = &shadow
|
||||
}
|
||||
if err := mr.modifyResponse(respToModify); err != nil {
|
||||
if err := modifyResponseWithBodyRewriteGate(mr, resp); err != nil {
|
||||
return gperr.PrependSubject(err, strconv.Itoa(i))
|
||||
}
|
||||
if !allowBodyModification {
|
||||
resp.StatusCode = respToModify.StatusCode
|
||||
if respToModify.Header != nil {
|
||||
maps.Copy(resp.Header, respToModify.Header)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func modifyResponseWithBodyRewriteGate(mr ResponseModifier, resp *http.Response) error {
|
||||
originalBody := resp.Body
|
||||
originalContentLength := resp.ContentLength
|
||||
allowBodyRewrite := canBufferAndModifyResponseBody(responseHeaderForBodyRewriteGate(resp))
|
||||
|
||||
if err := mr.modifyResponse(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allowBodyRewrite || resp.Body == originalBody {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.Body != nil {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
return fmt.Errorf("close rewritten body: %w", err)
|
||||
}
|
||||
}
|
||||
if originalBody == nil || originalBody == http.NoBody {
|
||||
resp.Body = http.NoBody
|
||||
} else {
|
||||
resp.Body = originalBody
|
||||
}
|
||||
resp.ContentLength = originalContentLength
|
||||
if originalContentLength >= 0 {
|
||||
resp.Header.Set("Content-Length", strconv.FormatInt(originalContentLength, 10))
|
||||
} else {
|
||||
resp.Header.Del("Content-Length")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responseHeaderForBodyRewriteGate(resp *http.Response) http.Header {
|
||||
h := resp.Header.Clone()
|
||||
if len(resp.TransferEncoding) > 0 && len(h.Values("Transfer-Encoding")) == 0 {
|
||||
h["Transfer-Encoding"] = append([]string(nil), resp.TransferEncoding...)
|
||||
}
|
||||
if resp.ContentLength >= 0 && h.Get("Content-Length") == "" {
|
||||
h.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10))
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -127,9 +128,101 @@ func TestMiddlewareResponseRewriteGate(t *testing.T) {
|
||||
respStatus: http.StatusOK,
|
||||
})
|
||||
expect.NoError(t, err)
|
||||
expect.Equal(t, result.ResponseStatus, 418)
|
||||
expect.Equal(t, result.ResponseStatus, http.StatusTeapot)
|
||||
expect.Equal(t, result.ResponseHeaders.Get("X-Rewrite"), "1")
|
||||
expect.Equal(t, string(result.Data), tc.expectBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareResponseRewriteGateServeHTTP(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"status_code": 418,
|
||||
"header_key": "X-Rewrite",
|
||||
"header_val": "1",
|
||||
"body": "rewritten-body",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respHeaders http.Header
|
||||
respBody string
|
||||
expectStatusCode int
|
||||
expectHeader string
|
||||
expectBody string
|
||||
}{
|
||||
{
|
||||
name: "allow_body_rewrite_for_html",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"text/html; charset=utf-8"},
|
||||
},
|
||||
respBody: "<html><body>original</body></html>",
|
||||
expectStatusCode: http.StatusTeapot,
|
||||
expectHeader: "1",
|
||||
expectBody: "rewritten-body",
|
||||
},
|
||||
{
|
||||
name: "block_body_rewrite_for_binary_content",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"application/octet-stream"},
|
||||
},
|
||||
respBody: "binary",
|
||||
expectStatusCode: http.StatusOK,
|
||||
expectHeader: "",
|
||||
expectBody: "binary",
|
||||
},
|
||||
{
|
||||
name: "block_body_rewrite_for_transfer_encoded_html",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
"Transfer-Encoding": []string{"chunked"},
|
||||
},
|
||||
respBody: "<html><body>original</body></html>",
|
||||
expectStatusCode: http.StatusOK,
|
||||
expectHeader: "",
|
||||
expectBody: "<html><body>original</body></html>",
|
||||
},
|
||||
{
|
||||
name: "block_body_rewrite_for_content_encoded_html",
|
||||
respHeaders: http.Header{
|
||||
"Content-Type": []string{"text/html"},
|
||||
"Content-Encoding": []string{"gzip"},
|
||||
},
|
||||
respBody: "<html><body>original</body></html>",
|
||||
expectStatusCode: http.StatusOK,
|
||||
expectHeader: "",
|
||||
expectBody: "<html><body>original</body></html>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mid, err := responseRewrite.New(opts)
|
||||
expect.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
next := func(w http.ResponseWriter, _ *http.Request) {
|
||||
for key, values := range tc.respHeaders {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(tc.respBody))
|
||||
}
|
||||
|
||||
mid.ServeHTTP(next, rw, req)
|
||||
|
||||
resp := rw.Result()
|
||||
defer resp.Body.Close()
|
||||
data, readErr := io.ReadAll(resp.Body)
|
||||
expect.NoError(t, readErr)
|
||||
|
||||
expect.Equal(t, resp.StatusCode, tc.expectStatusCode)
|
||||
expect.Equal(t, resp.Header.Get("X-Rewrite"), tc.expectHeader)
|
||||
expect.Equal(t, string(data), tc.expectBody)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user