mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-21 00:29:03 +01:00
Replace the rewrite requirement check with a BodyResponseModifier marker and treat header and body modifiers separately. This ensures header/status rewrites still apply when body rewriting is blocked (for binary, encoded, or chunked responses), while body changes are skipped safely. It also avoids body reset/close side effects and returns early on passthrough responses. Update middleware tests to cover split header/body behavior and themed middleware body-skip scenarios.
406 lines
10 KiB
Go
406 lines
10 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"maps"
|
|
"mime"
|
|
"net/http"
|
|
"reflect"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/yusing/godoxy/internal/serialization"
|
|
httputils "github.com/yusing/goutils/http"
|
|
"github.com/yusing/goutils/http/httpheaders"
|
|
"github.com/yusing/goutils/http/reverseproxy"
|
|
)
|
|
|
|
const (
|
|
mimeEventStream = "text/event-stream"
|
|
headerContentType = "Content-Type"
|
|
maxModifiableBody = 4 * 1024 * 1024 // 4MB
|
|
)
|
|
|
|
type (
|
|
ReverseProxy = reverseproxy.ReverseProxy
|
|
ProxyRequest = reverseproxy.ProxyRequest
|
|
|
|
ImplNewFunc = func() any
|
|
OptionsRaw = map[string]any
|
|
|
|
commonOptions = struct {
|
|
// priority is only applied for ReverseProxy.
|
|
//
|
|
// Middleware compose follows the order of the slice
|
|
//
|
|
// Default is 10, 0 is the highest
|
|
Priority int `json:"priority"`
|
|
Bypass Bypass `json:"bypass"`
|
|
}
|
|
|
|
Middleware struct {
|
|
commonOptions
|
|
|
|
name string
|
|
construct ImplNewFunc
|
|
impl any
|
|
}
|
|
ByPriority []*Middleware
|
|
|
|
RequestModifier interface {
|
|
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
|
}
|
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
|
BodyResponseModifier interface {
|
|
ResponseModifier
|
|
isBodyResponseModifier()
|
|
}
|
|
MiddlewareWithSetup interface{ setup() }
|
|
MiddlewareFinalizer interface{ finalize() }
|
|
MiddlewareFinalizerWithError interface {
|
|
finalize() error
|
|
}
|
|
)
|
|
|
|
const DefaultPriority = 10
|
|
|
|
func (m ByPriority) Len() int { return len(m) }
|
|
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
|
func (m ByPriority) Less(i, j int) bool { return m[i].Priority < m[j].Priority }
|
|
|
|
func NewMiddleware[ImplType any]() *Middleware {
|
|
// type check
|
|
t := any(new(ImplType))
|
|
switch t.(type) {
|
|
case RequestModifier:
|
|
case ResponseModifier:
|
|
default:
|
|
panic("must implement RequestModifier or ResponseModifier")
|
|
}
|
|
_, hasFinializer := t.(MiddlewareFinalizer)
|
|
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
|
|
if hasFinializer && hasFinializerWithError {
|
|
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
|
|
}
|
|
return &Middleware{
|
|
name: reflect.TypeFor[ImplType]().Name(),
|
|
construct: func() any { return new(ImplType) },
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) setup() {
|
|
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
|
|
setup.setup()
|
|
}
|
|
}
|
|
|
|
func (m *Middleware) apply(optsRaw OptionsRaw) error {
|
|
if len(optsRaw) == 0 {
|
|
return nil
|
|
}
|
|
commonOpts := map[string]any{}
|
|
if priority, ok := optsRaw["priority"]; ok {
|
|
commonOpts["priority"] = priority
|
|
}
|
|
if bypass, ok := optsRaw["bypass"]; ok {
|
|
commonOpts["bypass"] = bypass
|
|
}
|
|
if len(commonOpts) > 0 {
|
|
if err := serialization.MapUnmarshalValidate(commonOpts, &m.commonOptions); err != nil {
|
|
return err
|
|
}
|
|
optsRaw = maps.Clone(optsRaw)
|
|
for k := range commonOpts {
|
|
delete(optsRaw, k)
|
|
}
|
|
}
|
|
return serialization.MapUnmarshalValidate(optsRaw, m.impl)
|
|
}
|
|
|
|
func (m *Middleware) finalize() error {
|
|
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
|
finalizer.finalize()
|
|
}
|
|
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
|
|
return finalizer.finalize()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, error) {
|
|
if m.construct == nil { // likely a middleware from compose
|
|
if len(optsRaw) != 0 {
|
|
return nil, fmt.Errorf("additional options not allowed for middleware %s", m.name)
|
|
}
|
|
return m, nil
|
|
}
|
|
mid := &Middleware{name: m.name, impl: m.construct()}
|
|
mid.setup()
|
|
if err := mid.apply(optsRaw); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := mid.finalize(); err != nil {
|
|
return nil, err
|
|
}
|
|
mid.impl = mid.withCheckBypass()
|
|
return mid, nil
|
|
}
|
|
|
|
func (m *Middleware) Name() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *Middleware) String() string {
|
|
return m.name
|
|
}
|
|
|
|
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
|
type allOptions struct {
|
|
commonOptions
|
|
any
|
|
}
|
|
return sonic.MarshalIndent(map[string]any{
|
|
"name": m.name,
|
|
"options": allOptions{
|
|
commonOptions: m.commonOptions,
|
|
any: m.impl,
|
|
},
|
|
}, "", " ")
|
|
}
|
|
|
|
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
if proceed := exec.before(w, r); !proceed {
|
|
return
|
|
}
|
|
}
|
|
next(w, r)
|
|
}
|
|
|
|
func (m *Middleware) TryModifyRequest(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
return exec.before(w, r)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (m *Middleware) ModifyResponse(resp *http.Response) error {
|
|
if exec, ok := m.impl.(ResponseModifier); ok {
|
|
return exec.modifyResponse(resp)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
|
if exec, ok := m.impl.(RequestModifier); ok {
|
|
if proceed := exec.before(w, r); !proceed {
|
|
return
|
|
}
|
|
}
|
|
|
|
if httpheaders.IsWebsocket(r.Header) || strings.Contains(strings.ToLower(r.Header.Get("Accept")), mimeEventStream) {
|
|
next(w, r)
|
|
return
|
|
}
|
|
|
|
exec, ok := m.impl.(ResponseModifier)
|
|
if !ok {
|
|
next(w, r)
|
|
return
|
|
}
|
|
isBodyModifier := isBodyResponseModifier(exec)
|
|
|
|
shouldBuffer := canBufferAndModifyResponseBody
|
|
if !isBodyModifier {
|
|
// Header-only response modifiers do not need body rewrite capability checks.
|
|
// We still respect max buffer limits and may fall back to passthrough for large bodies.
|
|
shouldBuffer = func(http.Header) bool { return true }
|
|
}
|
|
lrm := httputils.NewLazyResponseModifier(w, shouldBuffer)
|
|
lrm.SetMaxBufferedBytes(maxModifiableBody)
|
|
defer func() {
|
|
_, err := lrm.FlushRelease()
|
|
if err != nil {
|
|
m.LogError(r).Err(err).Msg("failed to flush response")
|
|
}
|
|
}()
|
|
next(lrm, r)
|
|
|
|
// Skip modification if response wasn't buffered
|
|
if !lrm.IsBuffered() {
|
|
return
|
|
}
|
|
|
|
rm := lrm.ResponseModifier()
|
|
if rm.IsPassthrough() {
|
|
return
|
|
}
|
|
currentBody := rm.BodyReader()
|
|
currentResp := &http.Response{
|
|
StatusCode: rm.StatusCode(),
|
|
Header: rm.Header(),
|
|
ContentLength: int64(rm.ContentLength()),
|
|
Body: currentBody,
|
|
Request: r,
|
|
}
|
|
respToModify := currentResp
|
|
if err := exec.modifyResponse(respToModify); err != nil {
|
|
log.Err(err).Str("middleware", m.Name()).Str("url", fullURL(r)).Msg("failed to modify response")
|
|
return // skip modification if failed
|
|
}
|
|
|
|
// override the response status code
|
|
rm.WriteHeader(respToModify.StatusCode)
|
|
|
|
// overriding the response header
|
|
maps.Copy(rm.Header(), respToModify.Header)
|
|
|
|
// override the body if changed
|
|
if isBodyModifier && respToModify.Body != currentBody {
|
|
err := rm.SetBody(respToModify.Body)
|
|
if err != nil {
|
|
m.LogError(r).Err(err).Msg("failed to set response body")
|
|
return // skip modification if failed
|
|
}
|
|
}
|
|
}
|
|
|
|
// canBufferAndModifyResponseBody checks if the response body can be buffered and modified.
|
|
//
|
|
// A body can be buffered and modified if:
|
|
// - The response is not a websocket and is not an event stream
|
|
// - The response has identity transfer encoding
|
|
// - The response has identity content encoding
|
|
// - The response has a content length
|
|
// - The content length is less than 4MB
|
|
// - The content type is text-like
|
|
func canBufferAndModifyResponseBody(respHeader http.Header) bool {
|
|
if httpheaders.IsWebsocket(respHeader) {
|
|
return false
|
|
}
|
|
contentType := respHeader.Get("Content-Type")
|
|
if contentType == "" { // safe default: skip if no content type
|
|
return false
|
|
}
|
|
contentType = strings.ToLower(contentType)
|
|
if strings.Contains(contentType, mimeEventStream) {
|
|
return false
|
|
}
|
|
// strip charset or any other parameters
|
|
contentType, _, err := mime.ParseMediaType(contentType)
|
|
if err != nil { // skip if invalid content type
|
|
return false
|
|
}
|
|
if hasNonIdentityEncoding(respHeader.Values("Transfer-Encoding")) {
|
|
return false
|
|
}
|
|
if hasNonIdentityEncoding(respHeader.Values("Content-Encoding")) {
|
|
return false
|
|
}
|
|
if contentLengthRaw := respHeader.Get("Content-Length"); contentLengthRaw != "" {
|
|
contentLength, err := strconv.ParseInt(contentLengthRaw, 10, 64)
|
|
if err != nil || contentLength >= maxModifiableBody {
|
|
return false
|
|
}
|
|
}
|
|
if !isTextLikeMediaType(contentType) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func hasNonIdentityEncoding(values []string) bool {
|
|
for _, value := range values {
|
|
for token := range strings.SplitSeq(value, ",") {
|
|
token = strings.TrimSpace(token)
|
|
if token == "" || strings.EqualFold(token, "identity") {
|
|
continue
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isTextLikeMediaType(contentType string) bool {
|
|
if contentType == "" {
|
|
return false
|
|
}
|
|
contentType = strings.ToLower(contentType)
|
|
if strings.HasPrefix(contentType, "text/") {
|
|
return true
|
|
}
|
|
if contentType == "application/json" || strings.HasSuffix(contentType, "+json") {
|
|
return true
|
|
}
|
|
if contentType == "application/xml" || strings.HasSuffix(contentType, "+xml") {
|
|
return true
|
|
}
|
|
if strings.Contains(contentType, "yaml") || strings.Contains(contentType, "toml") {
|
|
return true
|
|
}
|
|
if strings.Contains(contentType, "javascript") || strings.Contains(contentType, "ecmascript") {
|
|
return true
|
|
}
|
|
if strings.Contains(contentType, "csv") {
|
|
return true
|
|
}
|
|
return contentType == "application/x-www-form-urlencoded"
|
|
}
|
|
|
|
func (m *Middleware) LogWarn(req *http.Request) *zerolog.Event {
|
|
//nolint:zerologlint
|
|
return log.Warn().Str("middleware", m.name).
|
|
Str("host", req.Host).
|
|
Str("path", req.URL.Path)
|
|
}
|
|
|
|
func (m *Middleware) LogError(req *http.Request) *zerolog.Event {
|
|
//nolint:zerologlint
|
|
return log.Error().Str("middleware", m.name).
|
|
Str("host", req.Host).
|
|
Str("path", req.URL.Path)
|
|
}
|
|
|
|
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) error {
|
|
middlewares, err := compileMiddlewares(middlewaresMap)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
patchReverseProxy(rp, middlewares)
|
|
return nil
|
|
}
|
|
|
|
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
|
sort.Sort(ByPriority(middlewares))
|
|
|
|
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
|
|
|
if before, ok := mid.impl.(RequestModifier); ok {
|
|
next := rp.HandlerFunc
|
|
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
|
if proceed := before.before(w, r); proceed {
|
|
next(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
if mr, ok := mid.impl.(ResponseModifier); ok {
|
|
if rp.ModifyResponse != nil {
|
|
ori := rp.ModifyResponse
|
|
rp.ModifyResponse = func(res *http.Response) error {
|
|
if err := mr.modifyResponse(res); err != nil {
|
|
return err
|
|
}
|
|
return ori(res)
|
|
}
|
|
} else {
|
|
rp.ModifyResponse = mr.modifyResponse
|
|
}
|
|
}
|
|
}
|