mirror of
https://github.com/yusing/godoxy.git
synced 2026-03-29 21:31:48 +02:00
fixed middleware implementation, added middleware tracing for easier debug
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type cidrWhitelist struct {
|
||||
@@ -19,7 +20,7 @@ type cidrWhitelistOpts struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
|
||||
trustedAddr map[string]struct{} // cache for trusted IPs
|
||||
cachedAddr F.Map[string, bool] // cache for trusted IPs
|
||||
}
|
||||
|
||||
var CIDRWhiteList = &cidrWhitelist{
|
||||
@@ -28,15 +29,16 @@ var CIDRWhiteList = &cidrWhitelist{
|
||||
"allow": D.YamlStringListParser,
|
||||
"statusCode": D.IntParser,
|
||||
},
|
||||
withOptions: NewCIDRWhitelist,
|
||||
},
|
||||
}
|
||||
|
||||
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
|
||||
return &cidrWhitelistOpts{
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
trustedAddr: make(map[string]struct{}),
|
||||
Allow: []*types.CIDR{},
|
||||
StatusCode: http.StatusForbidden,
|
||||
Message: "IP not allowed",
|
||||
cachedAddr: F.NewMapOf[string, bool](),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,23 +59,32 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
return wl.m, nil
|
||||
}
|
||||
|
||||
func (wl *cidrWhitelist) checkIP(next http.Handler, w ResponseWriter, r *Request) {
|
||||
var ok bool
|
||||
if _, ok = wl.trustedAddr[r.RemoteAddr]; !ok {
|
||||
ip := net.IP(r.RemoteAddr)
|
||||
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
var allow, ok bool
|
||||
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
|
||||
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ipStr = r.RemoteAddr
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
for _, cidr := range wl.cidrWhitelistOpts.Allow {
|
||||
if cidr.Contains(ip) {
|
||||
wl.trustedAddr[r.RemoteAddr] = struct{}{}
|
||||
ok = true
|
||||
wl.cachedAddr.Store(r.RemoteAddr, true)
|
||||
allow = true
|
||||
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allow {
|
||||
wl.cachedAddr.Store(r.RemoteAddr, false)
|
||||
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
if !allow {
|
||||
w.WriteHeader(wl.StatusCode)
|
||||
w.Write([]byte(wl.Message))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
42
internal/net/http/middleware/cidr_whitelist_test.go
Normal file
42
internal/net/http/middleware/cidr_whitelist_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
//go:embed test_data/cidr_whitelist_test.yml
|
||||
var testCIDRWhitelistCompose []byte
|
||||
var deny, accept *Middleware
|
||||
|
||||
func TestCIDRWhitelist(t *testing.T) {
|
||||
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
deny = mids["deny@file"]
|
||||
accept = mids["accept@file"]
|
||||
if deny == nil || accept == nil {
|
||||
panic("bug occurred")
|
||||
}
|
||||
|
||||
t.Run("deny", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(deny, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
|
||||
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("accept", func(t *testing.T) {
|
||||
for range 10 {
|
||||
result, err := newMiddlewareTest(accept, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -39,12 +39,13 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
|
||||
cri := new(realIP)
|
||||
cri.m = &Middleware{
|
||||
impl: cri,
|
||||
rewrite: func(r *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
cidrs := tryFetchCFCIDR()
|
||||
if cidrs != nil {
|
||||
cri.From = cidrs
|
||||
}
|
||||
cri.setRealIP(r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
cri.realIPOpts = &realIPOpts{
|
||||
|
||||
@@ -15,9 +15,9 @@ import (
|
||||
)
|
||||
|
||||
var CustomErrorPage = &Middleware{
|
||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if !ServeStaticErrorPageFile(w, r) {
|
||||
next.ServeHTTP(w, r)
|
||||
next(w, r)
|
||||
}
|
||||
},
|
||||
modifyResponse: func(resp *Response) error {
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
||||
@@ -45,7 +44,6 @@ var ForwardAuth = func() *forwardAuth {
|
||||
fa.m.withOptions = NewForwardAuthfunc
|
||||
return fa
|
||||
}()
|
||||
var faLogger = logrus.WithField("middleware", "ForwardAuth")
|
||||
|
||||
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
faWithOpts := new(forwardAuth)
|
||||
@@ -80,7 +78,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
return faWithOpts.m, nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
|
||||
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
gpHTTP.RemoveHop(req.Header)
|
||||
|
||||
faReq, err := http.NewRequestWithContext(
|
||||
@@ -90,7 +88,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -103,7 +101,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||
|
||||
faResp, err := fa.client.Do(faReq)
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -111,7 +109,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||
|
||||
body, err := io.ReadAll(faResp.Body)
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -122,7 +120,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||
|
||||
redirectURL, err := faResp.Location()
|
||||
if err != nil {
|
||||
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
} else if redirectURL.String() != "" {
|
||||
@@ -132,7 +130,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
|
||||
w.WriteHeader(faResp.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
|
||||
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
@@ -21,7 +22,7 @@ type (
|
||||
Header = http.Header
|
||||
Cookie = http.Cookie
|
||||
|
||||
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
|
||||
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
|
||||
RewriteFunc func(req *Request)
|
||||
ModifyResponseFunc func(resp *Response) error
|
||||
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
|
||||
@@ -33,23 +34,38 @@ type (
|
||||
name string
|
||||
|
||||
before BeforeFunc // runs before ReverseProxy.ServeHTTP
|
||||
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
|
||||
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
|
||||
|
||||
transport http.RoundTripper
|
||||
|
||||
withOptions CloneWithOptFunc
|
||||
labelParserMap D.ValueParserMap
|
||||
impl any
|
||||
|
||||
parent *Middleware
|
||||
children []*Middleware
|
||||
trace bool
|
||||
}
|
||||
)
|
||||
|
||||
var Deserialize = U.Deserialize
|
||||
|
||||
func Rewrite(r RewriteFunc) BeforeFunc {
|
||||
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
r(req)
|
||||
next(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) Fullname() string {
|
||||
if m.parent != nil {
|
||||
return m.parent.Fullname() + "." + m.name
|
||||
}
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *Middleware) String() string {
|
||||
return m.name
|
||||
}
|
||||
@@ -72,14 +88,21 @@ func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Nested
|
||||
|
||||
// WithOptionsClone is called only once
|
||||
// set withOptions and labelParser will not be used after that
|
||||
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
|
||||
return &Middleware{
|
||||
m.name,
|
||||
m.before,
|
||||
m.modifyResponse,
|
||||
nil, nil,
|
||||
m.impl,
|
||||
m.parent,
|
||||
m.children,
|
||||
false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
|
||||
befores := make([]BeforeFunc, 0, len(middlewares))
|
||||
rewrites := make([]RewriteFunc, 0, len(middlewares))
|
||||
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
|
||||
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
|
||||
middlewares := make([]*Middleware, 0, len(middlewaresMap))
|
||||
|
||||
invalidM := E.NewBuilder("invalid middlewares")
|
||||
invalidOpts := E.NewBuilder("invalid options")
|
||||
@@ -88,7 +111,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
||||
invalidM.To(&res)
|
||||
}()
|
||||
|
||||
for name, opts := range middlewares {
|
||||
for name, opts := range middlewaresMap {
|
||||
m, ok := Get(name)
|
||||
if !ok {
|
||||
invalidM.Add(E.NotExist("middleware", name))
|
||||
@@ -100,56 +123,35 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
|
||||
invalidOpts.Add(err.Subject(name))
|
||||
continue
|
||||
}
|
||||
if m.before != nil {
|
||||
befores = append(befores, m.before)
|
||||
}
|
||||
if m.rewrite != nil {
|
||||
rewrites = append(rewrites, m.rewrite)
|
||||
}
|
||||
if m.modifyResponse != nil {
|
||||
modResps = append(modResps, m.modifyResponse)
|
||||
}
|
||||
middlewares = append(middlewares, m)
|
||||
}
|
||||
|
||||
if invalidM.HasError() {
|
||||
return
|
||||
}
|
||||
|
||||
origServeHTTP := rp.ServeHTTP
|
||||
for i, before := range befores {
|
||||
if i < len(befores)-1 {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
before(rp.ServeHTTP, w, r)
|
||||
}
|
||||
} else {
|
||||
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
|
||||
before(origServeHTTP, w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(rewrites) > 0 {
|
||||
origServeHTTP = rp.ServeHTTP
|
||||
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, rewrite := range rewrites {
|
||||
rewrite(r)
|
||||
}
|
||||
origServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if len(modResps) > 0 {
|
||||
if rp.ModifyResponse != nil {
|
||||
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
|
||||
}
|
||||
rp.ModifyResponse = func(res *Response) error {
|
||||
b := E.NewBuilder("errors in middleware ModifyResponse")
|
||||
for _, mr := range modResps {
|
||||
b.AddE(mr(res))
|
||||
}
|
||||
return b.Build().Error()
|
||||
}
|
||||
}
|
||||
|
||||
patchReverseProxy(rpName, rp, middlewares)
|
||||
return
|
||||
}
|
||||
|
||||
func patchReverseProxy(rpName string, rp *ReverseProxy, middlewares []*Middleware) {
|
||||
mid := BuildMiddlewareFromChain(rpName, middlewares)
|
||||
|
||||
if mid.before != nil {
|
||||
ori := rp.ServeHTTP
|
||||
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
|
||||
mid.before(ori, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if mid.modifyResponse != nil {
|
||||
if rp.ModifyResponse != nil {
|
||||
ori := rp.ModifyResponse
|
||||
rp.ModifyResponse = func(res *http.Response) error {
|
||||
return errors.Join(mid.modifyResponse(res), ori(res))
|
||||
}
|
||||
} else {
|
||||
rp.ModifyResponse = mid.modifyResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||
var rawMap map[string][]map[string]any
|
||||
err := yaml.Unmarshal(data, &rawMap)
|
||||
if err != nil {
|
||||
b.Add(E.FailWith("toml unmarshal", err))
|
||||
b.Add(E.FailWith("yaml unmarshal", err))
|
||||
return
|
||||
}
|
||||
middlewares = make(map[string]*Middleware)
|
||||
@@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||
chainErr := E.NewBuilder(name)
|
||||
chain := make([]*Middleware, 0, len(defs))
|
||||
for i, def := range defs {
|
||||
if def["use"] == nil || def["use"].(string) == "" {
|
||||
chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i))
|
||||
if def["use"] == nil || def["use"] == "" {
|
||||
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
|
||||
continue
|
||||
}
|
||||
baseName := def["use"].(string)
|
||||
base, ok := Get(baseName)
|
||||
if !ok {
|
||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i))
|
||||
continue
|
||||
base, ok = middlewares[baseName]
|
||||
if !ok {
|
||||
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
|
||||
continue
|
||||
}
|
||||
}
|
||||
delete(def, "use")
|
||||
m, err := base.WithOptionsClone(def)
|
||||
m.name = fmt.Sprintf("%s[%d]", name, i)
|
||||
if err != nil {
|
||||
chainErr.Add(err.Subjectf("item%d", i))
|
||||
continue
|
||||
@@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||
if chainErr.HasError() {
|
||||
b.Add(chainErr.Build())
|
||||
} else {
|
||||
name = name + "@file"
|
||||
middlewares[name] = BuildMiddlewareFromChain(name, chain)
|
||||
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
|
||||
|
||||
// TODO: check conflict or duplicates
|
||||
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
|
||||
var (
|
||||
befores []BeforeFunc
|
||||
rewrites []RewriteFunc
|
||||
modResps []ModifyResponseFunc
|
||||
)
|
||||
for _, m := range chain {
|
||||
if m.before != nil {
|
||||
befores = append(befores, m.before)
|
||||
m := &Middleware{name: name, children: chain}
|
||||
|
||||
var befores []*Middleware
|
||||
var modResps []*Middleware
|
||||
|
||||
for _, comp := range chain {
|
||||
if comp.before != nil {
|
||||
befores = append(befores, comp)
|
||||
}
|
||||
if m.rewrite != nil {
|
||||
rewrites = append(rewrites, m.rewrite)
|
||||
}
|
||||
if m.modifyResponse != nil {
|
||||
modResps = append(modResps, m.modifyResponse)
|
||||
if comp.modifyResponse != nil {
|
||||
modResps = append(modResps, comp)
|
||||
}
|
||||
comp.parent = m
|
||||
}
|
||||
|
||||
m := &Middleware{name: name}
|
||||
if len(befores) > 0 {
|
||||
m.before = func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
for _, before := range befores {
|
||||
before(next, w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(rewrites) > 0 {
|
||||
m.rewrite = func(r *Request) {
|
||||
for _, rewrite := range rewrites {
|
||||
rewrite(r)
|
||||
}
|
||||
}
|
||||
m.before = buildBefores(befores)
|
||||
}
|
||||
if len(modResps) > 0 {
|
||||
m.modifyResponse = func(res *Response) error {
|
||||
b := E.NewBuilder("errors in middleware %s", name)
|
||||
b := E.NewBuilder("errors in middleware")
|
||||
for _, mr := range modResps {
|
||||
b.AddE(mr(res))
|
||||
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
|
||||
}
|
||||
return b.Build().Error()
|
||||
}
|
||||
}
|
||||
|
||||
if common.IsDebug {
|
||||
m.EnableTrace()
|
||||
m.AddTracef("middleware created")
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func buildBefores(befores []*Middleware) BeforeFunc {
|
||||
if len(befores) == 1 {
|
||||
return befores[0].before
|
||||
}
|
||||
nextBefores := buildBefores(befores[1:])
|
||||
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
befores[0].before(func(w ResponseWriter, r *Request) {
|
||||
nextBefores(next, w, r)
|
||||
}, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,10 +67,10 @@ func LoadComposeFiles() {
|
||||
b.Add(E.Duplicated("middleware", name))
|
||||
continue
|
||||
}
|
||||
middlewares[name] = m
|
||||
middlewares[U.ToLowerNoSnake(name)] = m
|
||||
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
|
||||
}
|
||||
b.Add(err.Subject(defFile))
|
||||
b.Add(err.Subject(path.Base(defFile)))
|
||||
}
|
||||
if b.HasError() {
|
||||
logger.Error(b.Build())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
@@ -32,9 +33,15 @@ var ModifyRequest = func() *modifyRequest {
|
||||
|
||||
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
mr := new(modifyRequest)
|
||||
var mrFunc RewriteFunc
|
||||
if common.IsDebug {
|
||||
mrFunc = mr.modifyRequestWithTrace
|
||||
} else {
|
||||
mrFunc = mr.modifyRequest
|
||||
}
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
rewrite: mr.modifyRequest,
|
||||
impl: mr,
|
||||
before: Rewrite(mrFunc),
|
||||
}
|
||||
mr.modifyRequestOpts = new(modifyRequestOpts)
|
||||
err := Deserialize(optsRaw, mr.modifyRequestOpts)
|
||||
@@ -55,3 +62,9 @@ func (mr *modifyRequest) modifyRequest(req *Request) {
|
||||
req.Header.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
|
||||
mr.m.AddTraceRequest("before modify request", req)
|
||||
mr.modifyRequest(req)
|
||||
mr.m.AddTraceRequest("after modify request", req)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
@@ -34,9 +35,11 @@ var ModifyResponse = func() (mr *modifyResponse) {
|
||||
|
||||
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
|
||||
mr := new(modifyResponse)
|
||||
mr.m = &Middleware{
|
||||
impl: mr,
|
||||
modifyResponse: mr.modifyResponse,
|
||||
mr.m = &Middleware{impl: mr}
|
||||
if common.IsDebug {
|
||||
mr.m.modifyResponse = mr.modifyResponseWithTrace
|
||||
} else {
|
||||
mr.m.modifyResponse = mr.modifyResponse
|
||||
}
|
||||
mr.modifyResponseOpts = new(modifyResponseOpts)
|
||||
err := Deserialize(optsRaw, mr.modifyResponseOpts)
|
||||
@@ -58,3 +61,10 @@ func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
|
||||
mr.m.AddTraceResponse("before modify response", resp)
|
||||
err := mr.modifyResponse(resp)
|
||||
mr.m.AddTraceResponse("after modify response", resp)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
D "github.com/yusing/go-proxy/internal/docker"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
@@ -49,13 +49,14 @@ var realIPOptsDefault = func() *realIPOpts {
|
||||
}
|
||||
}
|
||||
|
||||
var realIPLogger = logrus.WithField("middleware", "RealIP")
|
||||
|
||||
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
riWithOpts := new(realIP)
|
||||
riWithOpts.m = &Middleware{
|
||||
impl: riWithOpts,
|
||||
rewrite: riWithOpts.setRealIP,
|
||||
impl: riWithOpts,
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
riWithOpts.setRealIP(r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
riWithOpts.realIPOpts = realIPOptsDefault()
|
||||
err := Deserialize(opts, riWithOpts.realIPOpts)
|
||||
@@ -78,7 +79,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
|
||||
func (ri *realIP) setRealIP(req *Request) {
|
||||
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
realIPLogger.Debugf("failed to split host port %s", err)
|
||||
clientIPStr = req.RemoteAddr
|
||||
}
|
||||
clientIP := net.ParseIP(clientIPStr)
|
||||
|
||||
@@ -90,7 +91,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||
}
|
||||
}
|
||||
if !isTrusted {
|
||||
realIPLogger.Debugf("client ip %s is not trusted", clientIP)
|
||||
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -98,7 +99,7 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||
var lastNonTrustedIP string
|
||||
|
||||
if len(realIPs) == 0 {
|
||||
realIPLogger.Debugf("no real ip found in header %q", ri.Header)
|
||||
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -110,14 +111,16 @@ func (ri *realIP) setRealIP(req *Request) {
|
||||
lastNonTrustedIP = r
|
||||
}
|
||||
}
|
||||
if lastNonTrustedIP == "" {
|
||||
realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if lastNonTrustedIP == "" {
|
||||
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
|
||||
return
|
||||
}
|
||||
|
||||
req.RemoteAddr = lastNonTrustedIP
|
||||
req.Header.Set(ri.Header, lastNonTrustedIP)
|
||||
req.Header.Set("X-Real-IP", lastNonTrustedIP)
|
||||
req.Header.Set("X-Forwarded-For", lastNonTrustedIP)
|
||||
req.Header.Set(xForwardedFor, lastNonTrustedIP)
|
||||
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
|
||||
}
|
||||
|
||||
@@ -2,13 +2,15 @@ package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
func TestSetRealIPOpts(t *testing.T) {
|
||||
opts := OptionsRaw{
|
||||
"header": "X-Real-IP",
|
||||
"from": []string{
|
||||
@@ -37,13 +39,39 @@ func TestSetRealIP(t *testing.T) {
|
||||
Recursive: true,
|
||||
}
|
||||
|
||||
t.Run("set_options", func(t *testing.T) {
|
||||
ri, err := RealIP.m.WithOptionsClone(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
|
||||
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
|
||||
})
|
||||
// TODO test
|
||||
ri, err := NewRealIP(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
|
||||
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
|
||||
for i, CIDR := range ri.impl.(*realIP).From {
|
||||
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRealIP(t *testing.T) {
|
||||
const (
|
||||
testHeader = "X-Real-IP"
|
||||
testRealIP = "192.168.1.1"
|
||||
)
|
||||
opts := OptionsRaw{
|
||||
"header": testHeader,
|
||||
"from": []string{"0.0.0.0/0"},
|
||||
}
|
||||
optsMr := OptionsRaw{
|
||||
"set_headers": map[string]string{testHeader: testRealIP},
|
||||
}
|
||||
realip, err := NewRealIP(opts)
|
||||
ExpectNoError(t, err.Error())
|
||||
|
||||
mr, err := NewModifyRequest(optsMr)
|
||||
ExpectNoError(t, err.Error())
|
||||
|
||||
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
|
||||
|
||||
result, err := newMiddlewareTest(mid, nil)
|
||||
ExpectNoError(t, err.Error())
|
||||
t.Log(traces)
|
||||
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
|
||||
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
|
||||
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP)
|
||||
}
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
)
|
||||
|
||||
var RedirectHTTP = &Middleware{
|
||||
before: func(next http.Handler, w ResponseWriter, r *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
|
||||
if r.TLS == nil {
|
||||
r.URL.Scheme = "https"
|
||||
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
|
||||
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
next(w, r)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
deny:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.1.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
accept:
|
||||
- use: ModifyRequest
|
||||
setHeaders:
|
||||
X-Real-IP: 192.168.0.1:1234
|
||||
- use: RealIP
|
||||
header: X-Real-IP
|
||||
from:
|
||||
- 0.0.0.0/0
|
||||
- use: CIDRWhitelist
|
||||
allow:
|
||||
- 192.168.0.0/24
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
|
||||
)
|
||||
@@ -20,6 +21,9 @@ var testHeaders http.Header
|
||||
const testHost = "example.com"
|
||||
|
||||
func init() {
|
||||
if !common.IsTest {
|
||||
return
|
||||
}
|
||||
tmp := map[string]string{}
|
||||
err := json.Unmarshal(testHeadersRaw, &tmp)
|
||||
if err != nil {
|
||||
@@ -31,13 +35,15 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
type requestHeaderRecorder struct {
|
||||
type requestRecorder struct {
|
||||
parent http.RoundTripper
|
||||
reqHeaders http.Header
|
||||
headers http.Header
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.reqHeaders = req.Header
|
||||
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.headers = req.Header
|
||||
rt.remoteAddr = req.RemoteAddr
|
||||
if rt.parent != nil {
|
||||
return rt.parent.RoundTrip(req)
|
||||
}
|
||||
@@ -46,6 +52,7 @@ func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, e
|
||||
Header: testHeaders,
|
||||
Body: io.NopCloser(bytes.NewBufferString("OK")),
|
||||
Request: req,
|
||||
TLS: req.TLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -53,6 +60,7 @@ type TestResult struct {
|
||||
RequestHeaders http.Header
|
||||
ResponseHeaders http.Header
|
||||
ResponseStatus int
|
||||
RemoteAddr string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
@@ -65,7 +73,7 @@ type testArgs struct {
|
||||
|
||||
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
|
||||
var body io.Reader
|
||||
var rt = new(requestHeaderRecorder)
|
||||
var rr = new(requestRecorder)
|
||||
var proxyURL *url.URL
|
||||
var requestTarget string
|
||||
var err error
|
||||
@@ -98,17 +106,16 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||
if err != nil {
|
||||
return nil, E.From(err)
|
||||
}
|
||||
rt.parent = http.DefaultTransport
|
||||
rr.parent = http.DefaultTransport
|
||||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
|
||||
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
|
||||
middleware.name: args.middlewareOpt,
|
||||
})
|
||||
rp := gpHTTP.NewReverseProxy(proxyURL, rr)
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
}
|
||||
patchReverseProxy(middleware.name, rp, []*Middleware{mid})
|
||||
rp.ServeHTTP(w, req)
|
||||
resp := w.Result()
|
||||
defer resp.Body.Close()
|
||||
@@ -117,9 +124,10 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||
return nil, E.From(err)
|
||||
}
|
||||
return &TestResult{
|
||||
RequestHeaders: rt.reqHeaders,
|
||||
RequestHeaders: rr.headers,
|
||||
ResponseHeaders: resp.Header,
|
||||
ResponseStatus: resp.StatusCode,
|
||||
RemoteAddr: rr.remoteAddr,
|
||||
Data: data,
|
||||
}, nil
|
||||
}
|
||||
|
||||
99
internal/net/http/middleware/trace.go
Normal file
99
internal/net/http/middleware/trace.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type Trace struct {
|
||||
Time string `json:"time,omitempty"`
|
||||
Caller string `json:"caller,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Message string `json:"msg"`
|
||||
ReqHeaders http.Header `json:"req_headers,omitempty"`
|
||||
RespHeaders http.Header `json:"resp_headers,omitempty"`
|
||||
Additional map[string]any `json:"additional,omitempty"`
|
||||
}
|
||||
|
||||
type Traces []*Trace
|
||||
|
||||
var traces = Traces{}
|
||||
var tracesMu sync.Mutex
|
||||
|
||||
const MaxTraceNum = 1000
|
||||
|
||||
func GetAllTrace() []*Trace {
|
||||
return traces
|
||||
}
|
||||
|
||||
func (tr *Trace) WithRequest(req *Request) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = req.RequestURI
|
||||
tr.ReqHeaders = req.Header.Clone()
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) WithResponse(resp *Response) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
tr.URL = resp.Request.RequestURI
|
||||
tr.ReqHeaders = resp.Request.Header.Clone()
|
||||
tr.RespHeaders = resp.Header.Clone()
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *Trace) With(what string, additional any) *Trace {
|
||||
if tr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tr.Additional == nil {
|
||||
tr.Additional = map[string]any{}
|
||||
}
|
||||
tr.Additional[what] = additional
|
||||
return tr
|
||||
}
|
||||
|
||||
func (m *Middleware) EnableTrace() {
|
||||
m.trace = true
|
||||
for _, child := range m.children {
|
||||
child.parent = m
|
||||
child.EnableTrace()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
|
||||
if !m.trace {
|
||||
return nil
|
||||
}
|
||||
return addTrace(&Trace{
|
||||
Time: U.FormatTime(time.Now()),
|
||||
Caller: m.Fullname(),
|
||||
Message: fmt.Sprintf(msg, args...),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
|
||||
return m.AddTracef("%s", msg).WithRequest(req)
|
||||
}
|
||||
|
||||
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
|
||||
return m.AddTracef("%s", msg).WithResponse(resp)
|
||||
}
|
||||
|
||||
func addTrace(t *Trace) *Trace {
|
||||
tracesMu.Lock()
|
||||
defer tracesMu.Unlock()
|
||||
if len(traces) > MaxTraceNum {
|
||||
traces = traces[1:]
|
||||
}
|
||||
traces = append(traces, t)
|
||||
return t
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -14,7 +15,7 @@ const (
|
||||
)
|
||||
|
||||
var SetXForwarded = &Middleware{
|
||||
rewrite: func(req *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
req.Header.Del("Forwarded")
|
||||
req.Header.Del(xForwardedFor)
|
||||
req.Header.Del(xForwardedHost)
|
||||
@@ -23,7 +24,7 @@ var SetXForwarded = &Middleware{
|
||||
if err == nil {
|
||||
req.Header.Set(xForwardedFor, clientIP)
|
||||
} else {
|
||||
req.Header.Del(xForwardedFor)
|
||||
req.Header.Set(xForwardedFor, req.RemoteAddr)
|
||||
}
|
||||
req.Header.Set(xForwardedHost, req.Host)
|
||||
if req.TLS == nil {
|
||||
@@ -31,14 +32,16 @@ var SetXForwarded = &Middleware{
|
||||
} else {
|
||||
req.Header.Set(xForwardedProto, "https")
|
||||
}
|
||||
next(w, req)
|
||||
},
|
||||
}
|
||||
|
||||
var HideXForwarded = &Middleware{
|
||||
rewrite: func(req *Request) {
|
||||
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
|
||||
req.Header.Del("Forwarded")
|
||||
req.Header.Del(xForwardedFor)
|
||||
req.Header.Del(xForwardedHost)
|
||||
req.Header.Del(xForwardedProto)
|
||||
next(w, req)
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user