From ed887a5cfc4f0a2646cbf65e4b8d41cb61cd2012 Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 2 Oct 2024 01:04:34 +0800 Subject: [PATCH] fixed serialization and middleware compose --- cmd/main.go | 13 ++-- docs/middlewares.md | 9 +++ internal/common/constants.go | 3 +- internal/docker/label_parser.go | 9 +++ .../middleware/middleware_builder_test.go | 9 +-- internal/net/http/middleware/middlewares.go | 6 +- internal/net/http/middleware/real_ip_test.go | 11 +--- internal/utils/serialization.go | 62 ++++++++++++------- 8 files changed, 74 insertions(+), 48 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index f38a8c66..956e640d 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -24,6 +24,7 @@ import ( "github.com/yusing/go-proxy/internal/docker" "github.com/yusing/go-proxy/internal/docker/idlewatcher" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/middleware" R "github.com/yusing/go-proxy/internal/route" "github.com/yusing/go-proxy/internal/server" F "github.com/yusing/go-proxy/internal/utils/functional" @@ -80,8 +81,9 @@ func main() { prepareDirectory(dir) } - err := config.Load() - if err != nil { + middleware.LoadComposeFiles() + + if err := config.Load(); err != nil { logrus.Warn(err) } cfg := config.GetInstance() @@ -113,11 +115,6 @@ func main() { } cfg.StartProxyProviders() - - if err.HasError() { - l.Warn(err) - } - cfg.WatchChanges() onShutdown.Add(docker.CloseAllClients) @@ -132,7 +129,7 @@ func main() { if autocert != nil { ctx, cancel := context.WithCancel(context.Background()) - if err = autocert.Setup(ctx); err != nil { + if err := autocert.Setup(ctx); err != nil { l.Fatal(err) } else { onShutdown.Add(cancel) diff --git a/docs/middlewares.md b/docs/middlewares.md index 4c98a7eb..4239e761 100644 --- a/docs/middlewares.md +++ b/docs/middlewares.md @@ -20,6 +20,7 @@ - [Hide X-Forwarded-\*](#hide-x-forwarded-) - [Set X-Forwarded-\*](#set-x-forwarded-) - [Forward Authorization header (experimental)](#forward-authorization-header-experimental) + - [Middleware Compose](#middleware-compose) - [Examples](#examples) - [Authentik (untested, experimental)](#authentik-untested-experimental) @@ -356,6 +357,14 @@ http: [🔼Back to top](#table-of-content) +## Middleware Compose + +Middleware compose is a way to create reusable middlewares in file(s), just like docker compose. + +You may use them with `@file` + +See [example](../internal/net/http/middleware/test_data/middleware_compose.yml) + ## Examples ### Authentik (untested, experimental) diff --git a/internal/common/constants.go b/internal/common/constants.go index 961b795a..57a50935 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -18,7 +18,7 @@ const ( ConfigExampleFileName = "config.example.yml" ConfigPath = ConfigBasePath + "/" + ConfigFileName - MiddlewareDefsBasePath = ConfigBasePath + "/middlewares" + MiddlewareComposeBasePath = ConfigBasePath + "/middlewares" ) const ( @@ -41,6 +41,7 @@ var ( ConfigBasePath, SchemaBasePath, ErrorPagesBasePath, + MiddlewareComposeBasePath, } ) diff --git a/internal/docker/label_parser.go b/internal/docker/label_parser.go index e9769c94..b1662166 100644 --- a/internal/docker/label_parser.go +++ b/internal/docker/label_parser.go @@ -1,6 +1,7 @@ package docker import ( + "strconv" "strings" E "github.com/yusing/go-proxy/internal/error" @@ -76,3 +77,11 @@ func BoolParser(value string) (any, E.NestedError) { return nil, E.Invalid("boolean value", value) } } + +func IntParser(value string) (any, E.NestedError) { + i, err := strconv.Atoi(value) + if err != nil { + return 0, E.Invalid("integer value", value) + } + return i, nil +} diff --git a/internal/net/http/middleware/middleware_builder_test.go b/internal/net/http/middleware/middleware_builder_test.go index 997d41ce..d7fca0ca 100644 --- a/internal/net/http/middleware/middleware_builder_test.go +++ b/internal/net/http/middleware/middleware_builder_test.go @@ -13,9 +13,10 @@ import ( var testMiddlewareCompose []byte func TestBuild(t *testing.T) { - // middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) - // ExpectNoError(t, err.Error()) - data, err := E.Check(json.MarshalIndent(middlewares, "", " ")) + middlewares, err := BuildMiddlewaresFromYAML(testMiddlewareCompose) ExpectNoError(t, err.Error()) - t.Log(string(data)) + _, err = E.Check(json.MarshalIndent(middlewares, "", " ")) + ExpectNoError(t, err.Error()) + // t.Log(string(data)) + // TODO: test } diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index f86a8ca9..21255aad 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -33,6 +33,7 @@ func init() { "customerrorpage": CustomErrorPage, "realip": RealIP.m, "cloudflarerealip": CloudflareRealIP.m, + "cidrwhitelist": CIDRWhiteList.m, } names := make(map[*Middleware][]string) for name, m := range middlewares { @@ -50,10 +51,11 @@ func init() { m.name = names[0] } } +} - // TODO: seperate from init() +func LoadComposeFiles() { b := E.NewBuilder("failed to load middlewares") - middlewareDefs, err := U.ListFiles(common.MiddlewareDefsBasePath, 0) + middlewareDefs, err := U.ListFiles(common.MiddlewareComposeBasePath, 0) if err != nil { logrus.Errorf("failed to list middleware definitions: %s", err) return diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index c4b941c8..71324561 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -45,14 +45,5 @@ func TestSetRealIP(t *testing.T) { // ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected) }) - - // t.Run("request_headers", func(t *testing.T) { - // result, err := newMiddlewareTest(ModifyRequest.m, &testArgs{ - // middlewareOpt: opts, - // }) - // ExpectNoError(t, err.Error()) - // ExpectEqual(t, result.RequestHeaders.Get("User-Agent"), "go-proxy/v0.5.0") - // ExpectTrue(t, slices.Contains(result.RequestHeaders.Values("Accept-Encoding"), "test-value")) - // ExpectEqual(t, result.RequestHeaders.Get("Accept"), "") - // }) + // TODO test } diff --git a/internal/utils/serialization.go b/internal/utils/serialization.go index 71a48cdb..fb0f2406 100644 --- a/internal/utils/serialization.go +++ b/internal/utils/serialization.go @@ -13,7 +13,7 @@ import ( ) type SerializedObject = map[string]any -type Convertor interface { +type Converter interface { ConvertFrom(value any) (any, E.NestedError) } @@ -188,7 +188,7 @@ func Deserialize(src SerializedObject, dst any) E.NestedError { // - error: the error occurred during conversion, or nil if no error occurred. func Convert(src reflect.Value, dst reflect.Value) E.NestedError { srcT := src.Type() - dstVT := dst.Type() + dstT := dst.Type() if src.Kind() == reflect.Interface { src = src.Elem() @@ -199,31 +199,36 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError { return E.From(fmt.Errorf("%w type %T is unsettable", E.ErrUnsupported, dst.Interface())) } - switch { - case srcT.AssignableTo(dstVT): - dst.Set(src) - case srcT.ConvertibleTo(dstVT): - dst.Set(src.Convert(dstVT)) - case srcT.Kind() == reflect.Map: - if dstVT.Kind() != reflect.Map { - return E.TypeError("map", srcT, dstVT) + if dst.Kind() == reflect.Pointer { + if dst.IsNil() { + dst.Set(reflect.New(dstT.Elem())) } + dst = dst.Elem() + dstT = dst.Type() + } + + switch { + case srcT.AssignableTo(dstT): + dst.Set(src) + case srcT.ConvertibleTo(dstT): + dst.Set(src.Convert(dstT)) + case srcT.Kind() == reflect.Map: obj, ok := src.Interface().(SerializedObject) if !ok { - return E.TypeError("map", srcT, dstVT) + return E.TypeMismatch[SerializedObject](src.Interface()) } err := Deserialize(obj, dst.Addr().Interface()) if err != nil { return err } case srcT.Kind() == reflect.Slice: - if dstVT.Kind() != reflect.Slice { - return E.TypeError("slice", srcT, dstVT) + if dstT.Kind() != reflect.Slice { + return E.TypeError("slice", srcT, dstT) } - newSlice := reflect.MakeSlice(dstVT, 0, src.Len()) + newSlice := reflect.MakeSlice(dstT, 0, src.Len()) i := 0 for _, v := range src.Seq2() { - tmp := reflect.New(dstVT.Elem()).Elem() + tmp := reflect.New(dstT.Elem()).Elem() err := Convert(v, tmp) if err != nil { return err.Subjectf("[%d]", i) @@ -233,16 +238,27 @@ func Convert(src reflect.Value, dst reflect.Value) E.NestedError { } dst.Set(newSlice) default: - // check if Convertor is implemented - if converter, ok := dst.Interface().(Convertor); ok { - converted, err := converter.ConvertFrom(src.Interface()) - if err != nil { - return err + var converter Converter + var ok bool + // check if (*T).Convertor is implemented + if converter, ok = dst.Addr().Interface().(Converter); !ok { + // check if (T).Convertor is implemented + converter, ok = dst.Interface().(Converter) + if !ok { + return E.TypeError("conversion", srcT, dstT) } - dst.Set(reflect.ValueOf(converted)) - return nil } - return E.TypeError("conversion", srcT, dstVT) + + converted, err := converter.ConvertFrom(src.Interface()) + if err != nil { + return err + } + c := reflect.ValueOf(converted) + if c.Kind() == reflect.Ptr { + c = c.Elem() + } + dst.Set(c) + return nil } return nil