fixed middleware implementation, added middleware tracing for easier debug

This commit is contained in:
yusing
2024-10-02 13:55:41 +08:00
parent d172552fb0
commit ba13b81b0e
31 changed files with 561 additions and 196 deletions

View File

@@ -1,9 +1,11 @@
package middleware
import (
"fmt"
"net/http"
"os"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
@@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("toml unmarshal", err))
b.Add(E.FailWith("yaml unmarshal", err))
return
}
middlewares = make(map[string]*Middleware)
@@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" {
chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i))
if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
continue
}
baseName := def["use"].(string)
base, ok := Get(baseName)
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i))
continue
base, ok = middlewares[baseName]
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m.name = fmt.Sprintf("%s[%d]", name, i)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
continue
@@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
if chainErr.HasError() {
b.Add(chainErr.Build())
} else {
name = name + "@file"
middlewares[name] = BuildMiddlewareFromChain(name, chain)
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
}
}
return
@@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
// TODO: check conflict or duplicates
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
var (
befores []BeforeFunc
rewrites []RewriteFunc
modResps []ModifyResponseFunc
)
for _, m := range chain {
if m.before != nil {
befores = append(befores, m.before)
m := &Middleware{name: name, children: chain}
var befores []*Middleware
var modResps []*Middleware
for _, comp := range chain {
if comp.before != nil {
befores = append(befores, comp)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
if comp.modifyResponse != nil {
modResps = append(modResps, comp)
}
comp.parent = m
}
m := &Middleware{name: name}
if len(befores) > 0 {
m.before = func(next http.Handler, w ResponseWriter, r *Request) {
for _, before := range befores {
before(next, w, r)
}
}
}
if len(rewrites) > 0 {
m.rewrite = func(r *Request) {
for _, rewrite := range rewrites {
rewrite(r)
}
}
m.before = buildBefores(befores)
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware %s", name)
b := E.NewBuilder("errors in middleware")
for _, mr := range modResps {
b.AddE(mr(res))
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
}
return b.Build().Error()
}
}
if common.IsDebug {
m.EnableTrace()
m.AddTracef("middleware created")
}
return m
}
func buildBefores(befores []*Middleware) BeforeFunc {
if len(befores) == 1 {
return befores[0].before
}
nextBefores := buildBefores(befores[1:])
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
befores[0].before(func(w ResponseWriter, r *Request) {
nextBefores(next, w, r)
}, w, r)
}
}