diff --git a/README.md b/README.md index a8f9aafb..2b203050 100755 --- a/README.md +++ b/README.md @@ -83,13 +83,14 @@ _Join our [Discord](https://discord.gg/umReR62nRd) for help and discussions_ ### Commands line arguments -| Argument | Description | Example | -| ----------- | -------------------------------- | -------------------------- | -| empty | start proxy server | | -| `validate` | validate config and exit | | -| `reload` | trigger a force reload of config | | -| `ls-config` | list config and exit | `go-proxy ls-config \| jq` | -| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` | +| Argument | Description | Example | +| ----------------- | ---------------------------------------------------- | -------------------------------- | +| empty | start proxy server | | +| `validate` | validate config and exit | | +| `reload` | trigger a force reload of config | | +| `ls-config` | list config and exit | `go-proxy ls-config \| jq` | +| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` | +| `debug-ls-mtrace` | list middleware trace **(works only in debug mode)** | `go-proxy debug-ls-mtrace \| jq` | **run with `docker exec go-proxy /app/go-proxy `** diff --git a/cmd/main.go b/cmd/main.go index 956e640d..320212d0 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -18,7 +18,7 @@ import ( "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal/api" - apiUtils "github.com/yusing/go-proxy/internal/api/v1/utils" + "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/docker" @@ -57,7 +57,7 @@ func main() { } if args.Command == common.CommandReload { - if err := apiUtils.ReloadServer(); err.HasError() { + if err := query.ReloadServer(); err.HasError() { log.Fatal(err) } log.Print("ok") @@ -93,7 +93,7 @@ func main() { printJSON(cfg.Value()) return case common.CommandListRoutes: - routes, err := apiUtils.ListRoutes() + routes, err := query.ListRoutes() if err.HasError() { log.Printf("failed to connect to api server: %s", err) log.Printf("falling back to config file") @@ -108,6 +108,12 @@ func main() { case common.CommandDebugListProviders: printJSON(cfg.DumpProviders()) return + case common.CommandDebugListMTrace: + trace, err := query.ListMiddlewareTraces() + if err.HasError() { + log.Fatal(err) + } + printJSON(trace) } if common.IsDebug { diff --git a/internal/api/v1/checkhealth.go b/internal/api/v1/checkhealth.go index 6a2b483f..bba30cd5 100644 --- a/internal/api/v1/checkhealth.go +++ b/internal/api/v1/checkhealth.go @@ -25,10 +25,10 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound) return case route.Type() == R.RouteTypeReverseProxy: - ok = U.IsSiteHealthy(route.URL().String()) + ok = IsSiteHealthy(route.URL().String()) case route.Type() == R.RouteTypeStream: entry := route.Entry() - ok = U.IsStreamHealthy( + ok = IsStreamHealthy( strings.Split(entry.Scheme, ":")[1], // target scheme fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]), ) diff --git a/internal/api/v1/utils/health_check.go b/internal/api/v1/health_check.go similarity index 82% rename from internal/api/v1/utils/health_check.go rename to internal/api/v1/health_check.go index a03f9c89..20825d65 100644 --- a/internal/api/v1/utils/health_check.go +++ b/internal/api/v1/health_check.go @@ -1,21 +1,22 @@ -package utils +package v1 import ( "net" "net/http" + U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" ) func IsSiteHealthy(url string) bool { // try HEAD first // if HEAD is not allowed, try GET - resp, err := httpClient.Head(url) + resp, err := U.Head(url) if resp != nil { resp.Body.Close() } if err != nil && resp != nil && resp.StatusCode == http.StatusMethodNotAllowed { - _, err = httpClient.Get(url) + _, err = U.Get(url) } if resp != nil { resp.Body.Close() diff --git a/internal/api/v1/list.go b/internal/api/v1/list.go index 48a8744b..1550d505 100644 --- a/internal/api/v1/list.go +++ b/internal/api/v1/list.go @@ -8,19 +8,28 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" + "github.com/yusing/go-proxy/internal/net/http/middleware" +) + +const ( + ListRoutes = "routes" + ListConfigFiles = "config_files" + ListMiddlewareTrace = "middleware_trace" ) func List(cfg *config.Config, w http.ResponseWriter, r *http.Request) { what := r.PathValue("what") if what == "" { - what = "routes" + what = ListRoutes } switch what { - case "routes": + case ListRoutes: listRoutes(cfg, w, r) - case "config_files": + case ListConfigFiles: listConfigFiles(w, r) + case ListMiddlewareTrace: + listMiddlewareTrace(w, r) default: U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) } @@ -59,3 +68,12 @@ func listConfigFiles(w http.ResponseWriter, r *http.Request) { } w.Write(resp) } + +func listMiddlewareTrace(w http.ResponseWriter, r *http.Request) { + resp, err := json.Marshal(middleware.GetAllTrace()) + if err != nil { + U.HandleErr(w, r, err) + return + } + w.Write(resp) +} diff --git a/internal/api/v1/utils/localhost.go b/internal/api/v1/query/query.go similarity index 56% rename from internal/api/v1/utils/localhost.go rename to internal/api/v1/query/query.go index 66e9070a..b588f83a 100644 --- a/internal/api/v1/utils/localhost.go +++ b/internal/api/v1/query/query.go @@ -1,4 +1,4 @@ -package utils +package query import ( "encoding/json" @@ -6,12 +6,15 @@ import ( "io" "net/http" + v1 "github.com/yusing/go-proxy/internal/api/v1" + U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/net/http/middleware" ) func ReloadServer() E.NestedError { - resp, err := httpClient.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil) + resp, err := U.Post(fmt.Sprintf("%s/v1/reload", common.APIHTTPURL), "", nil) if err != nil { return E.From(err) } @@ -32,7 +35,7 @@ func ReloadServer() E.NestedError { } func ListRoutes() (map[string]map[string]any, E.NestedError) { - resp, err := httpClient.Get(fmt.Sprintf("%s/v1/list/routes", common.APIHTTPURL)) + resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListRoutes)) if err != nil { return nil, E.From(err) } @@ -47,3 +50,20 @@ func ListRoutes() (map[string]map[string]any, E.NestedError) { } return routes, nil } + +func ListMiddlewareTraces() (middleware.Traces, E.NestedError) { + resp, err := U.Get(fmt.Sprintf("%s/v1/list/%s", common.APIHTTPURL, v1.ListMiddlewareTrace)) + if err != nil { + return nil, E.From(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, E.Failure("list middleware trace").Extraf("status code: %v", resp.StatusCode) + } + var traces middleware.Traces + err = json.NewDecoder(resp.Body).Decode(&traces) + if err != nil { + return nil, E.From(err) + } + return traces, nil +} diff --git a/internal/api/v1/utils/http_client.go b/internal/api/v1/utils/http_client.go index bcacbf37..a3afb7bb 100644 --- a/internal/api/v1/utils/http_client.go +++ b/internal/api/v1/utils/http_client.go @@ -8,7 +8,7 @@ import ( "github.com/yusing/go-proxy/internal/common" ) -var httpClient = &http.Client{ +var HTTPClient = &http.Client{ Timeout: common.ConnectionTimeout, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -21,3 +21,7 @@ var httpClient = &http.Client{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } + +var Get = HTTPClient.Get +var Post = HTTPClient.Post +var Head = HTTPClient.Head diff --git a/internal/common/args.go b/internal/common/args.go index cdc48f4e..55059912 100644 --- a/internal/common/args.go +++ b/internal/common/args.go @@ -2,9 +2,9 @@ package common import ( "flag" + "fmt" "github.com/sirupsen/logrus" - E "github.com/yusing/go-proxy/internal/error" ) type Args struct { @@ -20,6 +20,7 @@ const ( CommandReload = "reload" CommandDebugListEntries = "debug-ls-entries" CommandDebugListProviders = "debug-ls-providers" + CommandDebugListMTrace = "debug-ls-mtrace" ) var ValidCommands = []string{ @@ -31,23 +32,24 @@ var ValidCommands = []string{ CommandReload, CommandDebugListEntries, CommandDebugListProviders, + CommandDebugListMTrace, } func GetArgs() Args { var args Args flag.Parse() args.Command = flag.Arg(0) - if err := validateArg(args.Command); err.HasError() { + if err := validateArg(args.Command); err != nil { logrus.Fatal(err) } return args } -func validateArg(arg string) E.NestedError { +func validateArg(arg string) error { for _, v := range ValidCommands { if arg == v { return nil } } - return E.Invalid("argument", arg) + return fmt.Errorf("invalid command: %s", arg) } diff --git a/internal/common/env.go b/internal/common/env.go index e1db32d2..ddfe9359 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -4,14 +4,14 @@ import ( "fmt" "net" "os" + "strings" "github.com/sirupsen/logrus" - U "github.com/yusing/go-proxy/internal/utils" ) var ( NoSchemaValidation = GetEnvBool("GOPROXY_NO_SCHEMA_VALIDATION", false) - IsTest = GetEnvBool("GOPROXY_TEST", false) + IsTest = GetEnvBool("GOPROXY_TEST", false) || strings.HasSuffix(os.Args[0], ".test") IsDebug = GetEnvBool("GOPROXY_DEBUG", IsTest) ProxyHTTPAddr, @@ -35,7 +35,14 @@ func GetEnvBool(key string, defaultValue bool) bool { if !ok || value == "" { return defaultValue } - return U.ParseBool(value) + switch strings.ToLower(value) { + case "true", "yes", "1": + return true + case "false", "no", "0": + return false + default: + return defaultValue + } } func GetEnv(key, defaultValue string) string { diff --git a/internal/net/http/middleware/cidr_whitelist.go b/internal/net/http/middleware/cidr_whitelist.go index cf2c62bf..2d6e3249 100644 --- a/internal/net/http/middleware/cidr_whitelist.go +++ b/internal/net/http/middleware/cidr_whitelist.go @@ -7,6 +7,7 @@ import ( D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/types" + F "github.com/yusing/go-proxy/internal/utils/functional" ) type cidrWhitelist struct { @@ -19,7 +20,7 @@ type cidrWhitelistOpts struct { StatusCode int Message string - trustedAddr map[string]struct{} // cache for trusted IPs + cachedAddr F.Map[string, bool] // cache for trusted IPs } var CIDRWhiteList = &cidrWhitelist{ @@ -28,15 +29,16 @@ var CIDRWhiteList = &cidrWhitelist{ "allow": D.YamlStringListParser, "statusCode": D.IntParser, }, + withOptions: NewCIDRWhitelist, }, } var cidrWhitelistDefaults = func() *cidrWhitelistOpts { return &cidrWhitelistOpts{ - Allow: []*types.CIDR{}, - StatusCode: http.StatusForbidden, - Message: "IP not allowed", - trustedAddr: make(map[string]struct{}), + Allow: []*types.CIDR{}, + StatusCode: http.StatusForbidden, + Message: "IP not allowed", + cachedAddr: F.NewMapOf[string, bool](), } } @@ -57,23 +59,32 @@ func NewCIDRWhitelist(opts OptionsRaw) (*Middleware, E.NestedError) { return wl.m, nil } -func (wl *cidrWhitelist) checkIP(next http.Handler, w ResponseWriter, r *Request) { - var ok bool - if _, ok = wl.trustedAddr[r.RemoteAddr]; !ok { - ip := net.IP(r.RemoteAddr) +func (wl *cidrWhitelist) checkIP(next http.HandlerFunc, w ResponseWriter, r *Request) { + var allow, ok bool + if allow, ok = wl.cachedAddr.Load(r.RemoteAddr); !ok { + ipStr, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ipStr = r.RemoteAddr + } + ip := net.ParseIP(ipStr) for _, cidr := range wl.cidrWhitelistOpts.Allow { if cidr.Contains(ip) { - wl.trustedAddr[r.RemoteAddr] = struct{}{} - ok = true + wl.cachedAddr.Store(r.RemoteAddr, true) + allow = true + wl.m.AddTracef("client %s is allowed", ipStr).With("allowed CIDR", cidr) break } } + if !allow { + wl.cachedAddr.Store(r.RemoteAddr, false) + wl.m.AddTracef("client %s is forbidden", ipStr).With("allowed CIDRs", wl.cidrWhitelistOpts.Allow) + } } - if !ok { + if !allow { w.WriteHeader(wl.StatusCode) w.Write([]byte(wl.Message)) return } - next.ServeHTTP(w, r) + next(w, r) } diff --git a/internal/net/http/middleware/cidr_whitelist_test.go b/internal/net/http/middleware/cidr_whitelist_test.go new file mode 100644 index 00000000..0daeb9d8 --- /dev/null +++ b/internal/net/http/middleware/cidr_whitelist_test.go @@ -0,0 +1,42 @@ +package middleware + +import ( + _ "embed" + "net/http" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +//go:embed test_data/cidr_whitelist_test.yml +var testCIDRWhitelistCompose []byte +var deny, accept *Middleware + +func TestCIDRWhitelist(t *testing.T) { + mids, err := BuildMiddlewaresFromYAML(testCIDRWhitelistCompose) + if err != nil { + panic(err) + } + deny = mids["deny@file"] + accept = mids["accept@file"] + if deny == nil || accept == nil { + panic("bug occurred") + } + + t.Run("deny", func(t *testing.T) { + for range 10 { + result, err := newMiddlewareTest(deny, nil) + ExpectNoError(t, err.Error()) + ExpectEqual(t, result.ResponseStatus, cidrWhitelistDefaults().StatusCode) + ExpectEqual(t, string(result.Data), cidrWhitelistDefaults().Message) + } + }) + + t.Run("accept", func(t *testing.T) { + for range 10 { + result, err := newMiddlewareTest(accept, nil) + ExpectNoError(t, err.Error()) + ExpectEqual(t, result.ResponseStatus, http.StatusOK) + } + }) +} diff --git a/internal/net/http/middleware/cloudflare_real_ip.go b/internal/net/http/middleware/cloudflare_real_ip.go index 05dbe448..96101f99 100644 --- a/internal/net/http/middleware/cloudflare_real_ip.go +++ b/internal/net/http/middleware/cloudflare_real_ip.go @@ -39,12 +39,13 @@ func NewCloudflareRealIP(_ OptionsRaw) (*Middleware, E.NestedError) { cri := new(realIP) cri.m = &Middleware{ impl: cri, - rewrite: func(r *Request) { + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { cidrs := tryFetchCFCIDR() if cidrs != nil { cri.From = cidrs } cri.setRealIP(r) + next(w, r) }, } cri.realIPOpts = &realIPOpts{ diff --git a/internal/net/http/middleware/custom_error_page.go b/internal/net/http/middleware/custom_error_page.go index 41e682f3..95db5daa 100644 --- a/internal/net/http/middleware/custom_error_page.go +++ b/internal/net/http/middleware/custom_error_page.go @@ -15,9 +15,9 @@ import ( ) var CustomErrorPage = &Middleware{ - before: func(next http.Handler, w ResponseWriter, r *Request) { + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { if !ServeStaticErrorPageFile(w, r) { - next.ServeHTTP(w, r) + next(w, r) } }, modifyResponse: func(resp *Response) error { diff --git a/internal/net/http/middleware/forward_auth.go b/internal/net/http/middleware/forward_auth.go index 6853e578..ca9e9068 100644 --- a/internal/net/http/middleware/forward_auth.go +++ b/internal/net/http/middleware/forward_auth.go @@ -13,7 +13,6 @@ import ( "strings" "time" - "github.com/sirupsen/logrus" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" gpHTTP "github.com/yusing/go-proxy/internal/net/http" @@ -45,7 +44,6 @@ var ForwardAuth = func() *forwardAuth { fa.m.withOptions = NewForwardAuthfunc return fa }() -var faLogger = logrus.WithField("middleware", "ForwardAuth") func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) { faWithOpts := new(forwardAuth) @@ -80,7 +78,7 @@ func NewForwardAuthfunc(optsRaw OptionsRaw) (*Middleware, E.NestedError) { return faWithOpts.m, nil } -func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request) { +func (fa *forwardAuth) forward(next http.HandlerFunc, w ResponseWriter, req *Request) { gpHTTP.RemoveHop(req.Header) faReq, err := http.NewRequestWithContext( @@ -90,7 +88,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request nil, ) if err != nil { - faLogger.Debugf("new request err to %s: %s", fa.Address, err) + fa.m.AddTracef("new request err to %s", fa.Address).With("error", err) w.WriteHeader(http.StatusInternalServerError) return } @@ -103,7 +101,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request faResp, err := fa.client.Do(faReq) if err != nil { - faLogger.Debugf("failed to call %s: %s", fa.Address, err) + fa.m.AddTracef("failed to call %s", fa.Address).With("error", err) w.WriteHeader(http.StatusInternalServerError) return } @@ -111,7 +109,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request body, err := io.ReadAll(faResp.Body) if err != nil { - faLogger.Debugf("failed to read response body from %s: %s", fa.Address, err) + fa.m.AddTracef("failed to read response body from %s", fa.Address).With("error", err) w.WriteHeader(http.StatusInternalServerError) return } @@ -122,7 +120,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request redirectURL, err := faResp.Location() if err != nil { - faLogger.Debugf("failed to get location from %s: %s", fa.Address, err) + fa.m.AddTracef("failed to get location from %s", fa.Address).With("error", err) w.WriteHeader(http.StatusInternalServerError) return } else if redirectURL.String() != "" { @@ -132,7 +130,7 @@ func (fa *forwardAuth) forward(next http.Handler, w ResponseWriter, req *Request w.WriteHeader(faResp.StatusCode) if _, err = w.Write(body); err != nil { - faLogger.Debugf("failed to write response body from %s: %s", fa.Address, err) + fa.m.AddTracef("failed to write response body from %s", fa.Address).With("error", err) } return } diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index ec18fb38..3908c964 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -2,6 +2,7 @@ package middleware import ( "encoding/json" + "errors" "net/http" D "github.com/yusing/go-proxy/internal/docker" @@ -21,7 +22,7 @@ type ( Header = http.Header Cookie = http.Cookie - BeforeFunc func(next http.Handler, w ResponseWriter, r *Request) + BeforeFunc func(next http.HandlerFunc, w ResponseWriter, r *Request) RewriteFunc func(req *Request) ModifyResponseFunc func(resp *Response) error CloneWithOptFunc func(opts OptionsRaw) (*Middleware, E.NestedError) @@ -33,23 +34,38 @@ type ( name string before BeforeFunc // runs before ReverseProxy.ServeHTTP - rewrite RewriteFunc // runs after ReverseProxy.Rewrite modifyResponse ModifyResponseFunc // runs after ReverseProxy.ModifyResponse - transport http.RoundTripper - withOptions CloneWithOptFunc labelParserMap D.ValueParserMap impl any + + parent *Middleware + children []*Middleware + trace bool } ) var Deserialize = U.Deserialize +func Rewrite(r RewriteFunc) BeforeFunc { + return func(next http.HandlerFunc, w ResponseWriter, req *Request) { + r(req) + next(w, req) + } +} + func (m *Middleware) Name() string { return m.name } +func (m *Middleware) Fullname() string { + if m.parent != nil { + return m.parent.Fullname() + "." + m.name + } + return m.name +} + func (m *Middleware) String() string { return m.name } @@ -72,14 +88,21 @@ func (m *Middleware) WithOptionsClone(optsRaw OptionsRaw) (*Middleware, E.Nested // WithOptionsClone is called only once // set withOptions and labelParser will not be used after that - return &Middleware{m.name, m.before, m.rewrite, m.modifyResponse, m.transport, nil, nil, m.impl}, nil + return &Middleware{ + m.name, + m.before, + m.modifyResponse, + nil, nil, + m.impl, + m.parent, + m.children, + false, + }, nil } // TODO: check conflict or duplicates -func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res E.NestedError) { - befores := make([]BeforeFunc, 0, len(middlewares)) - rewrites := make([]RewriteFunc, 0, len(middlewares)) - modResps := make([]ModifyResponseFunc, 0, len(middlewares)) +func PatchReverseProxy(rpName string, rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (res E.NestedError) { + middlewares := make([]*Middleware, 0, len(middlewaresMap)) invalidM := E.NewBuilder("invalid middlewares") invalidOpts := E.NewBuilder("invalid options") @@ -88,7 +111,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res invalidM.To(&res) }() - for name, opts := range middlewares { + for name, opts := range middlewaresMap { m, ok := Get(name) if !ok { invalidM.Add(E.NotExist("middleware", name)) @@ -100,56 +123,35 @@ func PatchReverseProxy(rp *ReverseProxy, middlewares map[string]OptionsRaw) (res invalidOpts.Add(err.Subject(name)) continue } - if m.before != nil { - befores = append(befores, m.before) - } - if m.rewrite != nil { - rewrites = append(rewrites, m.rewrite) - } - if m.modifyResponse != nil { - modResps = append(modResps, m.modifyResponse) - } + middlewares = append(middlewares, m) } if invalidM.HasError() { return } - origServeHTTP := rp.ServeHTTP - for i, before := range befores { - if i < len(befores)-1 { - rp.ServeHTTP = func(w ResponseWriter, r *Request) { - before(rp.ServeHTTP, w, r) - } - } else { - rp.ServeHTTP = func(w ResponseWriter, r *Request) { - before(origServeHTTP, w, r) - } - } - } - - if len(rewrites) > 0 { - origServeHTTP = rp.ServeHTTP - rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) { - for _, rewrite := range rewrites { - rewrite(r) - } - origServeHTTP(w, r) - } - } - - if len(modResps) > 0 { - if rp.ModifyResponse != nil { - modResps = append([]ModifyResponseFunc{rp.ModifyResponse}, modResps...) - } - rp.ModifyResponse = func(res *Response) error { - b := E.NewBuilder("errors in middleware ModifyResponse") - for _, mr := range modResps { - b.AddE(mr(res)) - } - return b.Build().Error() - } - } - + patchReverseProxy(rpName, rp, middlewares) return } + +func patchReverseProxy(rpName string, rp *ReverseProxy, middlewares []*Middleware) { + mid := BuildMiddlewareFromChain(rpName, middlewares) + + if mid.before != nil { + ori := rp.ServeHTTP + rp.ServeHTTP = func(w http.ResponseWriter, r *http.Request) { + mid.before(ori, w, r) + } + } + + if mid.modifyResponse != nil { + if rp.ModifyResponse != nil { + ori := rp.ModifyResponse + rp.ModifyResponse = func(res *http.Response) error { + return errors.Join(mid.modifyResponse(res), ori(res)) + } + } else { + rp.ModifyResponse = mid.modifyResponse + } + } +} diff --git a/internal/net/http/middleware/middleware_builder.go b/internal/net/http/middleware/middleware_builder.go index 78e7c967..8b23ca87 100644 --- a/internal/net/http/middleware/middleware_builder.go +++ b/internal/net/http/middleware/middleware_builder.go @@ -1,9 +1,11 @@ package middleware import ( + "fmt" "net/http" "os" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" "gopkg.in/yaml.v3" ) @@ -23,7 +25,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, var rawMap map[string][]map[string]any err := yaml.Unmarshal(data, &rawMap) if err != nil { - b.Add(E.FailWith("toml unmarshal", err)) + b.Add(E.FailWith("yaml unmarshal", err)) return } middlewares = make(map[string]*Middleware) @@ -31,18 +33,22 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, chainErr := E.NewBuilder(name) chain := make([]*Middleware, 0, len(defs)) for i, def := range defs { - if def["use"] == nil || def["use"].(string) == "" { - chainErr.Add(E.Missing("use").Subjectf("%s.%d", name, i)) + if def["use"] == nil || def["use"] == "" { + chainErr.Add(E.Missing("use").Subjectf(".%d", i)) continue } baseName := def["use"].(string) base, ok := Get(baseName) if !ok { - chainErr.Add(E.NotExist("middleware", baseName).Subjectf("%s.%d", name, i)) - continue + base, ok = middlewares[baseName] + if !ok { + chainErr.Add(E.NotExist("middleware", baseName).Subjectf(".%d", i)) + continue + } } delete(def, "use") m, err := base.WithOptionsClone(def) + m.name = fmt.Sprintf("%s[%d]", name, i) if err != nil { chainErr.Add(err.Subjectf("item%d", i)) continue @@ -52,8 +58,7 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, if chainErr.HasError() { b.Add(chainErr.Build()) } else { - name = name + "@file" - middlewares[name] = BuildMiddlewareFromChain(name, chain) + middlewares[name+"@file"] = BuildMiddlewareFromChain(name, chain) } } return @@ -61,47 +66,49 @@ func BuildMiddlewaresFromYAML(data []byte) (middlewares map[string]*Middleware, // TODO: check conflict or duplicates func BuildMiddlewareFromChain(name string, chain []*Middleware) *Middleware { - var ( - befores []BeforeFunc - rewrites []RewriteFunc - modResps []ModifyResponseFunc - ) - for _, m := range chain { - if m.before != nil { - befores = append(befores, m.before) + m := &Middleware{name: name, children: chain} + + var befores []*Middleware + var modResps []*Middleware + + for _, comp := range chain { + if comp.before != nil { + befores = append(befores, comp) } - if m.rewrite != nil { - rewrites = append(rewrites, m.rewrite) - } - if m.modifyResponse != nil { - modResps = append(modResps, m.modifyResponse) + if comp.modifyResponse != nil { + modResps = append(modResps, comp) } + comp.parent = m } - m := &Middleware{name: name} if len(befores) > 0 { - m.before = func(next http.Handler, w ResponseWriter, r *Request) { - for _, before := range befores { - before(next, w, r) - } - } - } - if len(rewrites) > 0 { - m.rewrite = func(r *Request) { - for _, rewrite := range rewrites { - rewrite(r) - } - } + m.before = buildBefores(befores) } if len(modResps) > 0 { m.modifyResponse = func(res *Response) error { - b := E.NewBuilder("errors in middleware %s", name) + b := E.NewBuilder("errors in middleware") for _, mr := range modResps { - b.AddE(mr(res)) + b.Add(E.From(mr.modifyResponse(res)).Subject(mr.name)) } return b.Build().Error() } } + if common.IsDebug { + m.EnableTrace() + m.AddTracef("middleware created") + } return m } + +func buildBefores(befores []*Middleware) BeforeFunc { + if len(befores) == 1 { + return befores[0].before + } + nextBefores := buildBefores(befores[1:]) + return func(next http.HandlerFunc, w ResponseWriter, r *Request) { + befores[0].before(func(w ResponseWriter, r *Request) { + nextBefores(next, w, r) + }, w, r) + } +} diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index 21255aad..ada30b9e 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -67,10 +67,10 @@ func LoadComposeFiles() { b.Add(E.Duplicated("middleware", name)) continue } - middlewares[name] = m + middlewares[U.ToLowerNoSnake(name)] = m logger.Infof("middleware %s loaded from %s", name, path.Base(defFile)) } - b.Add(err.Subject(defFile)) + b.Add(err.Subject(path.Base(defFile))) } if b.HasError() { logger.Error(b.Build()) diff --git a/internal/net/http/middleware/modify_request.go b/internal/net/http/middleware/modify_request.go index f36fe07c..69febf9c 100644 --- a/internal/net/http/middleware/modify_request.go +++ b/internal/net/http/middleware/modify_request.go @@ -1,6 +1,7 @@ package middleware import ( + "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" ) @@ -32,9 +33,15 @@ var ModifyRequest = func() *modifyRequest { func NewModifyRequest(optsRaw OptionsRaw) (*Middleware, E.NestedError) { mr := new(modifyRequest) + var mrFunc RewriteFunc + if common.IsDebug { + mrFunc = mr.modifyRequestWithTrace + } else { + mrFunc = mr.modifyRequest + } mr.m = &Middleware{ - impl: mr, - rewrite: mr.modifyRequest, + impl: mr, + before: Rewrite(mrFunc), } mr.modifyRequestOpts = new(modifyRequestOpts) err := Deserialize(optsRaw, mr.modifyRequestOpts) @@ -55,3 +62,9 @@ func (mr *modifyRequest) modifyRequest(req *Request) { req.Header.Del(k) } } + +func (mr *modifyRequest) modifyRequestWithTrace(req *Request) { + mr.m.AddTraceRequest("before modify request", req) + mr.modifyRequest(req) + mr.m.AddTraceRequest("after modify request", req) +} diff --git a/internal/net/http/middleware/modify_response.go b/internal/net/http/middleware/modify_response.go index a212f0cf..ecf7b0ae 100644 --- a/internal/net/http/middleware/modify_response.go +++ b/internal/net/http/middleware/modify_response.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" + "github.com/yusing/go-proxy/internal/common" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" ) @@ -34,9 +35,11 @@ var ModifyResponse = func() (mr *modifyResponse) { func NewModifyResponse(optsRaw OptionsRaw) (*Middleware, E.NestedError) { mr := new(modifyResponse) - mr.m = &Middleware{ - impl: mr, - modifyResponse: mr.modifyResponse, + mr.m = &Middleware{impl: mr} + if common.IsDebug { + mr.m.modifyResponse = mr.modifyResponseWithTrace + } else { + mr.m.modifyResponse = mr.modifyResponse } mr.modifyResponseOpts = new(modifyResponseOpts) err := Deserialize(optsRaw, mr.modifyResponseOpts) @@ -58,3 +61,10 @@ func (mr *modifyResponse) modifyResponse(resp *http.Response) error { } return nil } + +func (mr *modifyResponse) modifyResponseWithTrace(resp *http.Response) error { + mr.m.AddTraceResponse("before modify response", resp) + err := mr.modifyResponse(resp) + mr.m.AddTraceResponse("after modify response", resp) + return err +} diff --git a/internal/net/http/middleware/real_ip.go b/internal/net/http/middleware/real_ip.go index 87094765..884ad00d 100644 --- a/internal/net/http/middleware/real_ip.go +++ b/internal/net/http/middleware/real_ip.go @@ -2,8 +2,8 @@ package middleware import ( "net" + "net/http" - "github.com/sirupsen/logrus" D "github.com/yusing/go-proxy/internal/docker" E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/types" @@ -49,13 +49,14 @@ var realIPOptsDefault = func() *realIPOpts { } } -var realIPLogger = logrus.WithField("middleware", "RealIP") - func NewRealIP(opts OptionsRaw) (*Middleware, E.NestedError) { riWithOpts := new(realIP) riWithOpts.m = &Middleware{ - impl: riWithOpts, - rewrite: riWithOpts.setRealIP, + impl: riWithOpts, + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { + riWithOpts.setRealIP(r) + next(w, r) + }, } riWithOpts.realIPOpts = realIPOptsDefault() err := Deserialize(opts, riWithOpts.realIPOpts) @@ -78,7 +79,7 @@ func (ri *realIP) isInCIDRList(ip net.IP) bool { func (ri *realIP) setRealIP(req *Request) { clientIPStr, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { - realIPLogger.Debugf("failed to split host port %s", err) + clientIPStr = req.RemoteAddr } clientIP := net.ParseIP(clientIPStr) @@ -90,7 +91,7 @@ func (ri *realIP) setRealIP(req *Request) { } } if !isTrusted { - realIPLogger.Debugf("client ip %s is not trusted", clientIP) + ri.m.AddTracef("client ip %s is not trusted", clientIP).With("allowed CIDRs", ri.From) return } @@ -98,7 +99,7 @@ func (ri *realIP) setRealIP(req *Request) { var lastNonTrustedIP string if len(realIPs) == 0 { - realIPLogger.Debugf("no real ip found in header %q", ri.Header) + ri.m.AddTracef("no real ip found in header %s", ri.Header).WithRequest(req) return } @@ -110,14 +111,16 @@ func (ri *realIP) setRealIP(req *Request) { lastNonTrustedIP = r } } - if lastNonTrustedIP == "" { - realIPLogger.Debugf("no non-trusted ip found in header %q", ri.Header) - return - } + } + + if lastNonTrustedIP == "" { + ri.m.AddTracef("no non-trusted ip found").With("allowed CIDRs", ri.From).With("ips", realIPs) + return } req.RemoteAddr = lastNonTrustedIP req.Header.Set(ri.Header, lastNonTrustedIP) req.Header.Set("X-Real-IP", lastNonTrustedIP) - req.Header.Set("X-Forwarded-For", lastNonTrustedIP) + req.Header.Set(xForwardedFor, lastNonTrustedIP) + ri.m.AddTracef("set real ip %s", lastNonTrustedIP) } diff --git a/internal/net/http/middleware/real_ip_test.go b/internal/net/http/middleware/real_ip_test.go index 71324561..fa98a1d4 100644 --- a/internal/net/http/middleware/real_ip_test.go +++ b/internal/net/http/middleware/real_ip_test.go @@ -2,13 +2,15 @@ package middleware import ( "net" + "net/http" + "strings" "testing" "github.com/yusing/go-proxy/internal/types" . "github.com/yusing/go-proxy/internal/utils/testing" ) -func TestSetRealIP(t *testing.T) { +func TestSetRealIPOpts(t *testing.T) { opts := OptionsRaw{ "header": "X-Real-IP", "from": []string{ @@ -37,13 +39,39 @@ func TestSetRealIP(t *testing.T) { Recursive: true, } - t.Run("set_options", func(t *testing.T) { - ri, err := RealIP.m.WithOptionsClone(opts) - ExpectNoError(t, err.Error()) - // ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) - // ExpectDeepEqual(t, ri.impl.(*realIP).From, optExpected.From) - // ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) - ExpectDeepEqual(t, ri.impl.(*realIP).realIPOpts, optExpected) - }) - // TODO test + ri, err := NewRealIP(opts) + ExpectNoError(t, err.Error()) + ExpectEqual(t, ri.impl.(*realIP).Header, optExpected.Header) + ExpectEqual(t, ri.impl.(*realIP).Recursive, optExpected.Recursive) + for i, CIDR := range ri.impl.(*realIP).From { + ExpectEqual(t, CIDR.String(), optExpected.From[i].String()) + } +} + +func TestSetRealIP(t *testing.T) { + const ( + testHeader = "X-Real-IP" + testRealIP = "192.168.1.1" + ) + opts := OptionsRaw{ + "header": testHeader, + "from": []string{"0.0.0.0/0"}, + } + optsMr := OptionsRaw{ + "set_headers": map[string]string{testHeader: testRealIP}, + } + realip, err := NewRealIP(opts) + ExpectNoError(t, err.Error()) + + mr, err := NewModifyRequest(optsMr) + ExpectNoError(t, err.Error()) + + mid := BuildMiddlewareFromChain("test", []*Middleware{mr, realip}) + + result, err := newMiddlewareTest(mid, nil) + ExpectNoError(t, err.Error()) + t.Log(traces) + ExpectEqual(t, result.ResponseStatus, http.StatusOK) + ExpectEqual(t, strings.Split(result.RemoteAddr, ":")[0], testRealIP) + ExpectEqual(t, result.RequestHeaders.Get(xForwardedFor), testRealIP) } diff --git a/internal/net/http/middleware/redirect_http.go b/internal/net/http/middleware/redirect_http.go index a0e72e15..595580e2 100644 --- a/internal/net/http/middleware/redirect_http.go +++ b/internal/net/http/middleware/redirect_http.go @@ -7,13 +7,13 @@ import ( ) var RedirectHTTP = &Middleware{ - before: func(next http.Handler, w ResponseWriter, r *Request) { + before: func(next http.HandlerFunc, w ResponseWriter, r *Request) { if r.TLS == nil { r.URL.Scheme = "https" r.URL.Host = r.URL.Hostname() + ":" + common.ProxyHTTPSPort http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) return } - next.ServeHTTP(w, r) + next(w, r) }, } diff --git a/internal/net/http/middleware/test_data/cidr_whitelist_test.yml b/internal/net/http/middleware/test_data/cidr_whitelist_test.yml new file mode 100644 index 00000000..4c414dd2 --- /dev/null +++ b/internal/net/http/middleware/test_data/cidr_whitelist_test.yml @@ -0,0 +1,22 @@ +deny: + - use: ModifyRequest + setHeaders: + X-Real-IP: 192.168.1.1:1234 + - use: RealIP + header: X-Real-IP + from: + - 0.0.0.0/0 + - use: CIDRWhitelist + allow: + - 192.168.0.0/24 +accept: + - use: ModifyRequest + setHeaders: + X-Real-IP: 192.168.0.1:1234 + - use: RealIP + header: X-Real-IP + from: + - 0.0.0.0/0 + - use: CIDRWhitelist + allow: + - 192.168.0.0/24 diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index a1f89dd9..62862680 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "net/url" + "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" gpHTTP "github.com/yusing/go-proxy/internal/net/http" ) @@ -20,6 +21,9 @@ var testHeaders http.Header const testHost = "example.com" func init() { + if !common.IsTest { + return + } tmp := map[string]string{} err := json.Unmarshal(testHeadersRaw, &tmp) if err != nil { @@ -31,13 +35,15 @@ func init() { } } -type requestHeaderRecorder struct { +type requestRecorder struct { parent http.RoundTripper - reqHeaders http.Header + headers http.Header + remoteAddr string } -func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, error) { - rt.reqHeaders = req.Header +func (rt *requestRecorder) RoundTrip(req *http.Request) (*http.Response, error) { + rt.headers = req.Header + rt.remoteAddr = req.RemoteAddr if rt.parent != nil { return rt.parent.RoundTrip(req) } @@ -46,6 +52,7 @@ func (rt *requestHeaderRecorder) RoundTrip(req *http.Request) (*http.Response, e Header: testHeaders, Body: io.NopCloser(bytes.NewBufferString("OK")), Request: req, + TLS: req.TLS, }, nil } @@ -53,6 +60,7 @@ type TestResult struct { RequestHeaders http.Header ResponseHeaders http.Header ResponseStatus int + RemoteAddr string Data []byte } @@ -65,7 +73,7 @@ type testArgs struct { func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.NestedError) { var body io.Reader - var rt = new(requestHeaderRecorder) + var rr = new(requestRecorder) var proxyURL *url.URL var requestTarget string var err error @@ -98,17 +106,16 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N if err != nil { return nil, E.From(err) } - rt.parent = http.DefaultTransport + rr.parent = http.DefaultTransport } else { proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect } - rp := gpHTTP.NewReverseProxy(proxyURL, rt) - setOptErr := PatchReverseProxy(rp, map[string]OptionsRaw{ - middleware.name: args.middlewareOpt, - }) + rp := gpHTTP.NewReverseProxy(proxyURL, rr) + mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt) if setOptErr != nil { return nil, setOptErr } + patchReverseProxy(middleware.name, rp, []*Middleware{mid}) rp.ServeHTTP(w, req) resp := w.Result() defer resp.Body.Close() @@ -117,9 +124,10 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N return nil, E.From(err) } return &TestResult{ - RequestHeaders: rt.reqHeaders, + RequestHeaders: rr.headers, ResponseHeaders: resp.Header, ResponseStatus: resp.StatusCode, + RemoteAddr: rr.remoteAddr, Data: data, }, nil } diff --git a/internal/net/http/middleware/trace.go b/internal/net/http/middleware/trace.go new file mode 100644 index 00000000..d654d81d --- /dev/null +++ b/internal/net/http/middleware/trace.go @@ -0,0 +1,99 @@ +package middleware + +import ( + "fmt" + "net/http" + "sync" + "time" + + U "github.com/yusing/go-proxy/internal/utils" +) + +type Trace struct { + Time string `json:"time,omitempty"` + Caller string `json:"caller,omitempty"` + URL string `json:"url,omitempty"` + Message string `json:"msg"` + ReqHeaders http.Header `json:"req_headers,omitempty"` + RespHeaders http.Header `json:"resp_headers,omitempty"` + Additional map[string]any `json:"additional,omitempty"` +} + +type Traces []*Trace + +var traces = Traces{} +var tracesMu sync.Mutex + +const MaxTraceNum = 1000 + +func GetAllTrace() []*Trace { + return traces +} + +func (tr *Trace) WithRequest(req *Request) *Trace { + if tr == nil { + return nil + } + tr.URL = req.RequestURI + tr.ReqHeaders = req.Header.Clone() + return tr +} + +func (tr *Trace) WithResponse(resp *Response) *Trace { + if tr == nil { + return nil + } + tr.URL = resp.Request.RequestURI + tr.ReqHeaders = resp.Request.Header.Clone() + tr.RespHeaders = resp.Header.Clone() + return tr +} + +func (tr *Trace) With(what string, additional any) *Trace { + if tr == nil { + return nil + } + + if tr.Additional == nil { + tr.Additional = map[string]any{} + } + tr.Additional[what] = additional + return tr +} + +func (m *Middleware) EnableTrace() { + m.trace = true + for _, child := range m.children { + child.parent = m + child.EnableTrace() + } +} + +func (m *Middleware) AddTracef(msg string, args ...any) *Trace { + if !m.trace { + return nil + } + return addTrace(&Trace{ + Time: U.FormatTime(time.Now()), + Caller: m.Fullname(), + Message: fmt.Sprintf(msg, args...), + }) +} + +func (m *Middleware) AddTraceRequest(msg string, req *Request) *Trace { + return m.AddTracef("%s", msg).WithRequest(req) +} + +func (m *Middleware) AddTraceResponse(msg string, resp *Response) *Trace { + return m.AddTracef("%s", msg).WithResponse(resp) +} + +func addTrace(t *Trace) *Trace { + tracesMu.Lock() + defer tracesMu.Unlock() + if len(traces) > MaxTraceNum { + traces = traces[1:] + } + traces = append(traces, t) + return t +} diff --git a/internal/net/http/middleware/x_forwarded.go b/internal/net/http/middleware/x_forwarded.go index 1bd59032..0712687d 100644 --- a/internal/net/http/middleware/x_forwarded.go +++ b/internal/net/http/middleware/x_forwarded.go @@ -2,6 +2,7 @@ package middleware import ( "net" + "net/http" ) const ( @@ -14,7 +15,7 @@ const ( ) var SetXForwarded = &Middleware{ - rewrite: func(req *Request) { + before: func(next http.HandlerFunc, w ResponseWriter, req *Request) { req.Header.Del("Forwarded") req.Header.Del(xForwardedFor) req.Header.Del(xForwardedHost) @@ -23,7 +24,7 @@ var SetXForwarded = &Middleware{ if err == nil { req.Header.Set(xForwardedFor, clientIP) } else { - req.Header.Del(xForwardedFor) + req.Header.Set(xForwardedFor, req.RemoteAddr) } req.Header.Set(xForwardedHost, req.Host) if req.TLS == nil { @@ -31,14 +32,16 @@ var SetXForwarded = &Middleware{ } else { req.Header.Set(xForwardedProto, "https") } + next(w, req) }, } var HideXForwarded = &Middleware{ - rewrite: func(req *Request) { + before: func(next http.HandlerFunc, w ResponseWriter, req *Request) { req.Header.Del("Forwarded") req.Header.Del(xForwardedFor) req.Header.Del(xForwardedHost) req.Header.Del(xForwardedProto) + next(w, req) }, } diff --git a/internal/route/http.go b/internal/route/http.go index 50d0cec4..23869012 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -68,7 +68,7 @@ func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { rp := NewReverseProxy(entry.URL, trans) if len(entry.Middlewares) > 0 { - err := middleware.PatchReverseProxy(rp, entry.Middlewares) + err := middleware.PatchReverseProxy(string(entry.Alias), rp, entry.Middlewares) if err != nil { return nil, err } diff --git a/internal/types/cidr.go b/internal/types/cidr.go index 8d3c4826..e3ab23fd 100644 --- a/internal/types/cidr.go +++ b/internal/types/cidr.go @@ -32,3 +32,7 @@ func (cidr *CIDR) Contains(ip net.IP) bool { func (cidr *CIDR) String() string { return (*net.IPNet)(cidr).String() } + +func (cidr *CIDR) Equals(other *CIDR) bool { + return (*net.IPNet)(cidr).IP.Equal(other.IP) && cidr.Mask.String() == other.Mask.String() +} diff --git a/internal/utils/format.go b/internal/utils/format.go index eaf9cc67..afa46f21 100644 --- a/internal/utils/format.go +++ b/internal/utils/format.go @@ -42,6 +42,10 @@ func FormatDuration(d time.Duration) string { return strings.Join(parts[:len(parts)-1], ", ") + " and " + parts[len(parts)-1] } +func FormatTime(t time.Time) string { + return t.Format("2006-01-02 15:04:05") +} + func ParseBool(s string) bool { switch strings.ToLower(s) { case "1", "true", "yes", "on": diff --git a/internal/utils/functional/slice.go b/internal/utils/functional/slice.go index e27f825e..d992aae9 100644 --- a/internal/utils/functional/slice.go +++ b/internal/utils/functional/slice.go @@ -1,19 +1,25 @@ package functional +import ( + "encoding/json" + "sync" +) + type Slice[T any] struct { - s []T + s []T + mu sync.Mutex } func NewSlice[T any]() *Slice[T] { - return &Slice[T]{make([]T, 0)} + return &Slice[T]{s: make([]T, 0)} } func NewSliceN[T any](n int) *Slice[T] { - return &Slice[T]{make([]T, n)} + return &Slice[T]{s: make([]T, n)} } func NewSliceFrom[T any](s []T) *Slice[T] { - return &Slice[T]{s} + return &Slice[T]{s: s} } func (s *Slice[T]) Size() int { @@ -46,6 +52,30 @@ func (s *Slice[T]) AddRange(other *Slice[T]) *Slice[T] { return s } +func (s *Slice[T]) SafeAdd(e T) *Slice[T] { + s.mu.Lock() + defer s.mu.Unlock() + return s.Add(e) +} + +func (s *Slice[T]) SafeAddRange(other *Slice[T]) *Slice[T] { + s.mu.Lock() + defer s.mu.Unlock() + return s.AddRange(other) +} + +func (s *Slice[T]) Pop() T { + v := s.s[len(s.s)-1] + s.s = s.s[:len(s.s)-1] + return v +} + +func (s *Slice[T]) SafePop() T { + s.mu.Lock() + defer s.mu.Unlock() + return s.Pop() +} + func (s *Slice[T]) ForEach(do func(T)) { for _, v := range s.s { do(v) @@ -57,7 +87,7 @@ func (s *Slice[T]) Map(m func(T) T) *Slice[T] { for i, v := range s.s { n[i] = m(v) } - return &Slice[T]{n} + return &Slice[T]{s: n} } func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] { @@ -67,5 +97,13 @@ func (s *Slice[T]) Filter(f func(T) bool) *Slice[T] { n = append(n, v) } } - return &Slice[T]{n} + return &Slice[T]{s: n} +} + +func (s *Slice[T]) String() string { + out, err := json.MarshalIndent(s.s, "", " ") + if err != nil { + panic(err) + } + return string(out) } diff --git a/internal/utils/testing/testing.go b/internal/utils/testing/testing.go index fb6e861b..bb184d35 100644 --- a/internal/utils/testing/testing.go +++ b/internal/utils/testing/testing.go @@ -2,10 +2,23 @@ package utils import ( "errors" + "os" "reflect" "testing" + + "github.com/yusing/go-proxy/internal/common" ) +func init() { + if common.IsTest { + os.Args = append([]string{os.Args[0], "-test.v"}, os.Args[1:]...) + } +} + +func IgnoreError[Result any](r Result, _ error) Result { + return r +} + func ExpectNoError(t *testing.T, err error) { t.Helper() if err != nil && !reflect.ValueOf(err).IsNil() {