cleanup and simplify middleware implementations, refactor some other code

This commit is contained in:
yusing
2024-12-16 10:19:14 +08:00
parent 8a9cb2527e
commit 59f4eaf3ea
34 changed files with 641 additions and 720 deletions

View File

@@ -3,70 +3,110 @@ package middleware
import (
"encoding/json"
"net/http"
"reflect"
"strings"
"github.com/rs/zerolog"
E "github.com/yusing/go-proxy/internal/error"
gphttp "github.com/yusing/go-proxy/internal/net/http"
U "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils"
)
type (
Error = E.Error
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
Request = http.Request
Response = gphttp.ProxyResponse
ResponseWriter = http.ResponseWriter
Header = http.Header
Cookie = http.Cookie
ReverseProxy = gphttp.ReverseProxy
ProxyRequest = gphttp.ProxyRequest
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(*Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.Error)
OptionsRaw = map[string]any
ImplNewFunc = func() any
OptionsRaw = map[string]any
Middleware struct {
_ U.NoCopy
zerolog.Logger
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
withOptions CloneWithOptFunc
impl any
parent *Middleware
children []*Middleware
trace bool
name string
construct ImplNewFunc
impl any
}
RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool)
}
ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() }
MiddlewareWithTracer *struct{ *Tracer }
)
var Deserialize = U.Deserialize
func Rewrite(r RewriteFunc) BeforeFunc {
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
r(req)
next(w, req)
func NewMiddleware[ImplType any]() *Middleware {
// type check
switch any(new(ImplType)).(type) {
case RequestModifier:
case ResponseModifier:
default:
panic("must implement RequestModifier or ResponseModifier")
}
return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) },
}
}
func (m *Middleware) enableTrace() {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
tracer.Tracer = &Tracer{name: m.name}
}
}
func (m *Middleware) getTracer() *Tracer {
if tracer, ok := m.impl.(MiddlewareWithTracer); ok {
return tracer.Tracer
}
return nil
}
func (m *Middleware) setParent(parent *Middleware) {
if tracer := m.getTracer(); tracer != nil {
tracer.parent = parent.getTracer()
}
}
func (m *Middleware) setup() {
if setup, ok := m.impl.(MiddlewareWithSetup); ok {
setup.setup()
}
}
func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
if len(optsRaw) == 0 {
return nil
}
return utils.Deserialize(optsRaw, m.impl)
}
func (m *Middleware) finalize() {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize()
}
}
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.construct == nil {
if optsRaw != nil {
panic("bug: middleware already constructed")
}
return m, nil
}
mid := &Middleware{name: m.name, impl: m.construct()}
mid.setup()
if err := mid.apply(optsRaw); err != nil {
return nil, err
}
mid.finalize()
return mid, nil
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) Fullname() string {
if m.parent != nil {
return m.parent.Fullname() + "." + m.name
}
return m.name
}
func (m *Middleware) String() string {
return m.name
}
@@ -78,57 +118,38 @@ func (m *Middleware) MarshalJSON() ([]byte, error) {
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Error) {
if m.withOptions != nil {
m, err := m.withOptions(optsRaw)
if err != nil {
return nil, err
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
m.Logger = logger.With().Str("name", m.name).Logger()
return m, nil
}
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{
Logger: logger.With().Str("name", m.name).Logger(),
name: m.name,
before: m.before,
modifyResponse: m.modifyResponse,
impl: m.impl,
parent: m.parent,
children: m.children,
}, nil
next(w, r)
}
func (m *Middleware) ModifyRequest(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.before != nil {
m.before(next, w, r)
}
}
func (m *Middleware) ModifyResponse(resp *Response) error {
if m.modifyResponse != nil {
return m.modifyResponse(resp)
func (m *Middleware) ModifyResponse(resp *http.Response) error {
if exec, ok := m.impl.(ResponseModifier); ok {
return exec.modifyResponse(resp)
}
return nil
}
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w ResponseWriter, r *Request) {
if m.modifyResponse != nil {
func (m *Middleware) ServeHTTP(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
if exec, ok := m.impl.(ResponseModifier); ok {
w = gphttp.NewModifyResponseWriter(w, r, func(resp *http.Response) error {
return m.modifyResponse(&Response{Response: resp, OriginalRequest: r})
return exec.modifyResponse(resp)
})
}
if m.before != nil {
m.before(next, w, r)
} else {
next(w, r)
if exec, ok := m.impl.(RequestModifier); ok {
if proceed := exec.before(w, r); !proceed {
return
}
}
next(w, r)
}
// TODO: check conflict or duplicates.
func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
func compileMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.Error) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
errs := E.NewBuilder("middlewares compile error")
@@ -141,7 +162,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
continue
}
m, err = m.WithOptionsClone(opts)
m, err = m.New(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue
@@ -157,7 +178,7 @@ func createMiddlewares(middlewaresMap map[string]OptionsRaw) ([]*Middleware, E.E
func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (err E.Error) {
var middlewares []*Middleware
middlewares, err = createMiddlewares(middlewaresMap)
middlewares, err = compileMiddlewares(middlewaresMap)
if err != nil {
return
}
@@ -166,34 +187,30 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
}
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rp.TargetName, append([]*Middleware{{
name: "set_upstream_headers",
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
r.Header.Set(gphttp.HeaderUpstreamScheme, rp.TargetURL.Scheme)
r.Header.Set(gphttp.HeaderUpstreamHost, rp.TargetURL.Hostname())
r.Header.Set(gphttp.HeaderUpstreamPort, rp.TargetURL.Port())
next(w, r)
},
}}, middlewares...))
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
if mid.before != nil {
ori := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *Request) {
mid.before(ori, w, r)
mid := NewMiddlewareChain(rp.TargetName, middlewares)
if before, ok := mid.impl.(RequestModifier); ok {
next := rp.HandlerFunc
rp.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
if proceed := before.before(w, r); proceed {
next(w, r)
}
}
}
if mid.modifyResponse != nil {
if mr, ok := mid.impl.(ResponseModifier); ok {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *Response) error {
if err := mid.modifyResponse(res); err != nil {
rp.ModifyResponse = func(res *http.Response) error {
if err := mr.modifyResponse(res); err != nil {
return err
}
return ori(res)
}
} else {
rp.ModifyResponse = mid.modifyResponse
rp.ModifyResponse = mr.modifyResponse
}
}
}