fixed middleware implementation, added middleware tracing for easier debug

This commit is contained in:
yusing
2024-10-02 13:55:41 +08:00
parent d172552fb0
commit ba13b81b0e
31 changed files with 561 additions and 196 deletions

View File

@@ -7,6 +7,7 @@ import (
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
F "github.com/yusing/go-proxy/internal/utils/functional"
)
type cidrWhitelist struct {
@@ -19,7 +20,7 @@ type cidrWhitelistOpts struct {
StatusCode int
Message string
trustedAddr map[string]struct{} // cache for trusted IPs
cachedAddr F.Map[string, bool] // cache for trusted IPs
}
var CIDRWhiteList = &cidrWhitelist{
@@ -28,15 +29,16 @@ var CIDRWhiteList = &cidrWhitelist{
"allow": D.YamlStringListParser,
"statusCode": D.IntParser,
},
withOptions: NewCIDRWhitelist,
},
}
var cidrWhitelistDefaults = func() *cidrWhitelistOpts {
return &cidrWhitelistOpts{
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
trustedAddr: make(map[string]struct{}),
Allow: []*types.CIDR{},
StatusCode: http.StatusForbidden,
Message: "IP not allowed",
cachedAddr: F.NewMapOf[string, bool](),
}
}
@@ -57,23 +59,32 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) {
return wl.m, nil
}
func (wl *cidrWhitelist) checkIP(next http.Handler, w ResponseWriter, r *Request) {
var ok bool
if _, ok = wl.trustedAddr[r.RemoteAddr]; !ok {
ip := net.IP(r.RemoteAddr)
func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) {
var allow, ok bool
if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ipStr = r.RemoteAddr
}
ip := net.ParseIP(ipStr)
for _, cidr := range wl.cidrWhitelistOpts.Allow {
if cidr.Contains(ip) {
wl.trustedAddr[r.RemoteAddr] = struct{}{}
ok = true
wl.cachedAddr.Store(r.RemoteAddr, true)
allow = true
wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr)
break
}
}
if !allow {
wl.cachedAddr.Store(r.RemoteAddr, false)
wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow)
}
}
if !ok {
if !allow {
w.WriteHeader(wl.StatusCode)
w.Write([]byte(wl.Message))
return
}
next.ServeHTTP(w, r)
next(w, r)
}

View File

@@ -0,0 +1,42 @@
package middleware
import (
_ "embed"
"net/http"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
//go:embed test_data/cidr_whitelist_test.yml
var testCIDRWhitelistCompose []byte
var deny, accept *Middleware
func TestCIDRWhitelist(t *testing.T) {
mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose)
if err != nil {
panic(err)
}
deny = mids["deny@file"]
accept = mids["accept@file"]
if deny == nil || accept == nil {
panic("bug occurred")
}
t.Run("deny", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(deny, nil)
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode)
ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message)
}
})
t.Run("accept", func(t *testing.T) {
for range 10 {
result, err := newMiddlewareTest(accept, nil)
ExpectNoError(t, err.Error())
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
}
})
}

View File

