implement middleware compose

This commit is contained in:
yusing
2024-10-01 16:38:07 +08:00
parent f5a36f94bb
commit 44cfd65f6c
17 changed files with 392 additions and 152 deletions

View File

@@ -13,6 +13,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
)
const (
@@ -53,7 +54,7 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) {
return cri.m, nil
}
func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
func tryFetchCFCIDR() (cfCIDRs []*types.CIDR) {
if time.Since(cfCIDRsLastUpdate) < cfCIDRsUpdateInterval {
return
}
@@ -66,14 +67,14 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
}
if common.IsTest {
cfCIDRs = []*net.IPNet{
cfCIDRs = []*types.CIDR{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 0, 0, 0)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
} else {
cfCIDRs = make([]*net.IPNet, 0, 30)
cfCIDRs = make([]*types.CIDR, 0, 30)
err := errors.Join(
fetchUpdateCFIPRange(cfIPv4CIDRsEndpoint, cfCIDRs),
fetchUpdateCFIPRange(cfIPv6CIDRsEndpoint, cfCIDRs),
@@ -90,7 +91,7 @@ func tryFetchCFCIDR() (cfCIDRs []*net.IPNet) {
return
}
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*types.CIDR) error {
resp, err := http.Get(endpoint)
if err != nil {
return err
@@ -110,7 +111,7 @@ func fetchUpdateCFIPRange(endpoint string, cfCIDRs []*net.IPNet) error {
if err != nil {
return fmt.Errorf("cloudflare responeded an invalid CIDR: %s", line)
} else {
cfCIDRs = append(cfCIDRs, cidr)
cfCIDRs = append(cfCIDRs, (*types.CIDR)(cidr))
}
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"encoding/json"
"net/http"
D "github.com/yusing/go-proxy/internal/docker"
@@ -53,7 +54,14 @@ func (m *Middleware) String() string {
return m.name
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw, rp *ReverseProxy) (*Middleware, E.NestedError) {
func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{
"name": m.name,
"options": m.impl,
}, "", " ")
}
func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.NestedError) {
if len(optsRaw) != 0 && m.withOptions != nil {
if mWithOpt, err := m.withOptions(optsRaw); err != nil {
return nil, err
@@ -87,7 +95,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res
continue
}
m, err := m.WithOptionsClone(opts, rp)
m, err := m.WithOptionsClone(opts)
if err != nil {
invalidOpts.Add(err.Subject(name))
continue

View File

@@ -8,24 +8,27 @@ import (
"gopkg.in/yaml.v3"
)
func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middleware, outErr E.NestedError) {
func BuildMiddlewaresFromComposeFile(filePath string) (map[string]*Middleware, E.NestedError) {
fileContent, err := os.ReadFile(filePath)
if err != nil {
return nil, E.FailWith("read middleware compose file", err)
}
return BuildMiddlewaresFromYAML(fileContent)
}
func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, outErr E.NestedError) {
b := E.NewBuilder("middlewares compile errors")
defer b.To(&outErr)
var data map[string][]map[string]any
fileContent, err := os.ReadFile(filePath)
if err != nil {
b.Add(E.FailWith("read file", err))
return
}
err = yaml.Unmarshal(fileContent, &data)
var rawMap map[string][]map[string]any
err := yaml.Unmarshal(data, &rawMap)
if err != nil {
b.Add(E.FailWith("toml unmarshal", err))
return
}
middlewares = make(map[string]*Middleware)
for name, defs := range data {
chainErr := E.NewBuilder("errors in middleware chain %s", name)
for name, defs := range rawMap {
chainErr := E.NewBuilder(name)
chain := make([]*Middleware, 0, len(defs))
for i, def := range defs {
if def["use"] == nil || def["use"].(string) == "" {
@@ -39,9 +42,9 @@ func BuildMiddlewaresFromYAML(filePath string) (middlewares map[string]*Middlewa
continue
}
delete(def, "use")
m, err := base.withOptions(def)
m, err := base.WithOptionsClone(def)
if err != nil {
chainErr.Add(err.Subjectf("%s.%d", name, i))
chainErr.Add(err.Subjectf("item%d", i))
continue
}
chain = append(chain, m)

View File

@@ -1,9 +1,21 @@
package middleware
import (
_ "embed"
"encoding/json"
"testing"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestBuild(t *testing.T) {
//go:embed test_data/middleware_compose.yml
var testMiddlewareCompose []byte
func TestBuild(t *testing.T) {
// middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose)
// ExpectNoError(t, err.Error())
data, err := E.Check(json.MarshalIndent(middlewares, "", " "))
ExpectNoError(t, err.Error())
t.Log(string(data))
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"fmt"
"net/http"
"path"
"strings"
@@ -15,27 +16,27 @@ import (
var middlewares map[string]*Middleware
func Get(name string) (middleware *Middleware, ok bool) {
middleware, ok = middlewares[name]
middleware, ok = middlewares[strings.ToLower(name)]
return
}
// initialize middleware names and label parsers
func init() {
middlewares = map[string]*Middleware{
"set_x_forwarded": SetXForwarded,
"hide_x_forwarded": HideXForwarded,
"redirect_http": RedirectHTTP,
"forward_auth": ForwardAuth.m,
"modify_response": ModifyResponse.m,
"modify_request": ModifyRequest.m,
"error_page": CustomErrorPage,
"custom_error_page": CustomErrorPage,
"real_ip": RealIP.m,
"cloudflare_real_ip": CloudflareRealIP.m,
"setxforwarded": SetXForwarded,
"hidexforwarded": HideXForwarded,
"redirecthttp": RedirectHTTP,
"forwardauth": ForwardAuth.m,
"modifyresponse": ModifyResponse.m,
"modifyrequest": ModifyRequest.m,
"errorpage": CustomErrorPage,
"customerrorpage": CustomErrorPage,
"realip": RealIP.m,
"cloudflarerealip": CloudflareRealIP.m,
}
names := make(map[*Middleware][]string)
for name, m := range middlewares {
names[m] = append(names[m], name)
names[m] = append(names[m], http.CanonicalHeaderKey(name))
// register middleware name to docker label parsr
// in order to parse middleware_name.option=value into correct type
if m.labelParserMap != nil {
@@ -49,6 +50,7 @@ func init() {
m.name = names[0]
}
}
// TODO: seperate from init()
b := E.NewBuilder("failed to load middlewares")
middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0)
@@ -57,7 +59,7 @@ func init() {
return
}
for _, defFile := range middlewareDefs {
mws, err := BuildMiddlewaresFromYAML(defFile)
mws, err := BuildMiddlewaresFromComposeFile(defFile)
for name, m := range mws {
if _, ok := middlewares[name]; ok {
b.Add(E.Duplicated("middleware", name))

View File

@@ -15,7 +15,7 @@ func TestSetModifyRequest(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyRequest.m.WithOptionsClone(opts, nil)
mr, err := ModifyRequest.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyRequest).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyRequest).AddHeaders, opts["add_headers"].(map[string]string))

View File

@@ -15,7 +15,7 @@ func TestSetModifyResponse(t *testing.T) {
}
t.Run("set_options", func(t *testing.T) {
mr, err := ModifyResponse.m.WithOptionsClone(opts, nil)
mr, err := ModifyResponse.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
ExpectDeepEqual(t, mr.impl.(*modifyResponse).SetHeaders, opts["set_headers"].(map[string]string))
ExpectDeepEqual(t, mr.impl.(*modifyResponse).AddHeaders, opts["add_headers"].(map[string]string))

View File

@@ -2,11 +2,11 @@ package middleware
import (
"net"
"strings"
"github.com/sirupsen/logrus"
D "github.com/yusing/go-proxy/internal/docker"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/types"
)
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
@@ -20,7 +20,7 @@ type realIPOpts struct {
// Header is the name of the header to use for the real client IP
Header string
// From is a list of Address / CIDRs to trust
From []*net.IPNet
From []*types.CIDR
/*
If recursive search is disabled,
the original client address that matches one of the trusted addresses is replaced by
@@ -35,7 +35,7 @@ type realIPOpts struct {
var RealIP = &realIP{
m: &Middleware{
labelParserMap: D.ValueParserMap{
"from": CIDRListParser,
"from": D.YamlStringListParser,
"recursive": D.BoolParser,
},
withOptions: NewRealIP,
@@ -45,14 +45,7 @@ var RealIP = &realIP{
var realIPOptsDefault = func() *realIPOpts {
return &realIPOpts{
Header: "X-Real-IP",
From: []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)},
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)},
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)},
{IP: net.ParseIP("fc00::"), Mask: net.CIDRMask(7, 128)},
{IP: net.ParseIP("fe80::"), Mask: net.CIDRMask(10, 128)},
},
From: []*types.CIDR{},
}
}
@@ -72,31 +65,6 @@ func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) {
return riWithOpts.m, nil
}
func CIDRListParser(s string) (any, E.NestedError) {
sl, err := D.YamlStringListParser(s)
if err != nil {
return nil, err
}
b := E.NewBuilder("invalid CIDR(s)")
CIDRs := sl.([]string)
res := make([]*net.IPNet, 0, len(CIDRs))
for _, cidr := range CIDRs {
if !strings.Contains(cidr, "/") {
cidr += "/32" // single IP
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
b.Add(E.Invalid("CIDR", cidr))
continue
}
res = append(res, ipnet)
}
return res, b.Build()
}
func (ri *realIP) isInCIDRList(ip net.IP) bool {
for _, CIDR := range ri.From {
if CIDR.Contains(ip) {

View File

@@ -0,0 +1,58 @@
package middleware
import (
"net"
"testing"
"github.com/yusing/go-proxy/internal/types"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestSetRealIP(t *testing.T) {
opts := OptionsRaw{
"header": "X-Real-IP",
"from": []string{
"127.0.0.0/8",
"192.168.0.0/16",
"172.16.0.0/12",
},
"recursive": true,
}
optExpected := &realIPOpts{
Header: "X-Real-IP",
From: []*types.CIDR{
{
IP: net.ParseIP("127.0.0.0"),
Mask: net.IPv4Mask(255, 0, 0, 0),
},
{
IP: net.ParseIP("192.168.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
{
IP: net.ParseIP("172.16.0.0"),
Mask: net.IPv4Mask(255, 240, 0, 0),
},
},
Recursive: true,
}
t.Run("set_options", func(t *testing.T) {
ri, err := RealIP.m.WithOptionsClone(opts)
ExpectNoError(t, err.Error())
// ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header)
// ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From)
// ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive)
ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected)
})
// t.Run("request_headers", func(t *testing.T) {
// result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{
// middlewareOpt: opts,
// })
// ExpectNoError(t, err.Error())
// ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0")
// ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value"))
// ExpectEqual(t, result.RequestHeaders.Get("Accept"), "")
// })
}

View File

@@ -0,0 +1,41 @@
theGreatPretender:
- use: HideXForwarded
- use: ModifyRequest
setHeaders:
X-Real-IP: 6.6.6.6
- use: ModifyResponse
hideHeaders:
- X-Test3
- X-Test4
notAuthenticAuthentik:
- use: RedirectHTTP
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
addAuthCookiesToResponse:
- session_id
- user_id
authResponseHeaders:
- X-Auth-SessionID
- X-Auth-UserID
- use: CustomErrorPage
realIPAuthentik:
- use: RedirectHTTP
- use: RealIP
header: X-Real-IP
from:
- "127.0.0.0/8"
- "192.168.0.0/16"
- "172.16.0.0/12"
recursive: true
- use: ForwardAuth
address: https://authentik.company
trustForwardHeader: true
testFakeRealIP:
- use: ModifyRequest
setHeaders:
CF-Connecting-IP: 127.0.0.1
- use: CloudflareRealIP