diff --git a/goutils b/goutils index 4f468fdc..4912690d 160000 --- a/goutils +++ b/goutils @@ -1 +1 @@ -Subproject commit 4f468fdce8af642e7c53bbb34b0b536d2b9206e2 +Subproject commit 4912690d409dd2466e405b2ca8570ccf87759bb9 diff --git a/internal/net/gphttp/middleware/middleware.go b/internal/net/gphttp/middleware/middleware.go index 54651327..817c401d 100644 --- a/internal/net/gphttp/middleware/middleware.go +++ b/internal/net/gphttp/middleware/middleware.go @@ -3,9 +3,11 @@ package middleware import ( "fmt" "maps" + "mime" "net/http" "reflect" "sort" + "strconv" "strings" "github.com/bytedance/sonic" @@ -17,6 +19,12 @@ import ( "github.com/yusing/goutils/http/reverseproxy" ) +const ( + mimeEventStream = "text/event-stream" + headerContentType = "Content-Type" + maxModifiableBody = 4 * 1024 * 1024 // 4MB +) + type ( ReverseProxy = reverseproxy.ReverseProxy ProxyRequest = reverseproxy.ProxyRequest @@ -190,78 +198,112 @@ func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r * } } - if httpheaders.IsWebsocket(r.Header) || r.Header.Get("Accept") == "text/event-stream" { + if httpheaders.IsWebsocket(r.Header) || strings.Contains(strings.ToLower(r.Header.Get("Accept")), mimeEventStream) { next(w, r) return } - if exec, ok := m.impl.(ResponseModifier); ok { - rm := httputils.NewResponseModifier(w) - defer func() { - _, err := rm.FlushRelease() - if err != nil { - m.LogError(r).Err(err).Msg("failed to flush response") - } - }() - next(rm, r) - - currentBody := rm.BodyReader() - currentResp := &http.Response{ - StatusCode: rm.StatusCode(), - Header: rm.Header(), - ContentLength: int64(rm.ContentLength()), - Body: currentBody, - Request: r, - } - allowBodyModification := canModifyResponseBody(currentResp) - respToModify := currentResp - if !allowBodyModification { - shadow := *currentResp - shadow.Body = eofReader{} - respToModify = &shadow - } - if err := exec.modifyResponse(respToModify); err != nil { - log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response") - } - - // override the response status code - rm.WriteHeader(respToModify.StatusCode) - - // overriding the response header - maps.Copy(rm.Header(), respToModify.Header) - - // override the content length and body if changed - if respToModify.Body != currentBody { - if allowBodyModification { - if err := rm.SetBody(respToModify.Body); err != nil { - m.LogError(r).Err(err).Msg("failed to set response body") - } - } else { - respToModify.Body.Close() - } - } - } else { + exec, ok := m.impl.(ResponseModifier) + if !ok { next(w, r) + return + } + + lrm := httputils.NewLazyResponseModifier(w, canBufferAndModifyResponseBody) + lrm.SetMaxBufferedBytes(maxModifiableBody) + defer func() { + _, err := lrm.FlushRelease() + if err != nil { + m.LogError(r).Err(err).Msg("failed to flush response") + } + }() + next(lrm, r) + + // Skip modification if response wasn't buffered + if !lrm.IsBuffered() { + return + } + + rm := lrm.ResponseModifier() + currentBody := rm.BodyReader() + currentResp := &http.Response{ + StatusCode: rm.StatusCode(), + Header: rm.Header(), + ContentLength: int64(rm.ContentLength()), + Body: currentBody, + Request: r, + } + respToModify := currentResp + if err := exec.modifyResponse(respToModify); err != nil { + log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response") + return // skip modification if failed + } + + // override the response status code + rm.WriteHeader(respToModify.StatusCode) + + // overriding the response header + maps.Copy(rm.Header(), respToModify.Header) + + // override the body if changed + if respToModify.Body != currentBody { + err := rm.SetBody(respToModify.Body) + if err != nil { + m.LogError(r).Err(err).Msg("failed to set response body") + return // skip modification if failed + } } } -func canModifyResponseBody(resp *http.Response) bool { - if hasNonIdentityEncoding(resp.TransferEncoding) { +// canBufferAndModifyResponseBody checks if the response body can be buffered and modified. +// +// A body can be buffered and modified if: +// - The response is not a websocket and is not an event stream +// - The response has identity transfer encoding +// - The response has identity content encoding +// - The response has a content length +// - The content length is less than 4MB +// - The content type is text-like +func canBufferAndModifyResponseBody(respHeader http.Header) bool { + if httpheaders.IsWebsocket(respHeader) { return false } - if hasNonIdentityEncoding(resp.Header.Values("Transfer-Encoding")) { + contentType := respHeader.Get("Content-Type") + if contentType == "" { // safe default: skip if no content type return false } - if hasNonIdentityEncoding(resp.Header.Values("Content-Encoding")) { + contentType = strings.ToLower(contentType) + if strings.Contains(contentType, mimeEventStream) { return false } - return isTextLikeMediaType(string(httputils.GetContentType(resp.Header))) + // strip charset or any other parameters + contentType, _, err := mime.ParseMediaType(contentType) + if err != nil { // skip if invalid content type + return false + } + if hasNonIdentityEncoding(respHeader.Values("Transfer-Encoding")) { + return false + } + if hasNonIdentityEncoding(respHeader.Values("Content-Encoding")) { + return false + } + if contentLengthRaw := respHeader.Get("Content-Length"); contentLengthRaw != "" { + contentLength, err := strconv.ParseInt(contentLengthRaw, 10, 64) + if err != nil || contentLength >= maxModifiableBody { + return false + } + } + if !isTextLikeMediaType(contentType) { + return false + } + return true } func hasNonIdentityEncoding(values []string) bool { for _, value := range values { - for _, token := range strings.Split(value, ",") { - if strings.TrimSpace(token) == "" || strings.EqualFold(strings.TrimSpace(token), "identity") { + for token := range strings.SplitSeq(value, ",") { + token = strings.TrimSpace(token) + if token == "" || strings.EqualFold(token, "identity") { continue } return true diff --git a/internal/net/gphttp/middleware/middleware_chain.go b/internal/net/gphttp/middleware/middleware_chain.go index e6e4e801..6d192201 100644 --- a/internal/net/gphttp/middleware/middleware_chain.go +++ b/internal/net/gphttp/middleware/middleware_chain.go @@ -1,7 +1,7 @@ package middleware import ( - "maps" + "fmt" "net/http" "strconv" @@ -47,23 +47,53 @@ func (m *middlewareChain) modifyResponse(resp *http.Response) error { if len(m.modResps) == 0 { return nil } - allowBodyModification := canModifyResponseBody(resp) for i, mr := range m.modResps { - respToModify := resp - if !allowBodyModification { - shadow := *resp - shadow.Body = eofReader{} - respToModify = &shadow - } - if err := mr.modifyResponse(respToModify); err != nil { + if err := modifyResponseWithBodyRewriteGate(mr, resp); err != nil { return gperr.PrependSubject(err, strconv.Itoa(i)) } - if !allowBodyModification { - resp.StatusCode = respToModify.StatusCode - if respToModify.Header != nil { - maps.Copy(resp.Header, respToModify.Header) - } - } } return nil } + +func modifyResponseWithBodyRewriteGate(mr ResponseModifier, resp *http.Response) error { + originalBody := resp.Body + originalContentLength := resp.ContentLength + allowBodyRewrite := canBufferAndModifyResponseBody(responseHeaderForBodyRewriteGate(resp)) + + if err := mr.modifyResponse(resp); err != nil { + return err + } + + if allowBodyRewrite || resp.Body == originalBody { + return nil + } + + if resp.Body != nil { + if err := resp.Body.Close(); err != nil { + return fmt.Errorf("close rewritten body: %w", err) + } + } + if originalBody == nil || originalBody == http.NoBody { + resp.Body = http.NoBody + } else { + resp.Body = originalBody + } + resp.ContentLength = originalContentLength + if originalContentLength >= 0 { + resp.Header.Set("Content-Length", strconv.FormatInt(originalContentLength, 10)) + } else { + resp.Header.Del("Content-Length") + } + return nil +} + +func responseHeaderForBodyRewriteGate(resp *http.Response) http.Header { + h := resp.Header.Clone() + if len(resp.TransferEncoding) > 0 && len(h.Values("Transfer-Encoding")) == 0 { + h["Transfer-Encoding"] = append([]string(nil), resp.TransferEncoding...) + } + if resp.ContentLength >= 0 && h.Get("Content-Length") == "" { + h.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) + } + return h +} diff --git a/internal/net/gphttp/middleware/middleware_test.go b/internal/net/gphttp/middleware/middleware_test.go index 1242a63a..0b212561 100644 --- a/internal/net/gphttp/middleware/middleware_test.go +++ b/internal/net/gphttp/middleware/middleware_test.go @@ -3,6 +3,7 @@ package middleware import ( "io" "net/http" + "net/http/httptest" "strconv" "strings" "testing" @@ -127,9 +128,101 @@ func TestMiddlewareResponseRewriteGate(t *testing.T) { respStatus: http.StatusOK, }) expect.NoError(t, err) - expect.Equal(t, result.ResponseStatus, 418) + expect.Equal(t, result.ResponseStatus, http.StatusTeapot) expect.Equal(t, result.ResponseHeaders.Get("X-Rewrite"), "1") expect.Equal(t, string(result.Data), tc.expectBody) }) } } + +func TestMiddlewareResponseRewriteGateServeHTTP(t *testing.T) { + opts := OptionsRaw{ + "status_code": 418, + "header_key": "X-Rewrite", + "header_val": "1", + "body": "rewritten-body", + } + + tests := []struct { + name string + respHeaders http.Header + respBody string + expectStatusCode int + expectHeader string + expectBody string + }{ + { + name: "allow_body_rewrite_for_html", + respHeaders: http.Header{ + "Content-Type": []string{"text/html; charset=utf-8"}, + }, + respBody: "original", + expectStatusCode: http.StatusTeapot, + expectHeader: "1", + expectBody: "rewritten-body", + }, + { + name: "block_body_rewrite_for_binary_content", + respHeaders: http.Header{ + "Content-Type": []string{"application/octet-stream"}, + }, + respBody: "binary", + expectStatusCode: http.StatusOK, + expectHeader: "", + expectBody: "binary", + }, + { + name: "block_body_rewrite_for_transfer_encoded_html", + respHeaders: http.Header{ + "Content-Type": []string{"text/html"}, + "Transfer-Encoding": []string{"chunked"}, + }, + respBody: "original", + expectStatusCode: http.StatusOK, + expectHeader: "", + expectBody: "original", + }, + { + name: "block_body_rewrite_for_content_encoded_html", + respHeaders: http.Header{ + "Content-Type": []string{"text/html"}, + "Content-Encoding": []string{"gzip"}, + }, + respBody: "original", + expectStatusCode: http.StatusOK, + expectHeader: "", + expectBody: "original", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mid, err := responseRewrite.New(opts) + expect.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + rw := httptest.NewRecorder() + + next := func(w http.ResponseWriter, _ *http.Request) { + for key, values := range tc.respHeaders { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(tc.respBody)) + } + + mid.ServeHTTP(next, rw, req) + + resp := rw.Result() + defer resp.Body.Close() + data, readErr := io.ReadAll(resp.Body) + expect.NoError(t, readErr) + + expect.Equal(t, resp.StatusCode, tc.expectStatusCode) + expect.Equal(t, resp.Header.Get("X-Rewrite"), tc.expectHeader) + expect.Equal(t, string(data), tc.expectBody) + }) + } +}