@@ -39,12 +39,13 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
cri := new(realIP)
cri.m = &Middleware{
impl: cri,
rewrite: func(r *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
cidrs := tryFetchCFCIDR()
if cidrs != nil {
cri.From = cidrs
}
cri.setRealIP(r)
next(w, r)
},
}
cri.realIPOpts = &realIPOpts{

View File

@@ -15,9 +15,9 @@ import (
)
var CustomErrorPage = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if !ServeStaticErrorPageFile(w, r) {
next.ServeHTTP(w, r)
next(w, r)
}
},
modifyResponse: func(resp *Response) error {

View File

@@ -13,7 +13,6 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
@@ -45,7 +44,6 @@ var ForwardAuth = func() *forwardAuth {
fa.m.withOptions = NewForwardAuthfunc
return fa
}()
var faLogger = logrus.WithField("middleware", "ForwardAuth")
func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
faWithOpts := new(forwardAuth)
@@ -80,7 +78,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
return faWithOpts.m, nil
}
func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) {
func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) {
gpHTTP.RemoveHop(req.Header)
faReq, err := http.NewRequestWithContext(
@@ -90,7 +88,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
nil,
)
if err != nil {
faLogger.Debugf("new request err to %s: %s", fa.Address, err)
fa.m.AddTracef("new request err to %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -103,7 +101,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
faResp, err := fa.client.Do(faReq)
if err != nil {
faLogger.Debugf("failed to call %s: %s", fa.Address, err)
fa.m.AddTracef("failed to call %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -111,7 +109,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
body, err := io.ReadAll(faResp.Body)
if err != nil {
faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err)
fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
@@ -122,7 +120,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
redirectURL, err := faResp.Location()
if err != nil {
faLogger.Debugf("failed to get location from %s: %s", fa.Address, err)
fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err)
w.WriteHeader(http.StatusInternalServerError)
return
} else if redirectURL.String() != "" {
@@ -132,7 +130,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request
w.WriteHeader(faResp.StatusCode)
if _, err = w.Write(body); err != nil {
faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err)
fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err)
}
return
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"encoding/json"
"errors"
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
@@ -21,7 +22,7 @@ type (
Header = http.Header
Cookie = http.Cookie
BeforeFunc func(next http.Handler, w ResponseWriter, r *Request)
BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request)
RewriteFunc func(req *Request)
ModifyResponseFunc func(resp *Response) error
CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError)
@@ -33,23 +34,38 @@ type (
name string
before BeforeFunc // runs before ReverseProxy.ServeHTTP
rewrite RewriteFunc // runs after ReverseProxy.Rewrite
modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse
transport http.RoundTripper
withOptions CloneWithOptFunc
labelParserMap D.ValueParserMap
impl any
parent *Middleware
children []*Middleware
trace bool
}
)
var Deserialize = U.Deserialize
func Rewrite(r RewriteFunc) BeforeFunc {
return func(next http.HandlerFunc, w ResponseWriter, req *Request) {
r(req)
next(w, req)
}
}
func (m *Middleware) Name() string {
return m.name
}
func (m *Middleware) Fullname() string {
if m.parent != nil {
return m.parent.Fullname() + "." + m.name
}
return m.name
}
func (m *Middleware) String() string {
return m.name
}
@@ -72,14 +88,21 @@ func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Nested
// WithOptionsClone is called only once
// set withOptions and labelParser will not be used after that
return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil
return &Middleware{
m.name,
m.before,
m.modifyResponse,
nil, nil,
m.impl,
m.parent,
m.children,
false,
}, nil
}
// TODO: check conflict or duplicates
func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) {
befores := make([]BeforeFunc, 0, len(middlewares))
rewrites := make([]RewriteFunc, 0, len(middlewares))
modResps := make([]ModifyResponseFunc, 0, len(middlewares))
func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) {
middlewares := make([]*Middleware, 0, len(middlewaresMap))
invalidM := E.NewBuilder("invalid middlewares")
invalidOpts := E.NewBuilder("invalid options")
@@ -88,7 +111,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
invalidM.To(&res)
}()
for name, opts := range middlewares {
for name, opts := range middlewaresMap {
m, ok := Get(name)
if !ok {
invalidM.Add(E.NotExist("middleware", name))
@@ -100,56 +123,35 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
invalidOpts.Add(err.Subject(name))
continue
}
if m.before != nil {
befores = append(befores, m.before)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
}
middlewares = append(middlewares, m)
}
if invalidM.HasError() {
return
}
origServeHTTP := rp.ServeHTTP
for i, before := range befores {
if i < len(befores)-1 {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(rp.ServeHTTP, w, r)
}
} else {
rp.ServeHTTP = func(w ResponseWriter, r *Request) {
before(origServeHTTP, w, r)
}
}
}
if len(rewrites) > 0 {
origServeHTTP = rp.ServeHTTP
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
for _, rewrite := range rewrites {
rewrite(r)
}
origServeHTTP(w, r)
}
}
if len(modResps) > 0 {
if rp.ModifyResponse != nil {
modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...)
}
rp.ModifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware ModifyResponse")
for _, mr := range modResps {
b.AddE(mr(res))
}
return b.Build().Error()
}
}
patchReverseProxy(rpName, rp, middlewares)
return
}
func patchReverseProxy(rpName string, rp *ReverseProxy, middlewares []*Middleware) {
mid := BuildMiddlewareFromChain(rpName, middlewares)
if mid.before != nil {
ori := rp.ServeHTTP
rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) {
mid.before(ori, w, r)
}
}
if mid.modifyResponse != nil {
if rp.ModifyResponse != nil {
ori := rp.ModifyResponse
rp.ModifyResponse = func(res *http.Response) error {
return errors.Join(mid.modifyResponse(res), ori(res))
}
} else {
rp.ModifyResponse = mid.modifyResponse
}
}
}

View File

@@ -1,9 +1,11 @@
package middleware
import (
"fmt"
"net/http"
"os"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"gopkg.in/yaml.v3"
)
@@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("toml unmarshal", err))
b.Add(E.FailWith("yaml unmarshal", err))
return
}
middlewares = make(map[string]*Middleware)
@@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" {
chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i))
if def["use"] == nil || def["use"] == "" {
chainErr.Add(E.Missing("use").Subjectf(".%d", i))
continue
}
baseName := def["use"].(string)
base, ok := Get(baseName)
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i))
continue
base, ok = middlewares[baseName]
if !ok {
chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i))
continue
}
}
delete(def, "use")
m, err := base.WithOptionsClone(def)
m.name = fmt.Sprintf("%s[%d]", name, i)
if err != nil {
chainErr.Add(err.Subjectf("item%d", i))
continue
@@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
if chainErr.HasError() {
b.Add(chainErr.Build())
} else {
name = name + "@file"
middlewares[name] = BuildMiddlewareFromChain(name, chain)
middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain)
}
}
return
@@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware,
// TODO: check conflict or duplicates
func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware {
var (
befores []BeforeFunc
rewrites []RewriteFunc
modResps []ModifyResponseFunc
)
for _, m := range chain {
if m.before != nil {
befores = append(befores, m.before)
m := &Middleware{name: name, children: chain}
var befores []*Middleware
var modResps []*Middleware
for _, comp := range chain {
if comp.before != nil {
befores = append(befores, comp)
}
if m.rewrite != nil {
rewrites = append(rewrites, m.rewrite)
}
if m.modifyResponse != nil {
modResps = append(modResps, m.modifyResponse)
if comp.modifyResponse != nil {
modResps = append(modResps, comp)
}
comp.parent = m
}
m := &Middleware{name: name}
if len(befores) > 0 {
m.before = func(next http.Handler, w ResponseWriter, r *Request) {
for _, before := range befores {
before(next, w, r)
}
}
}
if len(rewrites) > 0 {
m.rewrite = func(r *Request) {
for _, rewrite := range rewrites {
rewrite(r)
}
}
m.before = buildBefores(befores)
}
if len(modResps) > 0 {
m.modifyResponse = func(res *Response) error {
b := E.NewBuilder("errors in middleware %s", name)
b := E.NewBuilder("errors in middleware")
for _, mr := range modResps {
b.AddE(mr(res))
b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name))
}
return b.Build().Error()
}
}
if common.IsDebug {
m.EnableTrace()
m.AddTracef("middleware created")
}
return m
}
func buildBefores(befores []*Middleware) BeforeFunc {
if len(befores) == 1 {
return befores[0].before
}
nextBefores := buildBefores(befores[1:])
return func(next http.HandlerFunc, w ResponseWriter, r *Request) {
befores[0].before(func(w ResponseWriter, r *Request) {
nextBefores(next, w, r)
}, w, r)
}
}

View File

@@ -67,10 +67,10 @@ func LoadComposeFiles() {
b.Add(E.Duplicated("middleware", name))
continue
}
middlewares[name] = m
middlewares[U.ToLowerNoSnake(name)] = m
logger.Infof("middleware %s loaded from %s", name, path.Base(defFile))
}
b.Add(err.Subject(defFile))
b.Add(err.Subject(path.Base(defFile)))
}
if b.HasError() {
logger.Error(b.Build())

View File

@@ -1,6 +1,7 @@
package middleware
import (
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
@@ -32,9 +33,15 @@ var ModifyRequest = func() *modifyRequest {
func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
mr := new(modifyRequest)
var mrFunc RewriteFunc
if common.IsDebug {
mrFunc = mr.modifyRequestWithTrace
} else {
mrFunc = mr.modifyRequest
}
mr.m = &Middleware{
impl: mr,
rewrite: mr.modifyRequest,
impl: mr,
before: Rewrite(mrFunc),
}
mr.modifyRequestOpts = new(modifyRequestOpts)
err := Deserialize(optsRaw, mr.modifyRequestOpts)
@@ -55,3 +62,9 @@ func (mr *modifyRequest) modifyRequest(req *Request) {
req.Header.Del(k)
}
}
func (mr *modifyRequest) modifyRequestWithTrace(req *Request) {
mr.m.AddTraceRequest("before modify request", req)
mr.modifyRequest(req)
mr.m.AddTraceRequest("after modify request", req)
}

View File

@@ -3,6 +3,7 @@ package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/common"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
)
@@ -34,9 +35,11 @@ var ModifyResponse = func() (mr *modifyResponse) {
func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
mr := new(modifyResponse)
mr.m = &Middleware{
impl: mr,
modifyResponse: mr.modifyResponse,
mr.m = &Middleware{impl: mr}
if common.IsDebug {
mr.m.modifyResponse = mr.modifyResponseWithTrace
} else {
mr.m.modifyResponse = mr.modifyResponse
}
mr.modifyResponseOpts = new(modifyResponseOpts)
err := Deserialize(optsRaw, mr.modifyResponseOpts)
@@ -58,3 +61,10 @@ func (mr *modifyResponse) modifyResponse(resp *http.Response) error {
}
return nil
}
func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error {
mr.m.AddTraceResponse("before modify response", resp)
err := mr.modifyResponse(resp)
mr.m.AddTraceResponse("after modify response", resp)
return err
}

View File

@@ -2,8 +2,8 @@ package middleware
import (
"net"
"net/http"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
@@ -49,13 +49,14 @@ var realIPOptsDefault = func() *realIPOpts {
}
}
var realIPLogger = logrus.WithField("middleware", "RealIP")
func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
riWithOpts := new(realIP)
riWithOpts.m = &Middleware{
impl: riWithOpts,
rewrite: riWithOpts.setRealIP,
impl: riWithOpts,
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
riWithOpts.setRealIP(r)
next(w, r)
},
}
riWithOpts.realIPOpts = realIPOptsDefault()
err := Deserialize(opts, riWithOpts.realIPOpts)
@@ -78,7 +79,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool {
func (ri *realIP) setRealIP(req *Request) {
clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
realIPLogger.Debugf("failed to split host port %s", err)
clientIPStr = req.RemoteAddr
}
clientIP := net.ParseIP(clientIPStr)
@@ -90,7 +91,7 @@ func (ri *realIP) setRealIP(req *Request) {
}
}
if !isTrusted {
realIPLogger.Debugf("client ip %s is not trusted", clientIP)
ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From)
return
}
@@ -98,7 +99,7 @@ func (ri *realIP) setRealIP(req *Request) {
var lastNonTrustedIP string
if len(realIPs) == 0 {
realIPLogger.Debugf("no real ip found in header %q", ri.Header)
ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req)
return
}
@@ -110,14 +111,16 @@ func (ri *realIP) setRealIP(req *Request) {
lastNonTrustedIP = r
}
}
if lastNonTrustedIP == "" {
realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header)
return
}
}
if lastNonTrustedIP == "" {
ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs)
return
}
req.RemoteAddr = lastNonTrustedIP
req.Header.Set(ri.Header, lastNonTrustedIP)
req.Header.Set("X-Real-IP", lastNonTrustedIP)
req.Header.Set("X-Forwarded-For", lastNonTrustedIP)
req.Header.Set(xForwardedFor, lastNonTrustedIP)
ri.m.AddTracef("set real ip %s", lastNonTrustedIP)
}

View File

@@ -2,13 +2,15 @@ package middleware
import (
"net"
"net/http"
"strings"
"testing"
"github.com/yusing/go-proxy/internal/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIP(t *testing.T) {
func TestSetRealIPOpts(t *testing.T) {
opts := OptionsRaw{
"header": "X-Real-IP",
"from": []string{
@@ -37,13 +39,39 @@ func TestSetRealIP(t *testing.T) {
Recursive: true,
}
t.Run("set_options", func(t *testing.T) {
ri, err := RealIP.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
})
// TODO test
ri, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
for i, CIDR := range ri.impl.(*realIP).From {
ExpectEqual(t, CIDR.String(), optExpected.From[i].String())
}
}
func TestSetRealIP(t *testing.T) {
const (
testHeader = "X-Real-IP"
testRealIP = "192.168.1.1"
)
opts := OptionsRaw{
"header": testHeader,
"from": []string{"0.0.0.0/0"},
}
optsMr := OptionsRaw{
"set_headers": map[string]string{testHeader: testRealIP},
}
realip, err := NewRealIP(opts)
ExpectNoError(t, err.Error())
mr, err := NewModifyRequest(optsMr)
ExpectNoError(t, err.Error())
mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip})
result, err := newMiddlewareTest(mid, nil)
ExpectNoError(t, err.Error())
t.Log(traces)
ExpectEqual(t, result.ResponseStatus, http.StatusOK)
ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP)
ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP)
}

View File

@@ -7,13 +7,13 @@ import (
)
var RedirectHTTP = &Middleware{
before: func(next http.Handler, w ResponseWriter, r *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, r *Request) {
if r.TLS == nil {
r.URL.Scheme = "https"
r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
return
}
next.ServeHTTP(w, r)
next(w, r)
},
}

View File

@@ -0,0 +1,22 @@
deny:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.1.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24
accept:
- use: ModifyRequest
setHeaders:
X-Real-IP: 192.168.0.1:1234
- use: RealIP
header: X-Real-IP
from:
- 0.0.0.0/0
- use: CIDRWhitelist
allow:
- 192.168.0.0/24

View File

@@ -9,6 +9,7 @@ import (
"net/http/httptest"
"net/url"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
gpHTTP "github.com/yusing/go-proxy/internal/net/http"
)
@@ -20,6 +21,9 @@ var testHeaders http.Header
const testHost = "example.com"
func init() {
if !common.IsTest {
return
}
tmp := map[string]string{}
err := json.Unmarshal(testHeadersRaw, &tmp)
if err != nil {
@@ -31,13 +35,15 @@ func init() {
}
}
type requestHeaderRecorder struct {
type requestRecorder struct {
parent http.RoundTripper
reqHeaders http.Header
headers http.Header
remoteAddr string
}
func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.reqHeaders = req.Header
func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
rt.headers = req.Header
rt.remoteAddr = req.RemoteAddr
if rt.parent != nil {
return rt.parent.RoundTrip(req)
}
@@ -46,6 +52,7 @@ func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, e
Header: testHeaders,
Body: io.NopCloser(bytes.NewBufferString("OK")),
Request: req,
TLS: req.TLS,
}, nil
}
@@ -53,6 +60,7 @@ type TestResult struct {
RequestHeaders http.Header
ResponseHeaders http.Header
ResponseStatus int
RemoteAddr string
Data []byte
}
@@ -65,7 +73,7 @@ type testArgs struct {
func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) {
var body io.Reader
var rt = new(requestHeaderRecorder)
var rr = new(requestRecorder)
var proxyURL *url.URL
var requestTarget string
var err error
@@ -98,17 +106,16 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
if err != nil {
return nil, E.From(err)
}
rt.parent = http.DefaultTransport
rr.parent = http.DefaultTransport
} else {
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
}
rp := gpHTTP.NewReverseProxy(proxyURL, rt)
setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{
middleware.name: args.middlewareOpt,
})
rp := gpHTTP.NewReverseProxy(proxyURL, rr)
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
patchReverseProxy(middleware.name, rp, []*Middleware{mid})
rp.ServeHTTP(w, req)
resp := w.Result()
defer resp.Body.Close()
@@ -117,9 +124,10 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
return nil, E.From(err)
}
return &TestResult{
RequestHeaders: rt.reqHeaders,
RequestHeaders: rr.headers,
ResponseHeaders: resp.Header,
ResponseStatus: resp.StatusCode,
RemoteAddr: rr.remoteAddr,
Data: data,
}, nil
}

View File

@@ -0,0 +1,99 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
U "github.com/yusing/go-proxy/internal/utils"
)
type Trace struct {
Time string `json:"time,omitempty"`
Caller string `json:"caller,omitempty"`
URL string `json:"url,omitempty"`
Message string `json:"msg"`
ReqHeaders http.Header `json:"req_headers,omitempty"`
RespHeaders http.Header `json:"resp_headers,omitempty"`
Additional map[string]any `json:"additional,omitempty"`
}
type Traces []*Trace
var traces = Traces{}
var tracesMu sync.Mutex
const MaxTraceNum = 1000
func GetAllTrace() []*Trace {
return traces
}
func (tr *Trace) WithRequest(req *Request) *Trace {
if tr == nil {
return nil
}
tr.URL = req.RequestURI
tr.ReqHeaders = req.Header.Clone()
return tr
}
func (tr *Trace) WithResponse(resp *Response) *Trace {
if tr == nil {
return nil
}
tr.URL = resp.Request.RequestURI
tr.ReqHeaders = resp.Request.Header.Clone()
tr.RespHeaders = resp.Header.Clone()
return tr
}
func (tr *Trace) With(what string, additional any) *Trace {
if tr == nil {
return nil
}
if tr.Additional == nil {
tr.Additional = map[string]any{}
}
tr.Additional[what] = additional
return tr
}
func (m *Middleware) EnableTrace() {
m.trace = true
for _, child := range m.children {
child.parent = m
child.EnableTrace()
}
}
func (m *Middleware) AddTracef(msg string, args ...any) *Trace {
if !m.trace {
return nil
}
return addTrace(&Trace{
Time: U.FormatTime(time.Now()),
Caller: m.Fullname(),
Message: fmt.Sprintf(msg, args...),
})
}
func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace {
return m.AddTracef("%s", msg).WithRequest(req)
}
func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace {
return m.AddTracef("%s", msg).WithResponse(resp)
}
func addTrace(t *Trace) *Trace {
tracesMu.Lock()
defer tracesMu.Unlock()
if len(traces) > MaxTraceNum {
traces = traces[1:]
}
traces = append(traces, t)
return t
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"net"
"net/http"
)
const (
@@ -14,7 +15,7 @@ const (
)
var SetXForwarded = &Middleware{
rewrite: func(req *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
req.Header.Del("Forwarded")
req.Header.Del(xForwardedFor)
req.Header.Del(xForwardedHost)
@@ -23,7 +24,7 @@ var SetXForwarded = &Middleware{
if err == nil {
req.Header.Set(xForwardedFor, clientIP)
} else {
req.Header.Del(xForwardedFor)
req.Header.Set(xForwardedFor, req.RemoteAddr)
}
req.Header.Set(xForwardedHost, req.Host)
if req.TLS == nil {
@@ -31,14 +32,16 @@ var SetXForwarded = &Middleware{
} else {
req.Header.Set(xForwardedProto, "https")
}
next(w, req)
},
}
var HideXForwarded = &Middleware{
rewrite: func(req *Request) {
before: func(next http.HandlerFunc, w ResponseWriter, req *Request) {
req.Header.Del("Forwarded")
req.Header.Del(xForwardedFor)
req.Header.Del(xForwardedHost)
req.Header.Del(xForwardedProto)
next(w, req)
},
}