From 4a2d42bfa95b3ba289f33c5d2c2de4db4cb437e9 Mon Sep 17 00:00:00 2001 From: yusing Date: Thu, 19 Sep 2024 20:40:03 +0800 Subject: [PATCH] v0.5.0-rc5: check release --- README.md | 9 +- compose.example.yml | 2 +- docs/docker.md | 28 ++- src/api/v1/checkhealth.go | 17 +- src/api/v1/file.go | 15 +- src/api/v1/reload.go | 2 +- src/api/v1/utils/error.go | 8 +- src/api/v1/utils/net.go | 2 +- src/autocert/config.go | 24 +- src/autocert/constants.go | 6 + src/autocert/provider.go | 74 +++--- src/autocert/provider_test/ovh_test.go | 7 +- src/common/args.go | 17 +- src/common/constants.go | 26 +- src/config/config.go | 217 ++++++++-------- src/docker/client.go | 61 +++-- src/docker/client_info.go | 40 +-- src/docker/container.go | 109 ++++++++ src/docker/idlewatcher/round_trip.go | 14 ++ src/docker/idlewatcher/watcher.go | 329 +++++++++++++++++++++++++ src/docker/inspect.go | 19 ++ src/docker/label.go | 8 +- src/docker/label_parser.go | 16 +- src/docker/label_parser_test.go | 59 +++-- src/docker/labels.go | 13 + src/error/builder.go | 48 +++- src/error/builder_test.go | 31 ++- src/error/error.go | 154 +++++++++--- src/error/error_test.go | 48 +++- src/error/errors.go | 39 ++- src/go.mod | 1 + src/go.sum | 2 + src/main.go | 59 +++-- src/models/proxy_entry.go | 64 +++-- src/proxy/entry.go | 142 +++++++---- src/proxy/fields/alias.go | 23 +- src/proxy/fields/headers.go | 4 +- src/proxy/fields/host.go | 4 +- src/proxy/fields/path_mode.go | 2 +- src/proxy/fields/path_pattern.go | 8 +- src/proxy/fields/port.go | 6 +- src/proxy/fields/scheme.go | 17 +- src/proxy/fields/signal.go | 17 ++ src/proxy/fields/stop_method.go | 23 ++ src/proxy/fields/stream_port.go | 12 +- src/proxy/fields/stream_scheme.go | 4 +- src/proxy/fields/timeout.go | 18 ++ src/proxy/provider/docker_provider.go | 222 ++++++++--------- src/proxy/provider/file_provider.go | 45 +++- src/proxy/provider/provider.go | 162 +++++------- src/proxy/reverse_proxy_mod.go | 2 +- src/route/http_route.go | 115 +++++---- src/route/route.go | 71 +++++- src/route/stream_route.go | 14 +- src/route/tcp_route.go | 5 +- src/utils/functional/map.go | 257 ++++++------------- src/utils/functional/map_test.go | 75 ++++++ src/utils/io.go | 104 +++----- src/utils/reflection.go | 2 +- src/utils/serialization.go | 41 ++- src/utils/string.go | 11 + src/utils/{ => testing}/testing.go | 22 +- src/watcher/docker_watcher.go | 14 +- src/watcher/event.go | 26 -- src/watcher/event/event.go | 34 +++ src/watcher/file_watcher.go | 1 + src/watcher/file_watcher_helper.go | 6 +- src/watcher/watcher.go | 1 + 68 files changed, 1971 insertions(+), 1107 deletions(-) create mode 100644 src/docker/container.go create mode 100644 src/docker/idlewatcher/round_trip.go create mode 100644 src/docker/idlewatcher/watcher.go create mode 100644 src/docker/inspect.go create mode 100644 src/docker/labels.go create mode 100644 src/proxy/fields/signal.go create mode 100644 src/proxy/fields/stop_method.go create mode 100644 src/proxy/fields/timeout.go create mode 100644 src/utils/functional/map_test.go create mode 100644 src/utils/string.go rename src/utils/{ => testing}/testing.go (67%) delete mode 100644 src/watcher/event.go create mode 100644 src/watcher/event/event.go diff --git a/README.md b/README.md index 9757d0d7..ddaf4b6a 100755 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr - [go-proxy](#go-proxy) - [Key Points](#key-points) - [Getting Started](#getting-started) + - [Setup](#setup) - [Commands line arguments](#commands-line-arguments) - [Environment variables](#environment-variables) - [Use JSON Schema in VSCode](#use-json-schema-in-vscode) @@ -27,10 +28,11 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr - Easy to use - Effortless configuration - - Error messages is clear and detailed + - Error messages is clear and detailed, easy troubleshooting - Auto certificate obtaining and renewal (See [Supported DNS Challenge Providers](docs/dns_providers.md)) - Auto configuration for docker containers - Auto hot-reload on container state / config file changes +- Stop containers on idle, wake it up on traffic _(optional)_ - Support HTTP(s), TCP and UDP - Web UI for configuration and monitoring (See [screenshots](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots)) - Written in **[Go](https://go.dev)** @@ -39,6 +41,8 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr ## Getting Started +### Setup + 1. Setup DNS Records, e.g. - A Record: `*.y.z` -> `10.0.10.1` @@ -60,6 +64,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr | `validate` | validate config and exit | | | `reload` | trigger a force reload of config | | | `ls-config` | list config and exit | `go-proxy ls-config \| jq` | +| `ls-route` | list proxy entries and exit | `go-proxy ls-route \| jq` | **run with `docker exec /app/go-proxy `** @@ -104,7 +109,7 @@ providers: ### Provider File -Fields are same as [docker labels](docs/docker.md#labels) starting from `scheme` +See [Fields](docs/docker.md#fields) See [providers.example.yml](providers.example.yml) for examples diff --git a/compose.example.yml b/compose.example.yml index 84de8cba..2ba19263 100755 --- a/compose.example.yml +++ b/compose.example.yml @@ -6,7 +6,7 @@ services: network_mode: host labels: - proxy.aliases=gp - - proxy.gp.port=8888 + - proxy.gp.port=3000 depends_on: - app app: diff --git a/docs/docker.md b/docs/docker.md index 0a60347c..9bc50818 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -85,12 +85,17 @@ ### Syntax -| Label | Description | Default | -| ----------------------- | -------------------------------------------------------- | ---------------- | -| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` | -| `proxy.exclude` | to be excluded from `go-proxy` | false | -| `proxy..` | set field for specific alias | N/A | -| `proxy.*.` | set field for all aliases | N/A | +| Label | Description | Default | Accepted values | +| ----------------------- | --------------------------------------------------------------------- | -------------------- | ------------------------------------------------------------------------- | +| `proxy.aliases` | comma separated aliases for subdomain and label matching | `container_name` | any | +| `proxy.exclude` | to be excluded from `go-proxy` | false | boolean | +| `proxy.idle_timeout` | time for idle (no traffic) before put it into sleep **(http/s only)** | empty **(disabled)** | `number[unit]...`, e.g. `1m30s` | +| `proxy.wake_timeout` | time to wait for container to start before responding a loading page | empty | `number[unit]...` | +| `proxy.stop_method` | method to stop after `idle_timeout` | `stop` | `stop`, `pause`, `kill` | +| `proxy.stop_timeout` | time to wait for stop command | `10s` | `number[unit]...` | +| `proxy.stop_signal` | signal sent to container for `stop` and `kill` methods | docker's default | `SIGINT`, `SIGTERM`, `SIGHUP`, `SIGQUIT` and those without **SIG** prefix | +| `proxy..` | set field for specific alias | N/A | N/A | +| `proxy.*.` | set field for all aliases | N/A | N/A | ### Fields @@ -228,12 +233,18 @@ services: volumes: - adg-work:/opt/adguardhome/work - adg-conf:/opt/adguardhome/conf + ports: + - 80 + - 3000 + - 53 mc: image: itzg/minecraft-server tty: true stdin_open: true container_name: mc restart: unless-stopped + ports: + - 25565 labels: - proxy.mc.scheme=tcp - proxy.mc.port=20001:25565 @@ -246,6 +257,9 @@ services: restart: unless-stopped container_name: pal stop_grace_period: 30s + ports: + - 8211 + - 27015 labels: - proxy.aliases=pal1,pal2 - proxy.*.scheme=udp @@ -261,6 +275,8 @@ services: - nginx:/usr/share/nginx/html ports: - 80 + labels: + proxy.idle_timeout: 1m go-proxy: image: ghcr.io/yusing/go-proxy:latest container_name: go-proxy diff --git a/src/api/v1/checkhealth.go b/src/api/v1/checkhealth.go index ab730a7c..19476d45 100644 --- a/src/api/v1/checkhealth.go +++ b/src/api/v1/checkhealth.go @@ -3,6 +3,7 @@ package v1 import ( "fmt" "net/http" + "strings" U "github.com/yusing/go-proxy/api/v1/utils" "github.com/yusing/go-proxy/config" @@ -17,17 +18,19 @@ func CheckHealth(cfg *config.Config, w http.ResponseWriter, r *http.Request) { } var ok bool + route := cfg.FindRoute(target) - switch route := cfg.FindRoute(target).(type) { - case nil: + switch { + case route == nil: U.HandleErr(w, r, U.ErrNotFound("target", target), http.StatusNotFound) return - case *R.HTTPRoute: - ok = U.IsSiteHealthy(route.TargetURL.String()) - case *R.StreamRoute: + case route.Type() == R.RouteTypeReverseProxy: + ok = U.IsSiteHealthy(route.URL().String()) + case route.Type() == R.RouteTypeStream: + entry := route.Entry() ok = U.IsStreamHealthy( - string(route.Scheme.ProxyScheme), - fmt.Sprintf("%s:%v", route.Host, route.Port.ProxyPort), + strings.Split(entry.Scheme, ":")[1], // target scheme + fmt.Sprintf("%s:%v", entry.Host, strings.Split(entry.Port, ":")[1]), ) } diff --git a/src/api/v1/file.go b/src/api/v1/file.go index 6edab084..a0adc094 100644 --- a/src/api/v1/file.go +++ b/src/api/v1/file.go @@ -9,7 +9,6 @@ import ( U "github.com/yusing/go-proxy/api/v1/utils" "github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/config" - E "github.com/yusing/go-proxy/error" "github.com/yusing/go-proxy/proxy/provider" ) @@ -32,25 +31,25 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { U.HandleErr(w, r, U.ErrMissingKey("filename"), http.StatusBadRequest) return } - content, err := E.Check(io.ReadAll(r.Body)) - if err.HasError() { + content, err := io.ReadAll(r.Body) + if err != nil { U.HandleErr(w, r, err) return } if filename == common.ConfigFileName { - err = config.Validate(content) + err = config.Validate(content).Error() } else { - err = provider.Validate(content) + err = provider.Validate(content).Error() } - if err.HasError() { + if err != nil { U.HandleErr(w, r, err, http.StatusBadRequest) return } - err = E.From(os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)) - if err.HasError() { + err = os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644) + if err != nil { U.HandleErr(w, r, err) return } diff --git a/src/api/v1/reload.go b/src/api/v1/reload.go index 18011476..44efaf59 100644 --- a/src/api/v1/reload.go +++ b/src/api/v1/reload.go @@ -8,7 +8,7 @@ import ( ) func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - if err := cfg.Reload(); err.HasError() { + if err := cfg.Reload().Error(); err != nil { U.HandleErr(w, r, err) return } diff --git a/src/api/v1/utils/error.go b/src/api/v1/utils/error.go index ae5b9b4d..3f4dc925 100644 --- a/src/api/v1/utils/error.go +++ b/src/api/v1/utils/error.go @@ -9,14 +9,14 @@ import ( E "github.com/yusing/go-proxy/error" ) -func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) { - err = E.From(err).Subjectf("%s %s", r.Method, r.URL) +func HandleErr(w http.ResponseWriter, r *http.Request, origErr error, code ...int) { + err := E.From(origErr).Subjectf("%s %s", r.Method, r.URL) logrus.WithField("module", "api").Error(err) if len(code) > 0 { - http.Error(w, err.Error(), code[0]) + http.Error(w, err.String(), code[0]) return } - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.String(), http.StatusInternalServerError) } func ErrMissingKey(k string) error { diff --git a/src/api/v1/utils/net.go b/src/api/v1/utils/net.go index c7858f6b..b4da5c6d 100644 --- a/src/api/v1/utils/net.go +++ b/src/api/v1/utils/net.go @@ -44,7 +44,7 @@ func ReloadServer() E.NestedError { if resp.StatusCode != http.StatusOK { return E.Failure("server reload").Subjectf("status code: %v", resp.StatusCode) } - return E.Nil() + return nil } var HttpClient = &http.Client{ diff --git a/src/autocert/config.go b/src/autocert/config.go index 60076eed..e61792cd 100644 --- a/src/autocert/config.go +++ b/src/autocert/config.go @@ -26,33 +26,35 @@ func NewConfig(cfg *M.AutoCertConfig) *Config { return (*Config)(cfg) } -func (cfg *Config) GetProvider() (*Provider, E.NestedError) { - errors := E.NewBuilder("cannot create autocert provider") +func (cfg *Config) GetProvider() (provider *Provider, res E.NestedError) { + b := E.NewBuilder("unable to initialize autocert") + defer b.To(&res) if cfg.Provider != ProviderLocal { if len(cfg.Domains) == 0 { - errors.Addf("no domains specified") + b.Addf("no domains specified") } if cfg.Provider == "" { - errors.Addf("no provider specified") + b.Addf("no provider specified") } if cfg.Email == "" { - errors.Addf("no email specified") + b.Addf("no email specified") } // check if provider is implemented _, ok := providersGenMap[cfg.Provider] if !ok { - errors.Addf("unknown provider: %q", cfg.Provider) + b.Addf("unknown provider: %q", cfg.Provider) } } - if err := errors.Build(); err.HasError() { - return nil, err + if b.HasError() { + return } privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) if err.HasError() { - return nil, E.Failure("generate private key").With(err) + b.Add(E.FailWith("generate private key", err)) + return } user := &User{ @@ -63,11 +65,11 @@ func (cfg *Config) GetProvider() (*Provider, E.NestedError) { legoCfg := lego.NewConfig(user) legoCfg.Certificate.KeyType = certcrypto.RSA2048 - base := &Provider{ + provider = &Provider{ cfg: cfg, user: user, legoCfg: legoCfg, } - return base, E.Nil() + return } diff --git a/src/autocert/constants.go b/src/autocert/constants.go index 8dadbd9f..19b726fc 100644 --- a/src/autocert/constants.go +++ b/src/autocert/constants.go @@ -1,6 +1,8 @@ package autocert import ( + "errors" + "github.com/go-acme/lego/v4/providers/dns/clouddns" "github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/providers/dns/duckdns" @@ -31,4 +33,8 @@ var providersGenMap = map[string]ProviderGenerator{ ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig), } +var ( + ErrGetCertFailure = errors.New("get certificate failed") +) + var logger = logrus.WithField("module", "autocert") diff --git a/src/autocert/provider.go b/src/autocert/provider.go index 96c38a24..e78116a8 100644 --- a/src/autocert/provider.go +++ b/src/autocert/provider.go @@ -33,7 +33,7 @@ type CertExpiries map[string]time.Time func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { if p.tlsCert == nil { - return nil, E.Failure("get certificate") + return nil, ErrGetCertFailure } return p.tlsCert, nil } @@ -54,52 +54,60 @@ func (p *Provider) GetExpiries() CertExpiries { return p.certExpiries } -func (p *Provider) ObtainCert() E.NestedError { +func (p *Provider) ObtainCert() (res E.NestedError) { + b := E.NewBuilder("failed to obtain certificate") + defer b.To(&res) + if p.cfg.Provider == ProviderLocal { - return E.FailureWhy("obtain cert", "provider is set to \"local\"") + b.Addf("provider is set to %q", ProviderLocal) + return } if p.client == nil { if err := p.initClient(); err.HasError() { - return E.Failure("obtain cert").With(err) + b.Add(E.FailWith("init autocert client", err)) + return } } - ne := E.Failure("obtain certificate") - - client := p.client if p.user.Registration == nil { if err := p.loadRegistration(); err.HasError() { - ne = ne.With(err) if err := p.registerACME(); err.HasError() { - return ne.With(err) + b.Add(E.FailWith("register ACME", err)) + return } } } + + client := p.client req := certificate.ObtainRequest{ Domains: p.cfg.Domains, Bundle: true, } cert, err := E.Check(client.Certificate.Obtain(req)) if err.HasError() { - return ne.With(err) + b.Add(err) + return } err = p.saveCert(cert) if err.HasError() { - return ne.With(E.Failure("save certificate").With(err)) + b.Add(E.FailWith("save certificate", err)) + return } tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey)) if err.HasError() { - return ne.With(E.Failure("parse obtained certificate").With(err)) + b.Add(E.FailWith("parse obtained certificate", err)) + return } expiries, err := getCertExpiries(&tlsCert) if err.HasError() { - return ne.With(E.Failure("get certificate expiry").With(err)) + b.Add(E.FailWith("get certificate expiry", err)) + return } p.tlsCert = &tlsCert p.certExpiries = expiries - return E.Nil() + return nil } func (p *Provider) LoadCert() E.NestedError { @@ -152,50 +160,50 @@ func (p *Provider) ScheduleRenewal(ctx context.Context) { func (p *Provider) initClient() E.NestedError { legoClient, err := E.Check(lego.NewClient(p.legoCfg)) if err.HasError() { - return E.Failure("create lego client").With(err) + return E.FailWith("create lego client", err) } legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options) if err.HasError() { - return E.Failure("create lego provider").With(err) + return E.FailWith("create lego provider", err) } err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) if err.HasError() { - return E.Failure("set challenge provider").With(err) + return E.FailWith("set challenge provider", err) } p.client = legoClient - return E.Nil() + return nil } func (p *Provider) registerACME() E.NestedError { if p.user.Registration != nil { - return E.Nil() + return nil } reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) if err.HasError() { - return E.Failure("register ACME").With(err) + return err } p.user.Registration = reg if err := p.saveRegistration(); err.HasError() { logger.Warn(err) } - return E.Nil() + return nil } func (p *Provider) loadRegistration() E.NestedError { if p.user.Registration != nil { - return E.Nil() + return nil } reg := ®istration.Resource{} err := U.LoadJson(RegistrationFile, reg) if err.HasError() { - return E.Failure("parse registration file").With(err) + return E.FailWith("parse registration file", err) } p.user.Registration = reg - return E.Nil() + return nil } func (p *Provider) saveRegistration() E.NestedError { @@ -205,13 +213,13 @@ func (p *Provider) saveRegistration() E.NestedError { func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- if err != nil { - return E.Failure("write key file").With(err) + return E.FailWith("write key file", err) } err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r-- if err != nil { - return E.Failure("write cert file").With(err) + return E.FailWith("write cert file", err) } - return E.Nil() + return nil } func (p *Provider) certState() CertState { @@ -245,13 +253,13 @@ func (p *Provider) renewIfNeeded() E.NestedError { case CertStateMismatch: logger.Info("cert domains mismatch with config, renewing") default: - return E.Nil() + return nil } if err := p.ObtainCert(); err.HasError() { - return E.Failure("renew certificate").With(err) + return E.FailWith("renew certificate", err) } - return E.Nil() + return nil } func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { @@ -259,7 +267,7 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { for _, cert := range cert.Certificate { x509Cert, err := E.Check(x509.ParseCertificate(cert)) if err.HasError() { - return nil, E.Failure("parse certificate").With(err) + return nil, E.FailWith("parse certificate", err) } if x509Cert.IsCA { continue @@ -269,7 +277,7 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { r[x509Cert.DNSNames[i]] = x509Cert.NotAfter } } - return r, E.Nil() + return r, nil } func providerGenerator[CT any, PT challenge.Provider]( @@ -286,6 +294,6 @@ func providerGenerator[CT any, PT challenge.Provider]( if err.HasError() { return nil, err } - return p, E.Nil() + return p, nil } } diff --git a/src/autocert/provider_test/ovh_test.go b/src/autocert/provider_test/ovh_test.go index e5e40add..032c0a42 100644 --- a/src/autocert/provider_test/ovh_test.go +++ b/src/autocert/provider_test/ovh_test.go @@ -4,7 +4,8 @@ import ( "testing" "github.com/go-acme/lego/v4/providers/dns/ovh" - . "github.com/yusing/go-proxy/utils" + U "github.com/yusing/go-proxy/utils" + . "github.com/yusing/go-proxy/utils/testing" "gopkg.in/yaml.v3" ) @@ -44,6 +45,6 @@ oauth2_config: testYaml = testYaml[1:] // remove first \n opt := make(map[string]any) ExpectNoError(t, yaml.Unmarshal([]byte(testYaml), opt)) - ExpectNoError(t, Deserialize(opt, cfg)) - ExpectEqual(t, cfg, cfgExpected) + ExpectTrue(t, U.Deserialize(opt, cfg).NoError()) + ExpectDeepEqual(t, cfg, cfgExpected) } diff --git a/src/common/args.go b/src/common/args.go index b91a857b..c01f336d 100644 --- a/src/common/args.go +++ b/src/common/args.go @@ -15,25 +15,32 @@ const ( CommandStart = "" CommandValidate = "validate" CommandListConfigs = "ls-config" + CommandListRoutes = "ls-routes" CommandReload = "reload" ) -var ValidCommands = []string{CommandStart, CommandValidate, CommandListConfigs, CommandReload} +var ValidCommands = []string{ + CommandStart, + CommandValidate, + CommandListConfigs, + CommandListRoutes, + CommandReload, +} func GetArgs() Args { var args Args flag.Parse() args.Command = flag.Arg(0) - if err := validateArgs(args.Command, ValidCommands); err.HasError() { + if err := validateArg(args.Command); err.HasError() { logrus.Fatal(err) } return args } -func validateArgs[T comparable](arg T, validArgs []T) E.NestedError { - for _, v := range validArgs { +func validateArg(arg string) E.NestedError { + for _, v := range ValidCommands { if arg == v { - return E.Nil() + return nil } } return E.Invalid("argument", arg) diff --git a/src/common/constants.go b/src/common/constants.go index 344502b7..08c1b3c1 100644 --- a/src/common/constants.go +++ b/src/common/constants.go @@ -41,7 +41,6 @@ const ( ProxyHTTPPort = ":80" ProxyHTTPSPort = ":443" APIHTTPPort = ":8888" - PanelHTTPPort = ":8080" ) var WellKnownHTTPPorts = map[uint16]bool{ @@ -53,7 +52,7 @@ var WellKnownHTTPPorts = map[uint16]bool{ } var ( - ImageNamePortMapTCP = map[string]int{ + ServiceNamePortMapTCP = map[string]int{ "postgres": 5432, "mysql": 3306, "mariadb": 3306, @@ -62,8 +61,7 @@ var ( "memcached": 11211, "rabbitmq": 5672, "mongo": 27017, - } - ExtraNamePortMapTCP = map[string]int{ + "dns": 53, "ssh": 22, "ftp": 21, @@ -71,20 +69,9 @@ var ( "pop3": 110, "imap": 143, } - NamePortMapTCP = func() map[string]int { - m := make(map[string]int) - for k, v := range ImageNamePortMapTCP { - m[k] = v - } - for k, v := range ExtraNamePortMapTCP { - m[k] = v - } - return m - }() ) -// docker library uses uint16, so followed here -var ImageNamePortMapHTTP = map[string]uint16{ +var ImageNamePortMapHTTP = map[string]int{ "nginx": 80, "httpd": 80, "adguardhome": 3000, @@ -101,3 +88,10 @@ var ImageNamePortMapHTTP = map[string]uint16{ "dockge": 5001, "nginx-proxy-manager": 81, } + +const ( + IdleTimeoutDefault = "0" + WakeTimeoutDefault = "10s" + StopTimeoutDefault = "10s" + StopMethodDefault = "stop" +) diff --git a/src/config/config.go b/src/config/config.go index 624f2778..17a53b50 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -2,6 +2,7 @@ package config import ( "context" + "os" "github.com/sirupsen/logrus" "github.com/yusing/go-proxy/autocert" @@ -17,32 +18,26 @@ import ( ) type Config struct { - value *M.Config - - l logrus.FieldLogger - reader U.Reader - proxyProviders *F.Map[string, *PR.Provider] + value *M.Config + proxyProviders F.Map[string, *PR.Provider] autocertProvider *autocert.Provider + l logrus.FieldLogger + watcher W.Watcher watcherCtx context.Context watcherCancel context.CancelFunc reloadReq chan struct{} } -func New() (*Config, E.NestedError) { +func Load() (*Config, E.NestedError) { cfg := &Config{ - l: logrus.WithField("module", "config"), - reader: U.NewFileReader(common.ConfigPath), - watcher: W.NewFileWatcher(common.ConfigFileName), - reloadReq: make(chan struct{}, 1), + proxyProviders: F.NewMapOf[string, *PR.Provider](), + l: logrus.WithField("module", "config"), + watcher: W.NewFileWatcher(common.ConfigFileName), + reloadReq: make(chan struct{}, 1), } - if err := cfg.load(); err.HasError() { - return nil, err - } - cfg.startProviders() - cfg.watchChanges() - return cfg, E.Nil() + return cfg, cfg.load() } func Validate(data []byte) E.NestedError { @@ -57,11 +52,17 @@ func (cfg *Config) GetAutoCertProvider() *autocert.Provider { return cfg.autocertProvider } +func (cfg *Config) StartProxyProviders() { + cfg.startProviders() + cfg.watchChanges() +} + func (cfg *Config) Dispose() { - cfg.watcherCancel() - cfg.l.Debug("stopped watcher") + if cfg.watcherCancel != nil { + cfg.watcherCancel() + cfg.l.Debug("stopped watcher") + } cfg.stopProviders() - cfg.l.Debug("stopped providers") } func (cfg *Config) Reload() E.NestedError { @@ -70,46 +71,31 @@ func (cfg *Config) Reload() E.NestedError { return err } cfg.startProviders() - return E.Nil() + return nil } func (cfg *Config) FindRoute(alias string) R.Route { - r := cfg.proxyProviders.Find( - func(p *PR.Provider) (any, bool) { - rs := p.GetCurrentRoutes() - if rs.Contains(alias) { - return rs.Get(alias), true + return F.MapFind(cfg.proxyProviders, + func(p *PR.Provider) (R.Route, bool) { + if route, ok := p.GetRoute(alias); ok { + return route, true } return nil, false }, ) - if r == nil { - return nil - } - return r.(R.Route) } func (cfg *Config) RoutesByAlias() map[string]U.SerializedObject { routes := make(map[string]U.SerializedObject) - cfg.proxyProviders.Each(func(p *PR.Provider) { - prName := p.GetName() - p.GetCurrentRoutes().EachKV(func(a string, r R.Route) { - obj, err := U.Serialize(r) - if err.HasError() { - cfg.l.Error(err) - return - } - obj["provider"] = prName - switch r.(type) { - case *R.StreamRoute: - obj["type"] = "stream" - case *R.HTTPRoute: - obj["type"] = "reverse_proxy" - default: - panic("bug: should not reach here") - } - routes[a] = obj - }) + cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + obj, err := U.Serialize(r) + if err.HasError() { + cfg.l.Error(err) + return + } + obj["provider"] = p.GetName() + obj["type"] = string(r.Type()) + routes[alias] = obj }) return routes } @@ -119,26 +105,23 @@ func (cfg *Config) Statistics() map[string]any { nTotalRPs := 0 providerStats := make(map[string]any) - cfg.proxyProviders.Each(func(p *PR.Provider) { - stats := make(map[string]any) - nStreams := 0 - nRPs := 0 - p.GetCurrentRoutes().EachKV(func(a string, r R.Route) { - switch r.(type) { - case *R.StreamRoute: - nStreams++ - nTotalStreams++ - case *R.HTTPRoute: - nRPs++ - nTotalRPs++ - default: - panic("bug: should not reach here") - } - }) - stats["type"] = p.GetType() - stats["num_streams"] = nStreams - stats["num_reverse_proxies"] = nRPs - providerStats[p.GetName()] = stats + cfg.forEachRoute(func(alias string, r R.Route, p *PR.Provider) { + s, ok := providerStats[p.GetName()] + if !ok { + s = make(map[string]int) + } + + stats := s.(map[string]int) + switch r.Type() { + case R.RouteTypeStream: + stats["num_streams"]++ + nTotalStreams++ + case R.RouteTypeReverseProxy: + stats["num_reverse_proxies"]++ + nTotalRPs++ + default: + panic("bug: should not reach here") + } }) return map[string]any{ @@ -148,6 +131,14 @@ func (cfg *Config) Statistics() map[string]any { } } +func (cfg *Config) forEachRoute(do func(alias string, r R.Route, p *PR.Provider)) { + cfg.proxyProviders.RangeAll(func(_ string, p *PR.Provider) { + p.RangeRoutes(func(a string, r R.Route) { + do(a, r, p) + }) + }) +} + func (cfg *Config) watchChanges() { cfg.watcherCtx, cfg.watcherCancel = context.WithCancel(context.Background()) go func() { @@ -182,64 +173,82 @@ func (cfg *Config) watchChanges() { }() } -func (cfg *Config) load() E.NestedError { +func (cfg *Config) load() (res E.NestedError) { + b := E.NewBuilder("errors loading config") + defer b.To(&res) + cfg.l.Debug("loading config") + defer cfg.l.Debug("loaded config") - data, err := cfg.reader.Read() + data, err := E.Check(os.ReadFile(common.ConfigPath)) if err.HasError() { - return E.Failure("read config").With(err) - } - - model := M.DefaultConfig() - if err := E.From(yaml.Unmarshal(data, model)); err.HasError() { - return E.Failure("parse config").With(err) + b.Add(E.FailWith("read config", err)) + return } if !common.NoSchemaValidation { if err = Validate(data); err.HasError() { - return err + b.Add(E.FailWith("schema validation", err)) + return } } - warnings := E.NewBuilder("errors loading config") + model := M.DefaultConfig() + if err := E.From(yaml.Unmarshal(data, model)); err.HasError() { + b.Add(E.FailWith("parse config", err)) + return + } - cfg.l.Debug("initializing autocert") - ap, err := autocert.NewConfig(&model.AutoCert).GetProvider() - if err.HasError() { - warnings.Add(E.Failure("autocert provider").With(err)) - } else { - cfg.l.Debug("initialized autocert") - } - cfg.autocertProvider = ap - - cfg.l.Debug("loading providers") - cfg.proxyProviders = F.NewMap[string, *PR.Provider]() - for _, filename := range model.Providers.Files { - p := PR.NewFileProvider(filename) - cfg.proxyProviders.Set(p.GetName(), p) - } - for name, dockerHost := range model.Providers.Docker { - p := PR.NewDockerProvider(name, dockerHost) - cfg.proxyProviders.Set(p.GetName(), p) - } - cfg.l.Debug("loaded providers") + // errors are non fatal below + b.WithSeverity(E.SeverityWarning) + b.Add(cfg.initAutoCert(&model.AutoCert)) + b.Add(cfg.loadProviders(&model.Providers)) cfg.value = model + return +} - if err := warnings.Build(); err.HasError() { - cfg.l.Warn(err) +func (cfg *Config) initAutoCert(autocertCfg *M.AutoCertConfig) (err E.NestedError) { + if cfg.autocertProvider != nil { + return } - cfg.l.Debug("loaded config") - return E.Nil() + cfg.l.Debug("initializing autocert") + defer cfg.l.Debug("initialized autocert") + + cfg.autocertProvider, err = autocert.NewConfig(autocertCfg).GetProvider() + if err.HasError() { + err = E.FailWith("autocert provider", err) + } + return +} + +func (cfg *Config) loadProviders(providers *M.ProxyProviders) (res E.NestedError) { + cfg.l.Debug("loading providers") + defer cfg.l.Debug("loaded providers") + + b := E.NewBuilder("errors loading providers") + defer b.To(&res) + + for _, filename := range providers.Files { + p := PR.NewFileProvider(filename) + cfg.proxyProviders.Store(p.GetName(), p) + b.Add(p.LoadRoutes()) + } + for name, dockerHost := range providers.Docker { + p := PR.NewDockerProvider(name, dockerHost) + cfg.proxyProviders.Store(p.GetName(), p) + b.Add(p.LoadRoutes()) + } + return } func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.NestedError) { errors := E.NewBuilder("cannot %s these providers", action) - cfg.proxyProviders.EachKVParallel(func(name string, p *PR.Provider) { + cfg.proxyProviders.RangeAll(func(name string, p *PR.Provider) { if err := do(p); err.HasError() { - errors.Add(E.From(err).Subject(p)) + errors.Add(err.Subject(p)) } }) diff --git a/src/docker/client.go b/src/docker/client.go index 70ff22d9..e6f654ad 100644 --- a/src/docker/client.go +++ b/src/docker/client.go @@ -3,6 +3,7 @@ package docker import ( "net/http" "sync" + "sync/atomic" "github.com/docker/cli/cli/connhelper" "github.com/docker/docker/client" @@ -11,14 +12,37 @@ import ( E "github.com/yusing/go-proxy/error" ) -type Client = *client.Client +type Client struct { + key string + refCount *atomic.Int32 + *client.Client +} + +func (c Client) DaemonHostname() string { + url, _ := client.ParseHostURL(c.DaemonHost()) + return url.Hostname() +} + +// if the client is still referenced, this is no-op +func (c Client) Close() error { + if c.refCount.Load() > 0 { + c.refCount.Add(-1) + return nil + } + + clientMapMu.Lock() + defer clientMapMu.Unlock() + delete(clientMap, c.key) + + return c.Client.Close() +} // ConnectClient creates a new Docker client connection to the specified host. // // Returns existing client if available. // // Parameters: -// - host: the host to connect to (either a URL or "FROM_ENV"). +// - host: the host to connect to (either a URL or common.DockerHostFromEnv). // // Returns: // - Client: the Docker client connection. @@ -29,7 +53,8 @@ func ConnectClient(host string) (Client, E.NestedError) { // check if client exists if client, ok := clientMap[host]; ok { - return client, E.Nil() + client.refCount.Add(1) + return client, nil } // create client @@ -41,7 +66,7 @@ func ConnectClient(host string) (Client, E.NestedError) { default: helper, err := E.Check(connhelper.GetConnectionHelper(host)) if err.HasError() { - logger.Fatalf("unexpected error: %s", err) + return Client{}, E.UnexpectedError(err.Error()) } if helper != nil { httpClient := &http.Client{ @@ -66,11 +91,16 @@ func ConnectClient(host string) (Client, E.NestedError) { client, err := E.Check(client.NewClientWithOpts(opt...)) if err.HasError() { - return nil, err + return Client{}, err } - clientMap[host] = client - return client, E.Nil() + clientMap[host] = Client{ + Client: client, + key: host, + refCount: &atomic.Int32{}, + } + clientMap[host].refCount.Add(1) + return clientMap[host], nil } func CloseAllClients() { @@ -83,12 +113,13 @@ func CloseAllClients() { logger.Debug("closed all clients") } -var clientMap map[string]Client = make(map[string]Client) -var clientMapMu sync.Mutex +var ( + clientMap map[string]Client = make(map[string]Client) + clientMapMu sync.Mutex + clientOptEnvHost = []client.Opt{ + client.WithHostFromEnv(), + client.WithAPIVersionNegotiation(), + } -var clientOptEnvHost = []client.Opt{ - client.WithHostFromEnv(), - client.WithAPIVersionNegotiation(), -} - -var logger = logrus.WithField("module", "docker") + logger = logrus.WithField("module", "docker") +) diff --git a/src/docker/client_info.go b/src/docker/client_info.go index 20253a0e..c09190a9 100644 --- a/src/docker/client_info.go +++ b/src/docker/client_info.go @@ -12,35 +12,41 @@ import ( ) type ClientInfo struct { - Host string + Client Client Containers []types.Container } -func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) { +var listOptions = container.ListOptions{ + // Filters: filters.NewArgs( + // filters.Arg("health", "healthy"), + // filters.Arg("health", "none"), + // filters.Arg("health", "starting"), + // ), + All: true, +} + +func GetClientInfo(clientHost string, getContainer bool) (*ClientInfo, E.NestedError) { dockerClient, err := ConnectClient(clientHost) if err.HasError() { - return nil, E.Failure("create docker client").With(err) + return nil, E.FailWith("connect to docker", err) } + defer dockerClient.Close() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - containers, err := E.Check(dockerClient.ContainerList(ctx, container.ListOptions{})) - if err.HasError() { - return nil, E.Failure("list containers").With(err) + var containers []types.Container + if getContainer { + containers, err = E.Check(dockerClient.ContainerList(ctx, listOptions)) + if err.HasError() { + return nil, E.FailWith("list containers", err) + } } - // extract host from docker client url - // since the services being proxied to - // should have the same IP as the docker client - url, err := E.Check(client.ParseHostURL(dockerClient.DaemonHost())) - if err.HasError() { - return nil, E.Invalid("host url", dockerClient.DaemonHost()).With(err) - } - if url.Scheme == "unix" { - return &ClientInfo{Host: "localhost", Containers: containers}, E.Nil() - } - return &ClientInfo{Host: url.Hostname(), Containers: containers}, E.Nil() + return &ClientInfo{ + Client: dockerClient, + Containers: containers, + }, nil } func IsErrConnectionFailed(err error) bool { diff --git a/src/docker/container.go b/src/docker/container.go new file mode 100644 index 00000000..97e5e8a4 --- /dev/null +++ b/src/docker/container.go @@ -0,0 +1,109 @@ +package docker + +import ( + "fmt" + "strconv" + "strings" + + "github.com/docker/docker/api/types" + U "github.com/yusing/go-proxy/utils" +) + +type ProxyProperties struct { + DockerHost string `yaml:"docker_host" json:"docker_host"` + ContainerName string `yaml:"container_name" json:"container_name"` + ImageName string `yaml:"image_name" json:"image_name"` + Aliases []string `yaml:"aliases" json:"aliases"` + IsExcluded bool `yaml:"is_excluded" json:"is_excluded"` + FirstPort string `yaml:"first_port" json:"first_port"` + IdleTimeout string `yaml:"idle_timeout" json:"idle_timeout"` + WakeTimeout string `yaml:"wake_timeout" json:"wake_timeout"` + StopMethod string `yaml:"stop_method" json:"stop_method"` + StopTimeout string `yaml:"stop_timeout" json:"stop_timeout"` // stop_method = "stop" only + StopSignal string `yaml:"stop_signal" json:"stop_signal"` // stop_method = "stop" | "kill" only +} + +type Container struct { + *types.Container + *ProxyProperties +} + +func FromDocker(c *types.Container, dockerHost string) (res Container) { + res.Container = c + res.ProxyProperties = &ProxyProperties{ + DockerHost: dockerHost, + ContainerName: res.getName(), + ImageName: res.getImageName(), + Aliases: res.getAliases(), + IsExcluded: U.ParseBool(res.getDeleteLabel(LableExclude)), + FirstPort: res.firstPortOrEmpty(), + IdleTimeout: res.getDeleteLabel(LabelIdleTimeout), + WakeTimeout: res.getDeleteLabel(LabelWakeTimeout), + StopMethod: res.getDeleteLabel(LabelStopMethod), + StopTimeout: res.getDeleteLabel(LabelStopTimeout), + StopSignal: res.getDeleteLabel(LabelStopSignal), + } + return +} + +func FromJson(json types.ContainerJSON, dockerHost string) Container { + ports := make([]types.Port, 0) + for k, bindings := range json.NetworkSettings.Ports { + for _, v := range bindings { + pubPort, _ := strconv.Atoi(v.HostPort) + privPort, _ := strconv.Atoi(k.Port()) + ports = append(ports, types.Port{ + IP: v.HostIP, + PublicPort: uint16(pubPort), + PrivatePort: uint16(privPort), + }) + } + } + return FromDocker(&types.Container{ + ID: json.ID, + Names: []string{json.Name}, + Image: json.Image, + Ports: ports, + Labels: json.Config.Labels, + State: json.State.Status, + Status: json.State.Status, + }, dockerHost) +} + +func (c Container) getDeleteLabel(label string) string { + if l, ok := c.Labels[label]; ok { + delete(c.Labels, label) + return l + } + return "" +} + +func (c Container) getAliases() []string { + if l := c.getDeleteLabel(LableAliases); l != "" { + return U.CommaSeperatedList(l) + } else { + return []string{c.getName()} + } +} + +func (c Container) getName() string { + return strings.TrimPrefix(c.Names[0], "/") +} + +func (c Container) getImageName() string { + colonSep := strings.Split(c.Image, ":") + slashSep := strings.Split(colonSep[len(colonSep)-1], "/") + return slashSep[len(slashSep)-1] +} + +func (c Container) firstPortOrEmpty() string { + if len(c.Ports) == 0 { + return "" + } + for _, p := range c.Ports { + if p.PublicPort != 0 { + return fmt.Sprint(p.PublicPort) + } + } + return "" +} diff --git a/src/docker/idlewatcher/round_trip.go b/src/docker/idlewatcher/round_trip.go new file mode 100644 index 00000000..dc352d15 --- /dev/null +++ b/src/docker/idlewatcher/round_trip.go @@ -0,0 +1,14 @@ +package idlewatcher + +import "net/http" + +type ( + roundTripper struct { + patched roundTripFunc + } + roundTripFunc func(*http.Request) (*http.Response, error) +) + +func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return rt.patched(req) +} diff --git a/src/docker/idlewatcher/watcher.go b/src/docker/idlewatcher/watcher.go new file mode 100644 index 00000000..e02dd3c0 --- /dev/null +++ b/src/docker/idlewatcher/watcher.go @@ -0,0 +1,329 @@ +package idlewatcher + +import ( + "bytes" + "context" + "io" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/sirupsen/logrus" + D "github.com/yusing/go-proxy/docker" + E "github.com/yusing/go-proxy/error" + P "github.com/yusing/go-proxy/proxy" + PT "github.com/yusing/go-proxy/proxy/fields" +) + +type watcher struct { + *P.ReverseProxyEntry + client D.Client + + refCount atomic.Int32 + + stopByMethod StopCallback + wakeCh chan struct{} + wakeDone chan E.NestedError + + ctx context.Context + cancel context.CancelFunc + + l logrus.FieldLogger +} + +type ( + WakeDone <-chan error + WakeFunc func() WakeDone + StopCallback func() (bool, E.NestedError) +) + +func Register(entry *P.ReverseProxyEntry) (*watcher, E.NestedError) { + failure := E.Failure("idle_watcher register") + + if entry.IdleTimeout == 0 { + return nil, failure.With(E.Invalid("idle_timeout", 0)) + } + + watcherMapMu.Lock() + defer watcherMapMu.Unlock() + + if w, ok := watcherMap[entry.ContainerName]; ok { + w.refCount.Add(1) + return w, nil + } + + client, err := D.ConnectClient(entry.DockerHost) + if err.HasError() { + return nil, failure.With(err) + } + + w := &watcher{ + ReverseProxyEntry: entry, + client: client, + wakeCh: make(chan struct{}, 1), + wakeDone: make(chan E.NestedError, 1), + l: logger.WithField("container", entry.ContainerName), + } + w.refCount.Add(1) + + w.stopByMethod = w.getStopCallback() + watcherMap[w.ContainerName] = w + + go func() { + newWatcherCh <- w + }() + + return w, nil +} + +// If the container is not registered, this is no-op +func Unregister(containerName string) { + watcherMapMu.Lock() + defer watcherMapMu.Unlock() + + if w, ok := watcherMap[containerName]; ok { + if w.refCount.Load() == 0 { + w.cancel() + close(w.wakeCh) + delete(watcherMap, containerName) + } else { + w.refCount.Add(-1) + } + } +} + +func Start() { + logger.Debug("started") + defer logger.Debug("stopped") + + mainLoopCtx, mainLoopCancel = context.WithCancel(context.Background()) + + defer mainLoopWg.Wait() + + for { + select { + case <-mainLoopCtx.Done(): + return + case w := <-newWatcherCh: + w.l.Debug("registered") + mainLoopWg.Add(1) + go func() { + w.watch() + Unregister(w.ContainerName) + w.l.Debug("unregistered") + mainLoopWg.Done() + }() + } + } +} + +func Stop() { + mainLoopCancel() + mainLoopWg.Wait() +} + +func (w *watcher) PatchRoundTripper(rtp http.RoundTripper) roundTripper { + return roundTripper{patched: func(r *http.Request) (*http.Response, error) { + return w.roundTrip(rtp.RoundTrip, r) + }} +} + +func (w *watcher) roundTrip(origRoundTrip roundTripFunc, req *http.Request) (*http.Response, error) { + timeout := time.After(w.WakeTimeout) + w.wakeCh <- struct{}{} + for { + select { + case err := <-w.wakeDone: + if err != nil { + return nil, err.Error() + } + return origRoundTrip(req) + case <-timeout: + resp := loadingResponse + resp.TLS = req.TLS + return &resp, nil + } + } +} + +func (w *watcher) containerStop() error { + return w.client.ContainerStop(w.ctx, w.ContainerName, container.StopOptions{ + Signal: string(w.StopSignal), + Timeout: &w.StopTimeout}) +} + +func (w *watcher) containerPause() error { + return w.client.ContainerPause(w.ctx, w.ContainerName) +} + +func (w *watcher) containerKill() error { + return w.client.ContainerKill(w.ctx, w.ContainerName, string(w.StopSignal)) +} + +func (w *watcher) containerUnpause() error { + return w.client.ContainerUnpause(w.ctx, w.ContainerName) +} + +func (w *watcher) containerStart() error { + return w.client.ContainerStart(w.ctx, w.ContainerName, container.StartOptions{}) +} + +func (w *watcher) containerStatus() (string, E.NestedError) { + json, err := w.client.ContainerInspect(w.ctx, w.ContainerName) + if err != nil { + return "", E.FailWith("inspect container", err) + } + return json.State.Status, nil +} + +func (w *watcher) wakeIfStopped() (bool, E.NestedError) { + failure := E.Failure("wake") + status, err := w.containerStatus() + + if err.HasError() { + return false, failure.With(err) + } + // "created", "running", "paused", "restarting", "removing", "exited", or "dead" + switch status { + case "exited", "dead": + err = E.From(w.containerStart()) + case "paused": + err = E.From(w.containerUnpause()) + case "running": + return false, nil + default: + return false, failure.With(E.Unexpected("container state", status)) + } + + if err.HasError() { + return false, failure.With(err) + } + + status, err = w.containerStatus() + if err.HasError() { + return false, failure.With(err) + } else if status != "running" { + return false, failure.With(E.Unexpected("container state", status)) + } else { + return true, nil + } +} + +func (w *watcher) getStopCallback() StopCallback { + var cb func() error + switch w.StopMethod { + case PT.StopMethodPause: + cb = w.containerPause + case PT.StopMethodStop: + cb = w.containerStop + case PT.StopMethodKill: + cb = w.containerKill + default: + panic("should not reach here") + } + return func() (bool, E.NestedError) { + status, err := w.containerStatus() + if err.HasError() { + return false, E.FailWith("stop", err) + } + if status != "running" { + return false, nil + } + err = E.From(cb()) + if err.HasError() { + return false, E.FailWith("stop", err) + } + return true, nil + } +} + +func (w *watcher) watch() { + watcherCtx, watcherCancel := context.WithCancel(context.Background()) + w.ctx = watcherCtx + w.cancel = watcherCancel + + ticker := time.NewTicker(w.IdleTimeout) + defer ticker.Stop() + + for { + select { + case <-mainLoopCtx.Done(): + watcherCancel() + case <-watcherCtx.Done(): + w.l.Debug("stopped") + return + case <-ticker.C: + w.l.Debug("timeout") + stopped, err := w.stopByMethod() + if err.HasError() { + w.l.Error(err.Extraf("stop method: %s", w.StopMethod)) + } else if stopped { + w.l.Infof("%s: ok", w.StopMethod) + } else { + ticker.Stop() + } + case <-w.wakeCh: + w.l.Debug("wake received") + go func() { + started, err := w.wakeIfStopped() + if err != nil { + w.l.Error(err) + } else if started { + w.l.Infof("awaken") + ticker.Reset(w.IdleTimeout) + } + w.wakeDone <- err // this is passed to roundtrip + }() + } + } +} + +var ( + mainLoopCtx context.Context + mainLoopCancel context.CancelFunc + mainLoopWg sync.WaitGroup + + watcherMap = make(map[string]*watcher) + watcherMapMu sync.Mutex + + newWatcherCh = make(chan *watcher) + + logger = logrus.WithField("module", "idle_watcher") + + loadingResponse = http.Response{ + StatusCode: http.StatusAccepted, + Header: http.Header{ + "Content-Type": {"text/html"}, + "Cache-Control": { + "no-cache", + "no-store", + "must-revalidate", + }, + }, + Body: io.NopCloser(bytes.NewReader((loadingPage))), + ContentLength: int64(len(loadingPage)), + } + + loadingPage = []byte(` + + + + + + Loading... + + + +

Container is starting... Please wait

+ + +`[1:]) +) diff --git a/src/docker/inspect.go b/src/docker/inspect.go new file mode 100644 index 00000000..ae7f8f5c --- /dev/null +++ b/src/docker/inspect.go @@ -0,0 +1,19 @@ +package docker + +import ( + "context" + "time" + + E "github.com/yusing/go-proxy/error" +) + +func (c Client) Inspect(containerID string) (Container, E.NestedError) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + json, err := c.ContainerInspect(ctx, containerID) + if err != nil { + return Container{}, E.From(err) + } + return FromJson(json, c.key), nil +} diff --git a/src/docker/label.go b/src/docker/label.go index 1697f99c..b049509e 100644 --- a/src/docker/label.go +++ b/src/docker/label.go @@ -36,7 +36,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { return &Label{ Namespace: label, Value: value, - }, E.Nil() + }, nil } l := &Label{ @@ -54,12 +54,12 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { // find if namespace has value parser pm, ok := labelValueParserMap[l.Namespace] if !ok { - return l, E.Nil() + return l, nil } // find if attribute has value parser p, ok := pm[l.Attribute] if !ok { - return l, E.Nil() + return l, nil } // try to parse value v, err := p(value) @@ -67,7 +67,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { return nil, err } l.Value = v - return l, E.Nil() + return l, nil } func RegisterNamespace(namespace string, pm ValueParserMap) { diff --git a/src/docker/label_parser.go b/src/docker/label_parser.go index 1d34d05c..b89be9fd 100644 --- a/src/docker/label_parser.go +++ b/src/docker/label_parser.go @@ -10,7 +10,7 @@ import ( func yamlListParser(value string) (any, E.NestedError) { value = strings.TrimSpace(value) if value == "" { - return []string{}, E.Nil() + return []string{}, nil } var data []string err := E.From(yaml.Unmarshal([]byte(value), &data)) @@ -34,23 +34,15 @@ func yamlStringMappingParser(value string) (any, E.NestedError) { h[key] = val } } - return h, E.Nil() -} - -func commaSepParser(value string) (any, E.NestedError) { - v := strings.Split(value, ",") - for i := range v { - v[i] = strings.TrimSpace(v[i]) - } - return v, E.Nil() + return h, nil } func boolParser(value string) (any, E.NestedError) { switch strings.ToLower(value) { case "true", "yes", "1": - return true, E.Nil() + return true, nil case "false", "no", "0": - return false, E.Nil() + return false, nil default: return nil, E.Invalid("boolean value", value) } diff --git a/src/docker/label_parser_test.go b/src/docker/label_parser_test.go index 7aac189f..57faf939 100644 --- a/src/docker/label_parser_test.go +++ b/src/docker/label_parser_test.go @@ -7,7 +7,7 @@ import ( "testing" E "github.com/yusing/go-proxy/error" - . "github.com/yusing/go-proxy/utils" + . "github.com/yusing/go-proxy/utils/testing" ) func makeLabel(namespace string, alias string, field string) string { @@ -19,7 +19,7 @@ func TestHomePageLabel(t *testing.T) { field := "ip" v := "bar" pl, err := ParseLabel(makeLabel(NSHomePage, alias, field), v) - ExpectNoError(t, err) + ExpectNoError(t, err.Error()) if pl.Target != alias { t.Errorf("Expected alias=%s, got %s", alias, pl.Target) } @@ -34,8 +34,8 @@ func TestHomePageLabel(t *testing.T) { func TestStringProxyLabel(t *testing.T) { v := "bar" pl, err := ParseLabel(makeLabel(NSProxy, "foo", "ip"), v) - ExpectNoError(t, err) - ExpectEqual(t, pl.Value, v) + ExpectNoError(t, err.Error()) + ExpectEqual(t, pl.Value.(string), v) } func TestBoolProxyLabelValid(t *testing.T) { @@ -52,8 +52,8 @@ func TestBoolProxyLabelValid(t *testing.T) { for k, v := range tests { pl, err := ParseLabel(makeLabel(NSProxy, "foo", "no_tls_verify"), k) - ExpectNoError(t, err) - ExpectEqual(t, pl.Value, v) + ExpectNoError(t, err.Error()) + ExpectEqual(t, pl.Value.(bool), v) } } @@ -78,7 +78,7 @@ X-Custom-Header2: boo` } pl, err := ParseLabel(makeLabel(NSProxy, "foo", "set_headers"), v) - ExpectNoError(t, err) + ExpectNoError(t, err.Error()) hGot := ExpectType[map[string]string](t, pl.Value) if hGot != nil && !reflect.DeepEqual(h, hGot) { t.Errorf("Expected %v, got %v", h, hGot) @@ -109,33 +109,32 @@ func TestHideHeadersProxyLabel(t *testing.T) { ` v = strings.TrimPrefix(v, "\n") pl, err := ParseLabel(makeLabel(NSProxy, "foo", "hide_headers"), v) - ExpectNoError(t, err) + ExpectNoError(t, err.Error()) sGot := ExpectType[[]string](t, pl.Value) sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} if sGot != nil { - ExpectEqual(t, sGot, sWant) + ExpectDeepEqual(t, sGot, sWant) } } -func TestCommaSepProxyLabelSingle(t *testing.T) { - v := "a" - pl, err := ParseLabel("proxy.aliases", v) - ExpectNoError(t, err) - sGot := ExpectType[[]string](t, pl.Value) - sWant := []string{"a"} - if sGot != nil { - ExpectEqual(t, sGot, sWant) - } +// func TestCommaSepProxyLabelSingle(t *testing.T) { +// v := "a" +// pl, err := ParseLabel("proxy.aliases", v) +// ExpectNoError(t, err) +// sGot := ExpectType[[]string](t, pl.Value) +// sWant := []string{"a"} +// if sGot != nil { +// ExpectEqual(t, sGot, sWant) +// } +// } -} - -func TestCommaSepProxyLabelMulti(t *testing.T) { - v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3" - pl, err := ParseLabel("proxy.aliases", v) - ExpectNoError(t, err) - sGot := ExpectType[[]string](t, pl.Value) - sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} - if sGot != nil { - ExpectEqual(t, sGot, sWant) - } -} +// func TestCommaSepProxyLabelMulti(t *testing.T) { +// v := "X-Custom-Header1, X-Custom-Header2,X-Custom-Header3" +// pl, err := ParseLabel("proxy.aliases", v) +// ExpectNoError(t, err) +// sGot := ExpectType[[]string](t, pl.Value) +// sWant := []string{"X-Custom-Header1", "X-Custom-Header2", "X-Custom-Header3"} +// if sGot != nil { +// ExpectEqual(t, sGot, sWant) +// } +// } diff --git a/src/docker/labels.go b/src/docker/labels.go new file mode 100644 index 00000000..6b26197a --- /dev/null +++ b/src/docker/labels.go @@ -0,0 +1,13 @@ +package docker + +const ( + WildcardAlias = "*" + + LableAliases = NSProxy + ".aliases" + LableExclude = NSProxy + ".exclude" + LabelIdleTimeout = NSProxy + ".idle_timeout" + LabelWakeTimeout = NSProxy + ".wake_timeout" + LabelStopMethod = NSProxy + ".stop_method" + LabelStopTimeout = NSProxy + ".stop_timeout" + LabelStopSignal = NSProxy + ".stop_signal" +) diff --git a/src/error/builder.go b/src/error/builder.go index 80d79b62..cfd29be9 100644 --- a/src/error/builder.go +++ b/src/error/builder.go @@ -6,16 +6,23 @@ import ( ) type Builder struct { - message string - errors []error + *builder +} + +type builder struct { + message string + errors []NestedError + severity Severity sync.Mutex } -func NewBuilder(format string, args ...any) *Builder { - return &Builder{message: fmt.Sprintf(format, args...)} +func NewBuilder(format string, args ...any) Builder { + return Builder{&builder{message: fmt.Sprintf(format, args...)}} } -func (b *Builder) Add(err error) *Builder { +// adding nil / nil is no-op, +// you may safely pass expressions returning error to it +func (b Builder) Add(err NestedError) Builder { if err != nil { b.Lock() b.errors = append(b.errors, err) @@ -24,8 +31,17 @@ func (b *Builder) Add(err error) *Builder { return b } -func (b *Builder) Addf(format string, args ...any) *Builder { - return b.Add(fmt.Errorf(format, args...)) +func (b Builder) AddE(err error) Builder { + return b.Add(From(err)) +} + +func (b Builder) Addf(format string, args ...any) Builder { + return b.Add(errorf(format, args...)) +} + +func (b Builder) WithSeverity(s Severity) Builder { + b.severity = s + return b } // Build builds a NestedError based on the errors collected in the Builder. @@ -35,9 +51,21 @@ func (b *Builder) Addf(format string, args ...any) *Builder { // // Returns: // - NestedError: the built NestedError. -func (b *Builder) Build() NestedError { +func (b Builder) Build() NestedError { if len(b.errors) == 0 { - return Nil() + return nil } - return Join(b.message, b.errors...) + return Join(b.message, b.errors...).Severity(b.severity) +} + +func (b Builder) To(ptr *NestedError) { + if *ptr == nil { + *ptr = b.Build() + } else { + **ptr = *b.Build() + } +} + +func (b Builder) HasError() bool { + return len(b.errors) > 0 } diff --git a/src/error/builder_test.go b/src/error/builder_test.go index d3424bf2..122a123c 100644 --- a/src/error/builder_test.go +++ b/src/error/builder_test.go @@ -1,13 +1,38 @@ package error -import "testing" +import ( + "testing" -func TestBuilder(t *testing.T) { + . "github.com/yusing/go-proxy/utils/testing" +) + +func TestBuilderEmpty(t *testing.T) { + eb := NewBuilder("qwer") + ExpectTrue(t, eb.Build() == nil) + ExpectTrue(t, eb.Build().NoError()) + ExpectFalse(t, eb.HasError()) +} + +func TestBuilderAddNil(t *testing.T) { + eb := NewBuilder("asdf") + var err NestedError + for range 3 { + eb.Add(nil) + } + for range 3 { + eb.Add(err) + } + ExpectTrue(t, eb.Build() == nil) + ExpectTrue(t, eb.Build().NoError()) + ExpectFalse(t, eb.HasError()) +} + +func TestBuilderNested(t *testing.T) { eb := NewBuilder("error occurred") eb.Add(Failure("Action 1").With(Invalid("Inner", "1")).With(Invalid("Inner", "2"))) eb.Add(Failure("Action 2").With(Invalid("Inner", "3"))) - got := eb.Build().Error() + got := eb.Build().String() expected1 := (`error occurred: - Action 1 failed: diff --git a/src/error/error.go b/src/error/error.go index 6cb0fdf0..15587ab1 100644 --- a/src/error/error.go +++ b/src/error/error.go @@ -7,35 +7,37 @@ import ( ) type ( - // NestedError is an error with an inner error - // and a list of extra nested errors. - // - // It is designed to be non nil. - // - // You can use it to join multiple errors, - // or to set a inner reason for a nested error. - // - // When a method returns both valid values and errors, - // You should return (Slice/Map, NestedError). - // Caller then should handle the nested error, - // and continue with the valid values. - NestedError struct { - subject string - err error // can be nil - extras []NestedError + NestedError = *nestedError + nestedError struct { + subject string + err error // can be nil + extras []nestedError + severity Severity } + errorInterface struct { + *nestedError + } + Severity uint8 ) -func Nil() NestedError { return NestedError{} } +const ( + SeverityFatal Severity = iota + SeverityWarning +) + +func (e errorInterface) Error() string { + return e.String() +} func From(err error) NestedError { + if IsNil(err) { + return nil + } switch err := err.(type) { - case nil: - return Nil() - case NestedError: - return err + case errorInterface: + return err.nestedError default: - return NestedError{err: err} + return &nestedError{err: err} } } @@ -45,40 +47,84 @@ func Check[T any](obj T, err error) (T, NestedError) { return obj, From(err) } -func Join(message string, err ...error) NestedError { - extras := make([]NestedError, 0, len(err)) +func Join(message string, err ...NestedError) NestedError { + extras := make([]nestedError, len(err)) nErr := 0 - for _, e := range err { - if err == nil { + for i, e := range err { + if e == nil { continue } - extras = append(extras, From(e)) + extras[i] = *e nErr += 1 } if nErr == 0 { - return Nil() + return nil } - return NestedError{ + return &nestedError{ err: errors.New(message), extras: extras, } } -func (ne NestedError) Error() string { +func JoinE(message string, err ...error) NestedError { + b := NewBuilder(message) + for _, e := range err { + b.AddE(e) + } + return b.Build() +} + +func IsNil(err error) bool { + return err == nil +} + +func IsNotNil(err error) bool { + return err != nil +} + +func (ne NestedError) String() string { var buf strings.Builder ne.writeToSB(&buf, 0, "") return buf.String() } func (ne NestedError) Is(err error) bool { - return errors.Is(ne.err, err) + if ne == nil { + return err == nil + } + // return errors.Is(ne.err, err) + if errors.Is(ne.err, err) { + return true + } + for _, e := range ne.extras { + if e.Is(err) { + return true + } + } + return false +} + +func (ne NestedError) IsNot(err error) bool { + return !ne.Is(err) +} + +func (ne NestedError) Error() error { + if ne == nil { + return nil + } + return errorInterface{ne} } func (ne NestedError) With(s any) NestedError { + if ne == nil { + return ne + } var msg string switch ss := s.(type) { case nil: return ne + case *nestedError: + return ne.withError(ss.Error()) case error: return ne.withError(ss) case string: @@ -92,10 +138,13 @@ func (ne NestedError) With(s any) NestedError { } func (ne NestedError) Extraf(format string, args ...any) NestedError { - return ne.With(fmt.Errorf(format, args...)) + return ne.With(errorf(format, args...)) } func (ne NestedError) Subject(s any) NestedError { + if ne == nil { + return ne + } switch ss := s.(type) { case string: ne.subject = ss @@ -108,6 +157,9 @@ func (ne NestedError) Subject(s any) NestedError { } func (ne NestedError) Subjectf(format string, args ...any) NestedError { + if ne == nil { + return ne + } if strings.Contains(format, "%q") { panic("Subjectf format should not contain %q") } @@ -118,12 +170,36 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError { return ne } +func (ne NestedError) Severity(s Severity) NestedError { + if ne == nil { + return ne + } + ne.severity = s + return ne +} + +func (ne NestedError) Warn() NestedError { + if ne == nil { + return ne + } + ne.severity = SeverityWarning + return ne +} + func (ne NestedError) NoError() bool { - return ne.err == nil + return ne == nil } func (ne NestedError) HasError() bool { - return ne.err != nil + return ne != nil +} + +func (ne NestedError) IsFatal() bool { + return ne != nil && ne.severity == SeverityFatal +} + +func (ne NestedError) IsWarning() bool { + return ne != nil && ne.severity == SeverityWarning } func errorf(format string, args ...any) NestedError { @@ -131,11 +207,13 @@ func errorf(format string, args ...any) NestedError { } func (ne NestedError) withError(err error) NestedError { - ne.extras = append(ne.extras, From(err)) + if ne != nil && IsNotNil(err) { + ne.extras = append(ne.extras, *From(err)) + } return ne } -func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { +func (ne NestedError) writeToSB(sb *strings.Builder, level int, prefix string) { ne.writeIndents(sb, level) sb.WriteString(prefix) @@ -146,7 +224,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) sb.WriteString(ne.err.Error()) if ne.subject != "" { - if ne.err != nil { + if IsNotNil(ne.err) { sb.WriteString(fmt.Sprintf(" for %q", ne.subject)) } else { sb.WriteString(fmt.Sprint(ne.subject)) @@ -161,7 +239,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) } } -func (ne *NestedError) writeIndents(sb *strings.Builder, level int) { +func (ne NestedError) writeIndents(sb *strings.Builder, level int) { for i := 0; i < level; i++ { sb.WriteString(" ") } diff --git a/src/error/error_test.go b/src/error/error_test.go index e726abee..9fa2f3ec 100644 --- a/src/error/error_test.go +++ b/src/error/error_test.go @@ -4,7 +4,7 @@ import ( "testing" . "github.com/yusing/go-proxy/error" - . "github.com/yusing/go-proxy/utils" + . "github.com/yusing/go-proxy/utils/testing" ) func TestErrorIs(t *testing.T) { @@ -16,27 +16,53 @@ func TestErrorIs(t *testing.T) { ExpectTrue(t, Invalid("foo", "bar").Is(ErrInvalid)) ExpectFalse(t, Invalid("foo", "bar").Is(ErrFailure)) - ExpectTrue(t, Nil().Is(nil)) - ExpectFalse(t, Nil().Is(ErrInvalid)) ExpectFalse(t, Invalid("foo", "bar").Is(nil)) } -func TestNil(t *testing.T) { - ExpectTrue(t, Nil().NoError()) - ExpectFalse(t, Nil().HasError()) - ExpectEqual(t, Nil().Error(), "nil") +func TestErrorNestedIs(t *testing.T) { + var err NestedError + ExpectTrue(t, err.Is(nil)) + + err = Failure("some reason") + ExpectTrue(t, err.Is(ErrFailure)) + ExpectFalse(t, err.Is(ErrAlreadyExist)) + + err.With(AlreadyExist("something", "")) + ExpectTrue(t, err.Is(ErrFailure)) + ExpectTrue(t, err.Is(ErrAlreadyExist)) + ExpectFalse(t, err.Is(ErrInvalid)) +} + +func TestIsNil(t *testing.T) { + var err NestedError + ExpectTrue(t, err.Is(nil)) + ExpectFalse(t, err.HasError()) + ExpectTrue(t, err == nil) + ExpectTrue(t, err.NoError()) + + eb := NewBuilder("") + returnNil := func() error { + return eb.Build().Error() + } + ExpectTrue(t, IsNil(returnNil())) + ExpectTrue(t, returnNil() == nil) + + ExpectTrue(t, (err. + Subject("any"). + With("something"). + Extraf("foo %s", "bar")) == nil) } func TestErrorSimple(t *testing.T) { ne := Failure("foo bar") - ExpectEqual(t, ne.Error(), "foo bar failed") + ExpectEqual(t, ne.String(), "foo bar failed") ne = ne.Subject("baz") - ExpectEqual(t, ne.Error(), "foo bar failed for \"baz\"") + ExpectEqual(t, ne.String(), "foo bar failed for \"baz\"") } func TestErrorWith(t *testing.T) { ne := Failure("foo").With("bar").With("baz") - ExpectEqual(t, ne.Error(), "foo failed:\n - bar\n - baz") + ExpectEqual(t, ne.String(), "foo failed:\n - bar\n - baz") } func TestErrorNested(t *testing.T) { @@ -72,5 +98,5 @@ func TestErrorNested(t *testing.T) { - inner3 failed for "action 3": - 3 - 3` - ExpectEqual(t, ne.Error(), want) + ExpectEqual(t, ne.String(), want) } diff --git a/src/error/errors.go b/src/error/errors.go index 98213b98..14b34131 100644 --- a/src/error/errors.go +++ b/src/error/errors.go @@ -5,33 +5,48 @@ import ( ) var ( - ErrFailure = stderrors.New("failed") - ErrInvalid = stderrors.New("invalid") - ErrUnsupported = stderrors.New("unsupported") - ErrNotExists = stderrors.New("does not exist") - ErrDuplicated = stderrors.New("duplicated") + ErrFailure = stderrors.New("failed") + ErrInvalid = stderrors.New("invalid") + ErrUnsupported = stderrors.New("unsupported") + ErrUnexpected = stderrors.New("unexpected") + ErrNotExists = stderrors.New("does not exist") + ErrAlreadyExist = stderrors.New("already exist") ) +const fmtSubjectWhat = "%w %v: %v" + func Failure(what string) NestedError { return errorf("%s %w", what, ErrFailure) } -func FailureWhy(what string, why string) NestedError { +func FailedWhy(what string, why string) NestedError { return errorf("%s %w because %s", what, ErrFailure, why) } +func FailWith(what string, err any) NestedError { + return Failure(what).With(err) +} + func Invalid(subject, what any) NestedError { - return errorf("%w %v - %v", ErrInvalid, subject, what) + return errorf(fmtSubjectWhat, ErrInvalid, subject, what) } func Unsupported(subject, what any) NestedError { - return errorf("%w %v - %v", ErrUnsupported, subject, what) + return errorf(fmtSubjectWhat, ErrUnsupported, subject, what) } -func NotExists(subject, what any) NestedError { - return errorf("%s %v - %v", subject, ErrNotExists, what) +func Unexpected(subject, what any) NestedError { + return errorf(fmtSubjectWhat, ErrUnexpected, subject, what) } -func Duplicated(subject, what any) NestedError { - return errorf("%w %v: %v", ErrDuplicated, subject, what) +func UnexpectedError(err error) NestedError { + return errorf("%w error: %w", ErrUnexpected, err) +} + +func NotExist(subject, what any) NestedError { + return errorf("%v %w: %v", subject, ErrNotExists, what) +} + +func AlreadyExist(subject, what any) NestedError { + return errorf("%v %w: %v", subject, ErrAlreadyExist, what) } diff --git a/src/go.mod b/src/go.mod index 25ea342c..2d471d22 100644 --- a/src/go.mod +++ b/src/go.mod @@ -7,6 +7,7 @@ require ( github.com/docker/docker v27.2.1+incompatible github.com/fsnotify/fsnotify v1.7.0 github.com/go-acme/lego/v4 v4.18.0 + github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/santhosh-tekuri/jsonschema v1.2.4 github.com/sirupsen/logrus v1.9.3 golang.org/x/net v0.29.0 diff --git a/src/go.sum b/src/go.sum index 4fbab959..8a92f6d4 100644 --- a/src/go.sum +++ b/src/go.sum @@ -73,6 +73,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4= +github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis= diff --git a/src/main.go b/src/main.go index ecb1c0cf..2a6a80c2 100755 --- a/src/main.go +++ b/src/main.go @@ -18,6 +18,7 @@ import ( "github.com/yusing/go-proxy/common" "github.com/yusing/go-proxy/config" "github.com/yusing/go-proxy/docker" + "github.com/yusing/go-proxy/docker/idlewatcher" E "github.com/yusing/go-proxy/error" R "github.com/yusing/go-proxy/route" "github.com/yusing/go-proxy/server" @@ -53,37 +54,40 @@ func main() { // exit if only validate config if args.Command == common.CommandValidate { - var err E.NestedError - data, err := E.Check(os.ReadFile(common.ConfigPath)) - if err.HasError() { - l.WithError(err).Fatalf("config error") + data, err := os.ReadFile(common.ConfigPath) + if err == nil { + err = config.Validate(data).Error() } - if err = config.Validate(data); err.HasError() { - l.WithError(err).Fatalf("config error") + if err != nil { + l.Fatal("config error: ", err) } l.Printf("config OK") return } - cfg, err := config.New() - if err.HasError() { - l.Fatalf("config error: %s", err) + cfg, err := config.Load() + if err.IsFatal() { + l.Fatal(err) } if args.Command == common.CommandListConfigs { - yml, err := E.Check(json.Marshal(cfg.Value())) - if err.HasError() { - panic(err) - } - rawLogger := log.New(os.Stdout, "", 0) - rawLogger.Printf("%s", yml) // raw output for convenience using "jq" + printJSON(cfg.Value()) return } - onShutdown.Add(func() { - docker.CloseAllClients() - cfg.Dispose() - }) + cfg.StartProxyProviders() + + if args.Command == common.CommandListRoutes { + printJSON(cfg.RoutesByAlias()) + return + } + + if err.HasError() { + l.Warn(err) + } + + onShutdown.Add(docker.CloseAllClients) + onShutdown.Add(cfg.Dispose) sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) @@ -109,8 +113,9 @@ func main() { onShutdown.Add(certRenewalCancel) } - for name, expiry := range autocert.GetExpiries() { - l.Infof("certificate %q: expire on %s", name, expiry) + for _, expiry := range autocert.GetExpiries() { + l.Infof("certificate expire on %s", expiry) + break } } else { l.Info("autocert not configured") @@ -137,6 +142,9 @@ func main() { onShutdown.Add(proxyServer.Stop) onShutdown.Add(apiServer.Stop) + go idlewatcher.Start() + onShutdown.Add(idlewatcher.Stop) + // wait for signal <-sig @@ -164,3 +172,12 @@ func main() { logrus.Info("timeout waiting for shutdown") } } + +func printJSON(obj any) { + j, err := E.Check(json.Marshal(obj)) + if err.HasError() { + logrus.Fatal(err) + } + rawLogger := log.New(os.Stdout, "", 0) + rawLogger.Printf("%s", j) // raw output for convenience using "jq" +} diff --git a/src/models/proxy_entry.go b/src/models/proxy_entry.go index e8442173..4f41dcd1 100644 --- a/src/models/proxy_entry.go +++ b/src/models/proxy_entry.go @@ -1,13 +1,16 @@ package model import ( + "strconv" "strings" + . "github.com/yusing/go-proxy/common" + D "github.com/yusing/go-proxy/docker" F "github.com/yusing/go-proxy/utils/functional" ) type ( - ProxyEntry struct { + ProxyEntry struct { // raw entry object before validation Alias string `yaml:"-" json:"-"` Scheme string `yaml:"scheme" json:"scheme"` Host string `yaml:"host" json:"host"` @@ -16,35 +19,66 @@ type ( PathPatterns []string `yaml:"path_patterns" json:"path_patterns"` // http(s) proxy only SetHeaders map[string]string `yaml:"set_headers" json:"set_headers"` // http(s) proxy only HideHeaders []string `yaml:"hide_headers" json:"hide_headers"` // http(s) proxy only + + /* Docker only */ + *D.ProxyProperties `yaml:"-" json:"-"` } - ProxyEntries = *F.Map[string, *ProxyEntry] + ProxyEntries = F.Map[string, *ProxyEntry] ) -var NewProxyEntries = F.NewMap[string, *ProxyEntry] +var NewProxyEntries = F.NewMapOf[string, *ProxyEntry] func (e *ProxyEntry) SetDefaults() { if e.Scheme == "" { - if strings.ContainsRune(e.Port, ':') { + switch { + case strings.ContainsRune(e.Port, ':'): e.Scheme = "tcp" - } else { - switch e.Port { - case "443", "8443": - e.Scheme = "https" - default: - e.Scheme = "http" + case e.ProxyProperties != nil: + if _, ok := ServiceNamePortMapTCP[e.ImageName]; ok { + e.Scheme = "tcp" } } } + + if e.Scheme == "" { + switch e.Port { + case "443", "8443": + e.Scheme = "https" + default: + e.Scheme = "http" + } + } if e.Host == "" { e.Host = "localhost" } if e.Port == "" { - switch e.Scheme { - case "http": - e.Port = "80" - case "https": - e.Port = "443" + e.Port = e.FirstPort + } + if e.Port == "" { + if port, ok := ServiceNamePortMapTCP[e.Port]; ok { + e.Port = strconv.Itoa(port) + } else if port, ok := ImageNamePortMapHTTP[e.Port]; ok { + e.Port = strconv.Itoa(port) + } else { + switch e.Scheme { + case "http": + e.Port = "80" + case "https": + e.Port = "443" + } } } + if e.IdleTimeout == "" { + e.IdleTimeout = IdleTimeoutDefault + } + if e.WakeTimeout == "" { + e.WakeTimeout = WakeTimeoutDefault + } + if e.StopTimeout == "" { + e.StopTimeout = StopTimeoutDefault + } + if e.StopMethod == "" { + e.StopMethod = StopMethodDefault + } } diff --git a/src/proxy/entry.go b/src/proxy/entry.go index dc625cc0..896367b6 100644 --- a/src/proxy/entry.go +++ b/src/proxy/entry.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "net/url" + "time" E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" @@ -11,16 +12,23 @@ import ( ) type ( - Entry struct { // real model after validation + ReverseProxyEntry struct { // real model after validation Alias T.Alias Scheme T.Scheme - Host T.Host - Port T.Port URL *url.URL NoTLSVerify bool PathPatterns T.PathPatterns SetHeaders http.Header HideHeaders []string + + /* Docker only */ + IdleTimeout time.Duration + WakeTimeout time.Duration + StopMethod T.StopMethod + StopTimeout int + StopSignal T.Signal + DockerHost string + ContainerName string } StreamEntry struct { Alias T.Alias `json:"alias"` @@ -30,69 +38,105 @@ type ( } ) -func NewEntry(m *M.ProxyEntry) (any, E.NestedError) { +func (rp *ReverseProxyEntry) UseIdleWatcher() bool { + return rp.IdleTimeout > 0 && rp.DockerHost != "" +} + +func ValidateEntry(m *M.ProxyEntry) (any, E.NestedError) { m.SetDefaults() scheme, err := T.NewScheme(m.Scheme) if err.HasError() { return nil, err } + + var entry any + e := E.NewBuilder("error validating proxy entry") if scheme.IsStream() { - return validateStreamEntry(m) + entry = validateStreamEntry(m, e) + } else { + entry = validateRPEntry(m, scheme, e) } - return validateEntry(m, scheme) + if err := e.Build(); err.HasError() { + return nil, err + } + return entry, nil } -func validateEntry(m *M.ProxyEntry, s T.Scheme) (*Entry, E.NestedError) { - host, err := T.NewHost(m.Host) - if err.HasError() { - return nil, err - } - port, err := T.NewPort(m.Port) - if err.HasError() { - return nil, err - } - pathPatterns, err := T.NewPathPatterns(m.PathPatterns) - if err.HasError() { - return nil, err - } - setHeaders, err := T.NewHTTPHeaders(m.SetHeaders) - if err.HasError() { - return nil, err - } +func validateRPEntry(m *M.ProxyEntry, s T.Scheme, b E.Builder) *ReverseProxyEntry { + var stopTimeOut time.Duration + + host, err := T.ValidateHost(m.Host) + b.Add(err) + + port, err := T.ValidatePort(m.Port) + b.Add(err) + + pathPatterns, err := T.ValidatePathPatterns(m.PathPatterns) + b.Add(err) + + setHeaders, err := T.ValidateHTTPHeaders(m.SetHeaders) + b.Add(err) + url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) - if err.HasError() { - return nil, err + b.Add(err) + + idleTimeout, err := T.ValidateDurationPostitive(m.IdleTimeout) + b.Add(err) + + wakeTimeout, err := T.ValidateDurationPostitive(m.WakeTimeout) + b.Add(err) + + stopMethod, err := T.ValidateStopMethod(m.StopMethod) + b.Add(err) + + if stopMethod == T.StopMethodStop { + stopTimeOut, err = T.ValidateDurationPostitive(m.StopTimeout) + b.Add(err) + } + + stopSignal, err := T.ValidateSignal(m.StopSignal) + b.Add(err) + + if err.HasError() { + return nil + } + + return &ReverseProxyEntry{ + Alias: T.NewAlias(m.Alias), + Scheme: s, + URL: url, + NoTLSVerify: m.NoTLSVerify, + PathPatterns: pathPatterns, + SetHeaders: setHeaders, + HideHeaders: m.HideHeaders, + IdleTimeout: idleTimeout, + WakeTimeout: wakeTimeout, + StopMethod: stopMethod, + StopTimeout: int(stopTimeOut.Seconds()), // docker api takes integer seconds for timeout argument + StopSignal: stopSignal, + DockerHost: m.DockerHost, + ContainerName: m.ContainerName, } - return &Entry{ - Alias: T.NewAlias(m.Alias), - Scheme: s, - Host: host, - Port: port, - URL: url, - NoTLSVerify: m.NoTLSVerify, - PathPatterns: pathPatterns, - SetHeaders: setHeaders, - HideHeaders: m.HideHeaders, - }, E.Nil() } -func validateStreamEntry(m *M.ProxyEntry) (*StreamEntry, E.NestedError) { - host, err := T.NewHost(m.Host) - if err.HasError() { - return nil, err - } - port, err := T.NewStreamPort(m.Port) - if err.HasError() { - return nil, err - } - scheme, err := T.NewStreamScheme(m.Scheme) - if err.HasError() { - return nil, err +func validateStreamEntry(m *M.ProxyEntry, b E.Builder) *StreamEntry { + host, err := T.ValidateHost(m.Host) + b.Add(err) + + port, err := T.ValidateStreamPort(m.Port) + b.Add(err) + + scheme, err := T.ValidateStreamScheme(m.Scheme) + b.Add(err) + + if b.HasError() { + return nil } + return &StreamEntry{ Alias: T.NewAlias(m.Alias), Scheme: *scheme, Host: host, Port: port, - }, E.Nil() + } } diff --git a/src/proxy/fields/alias.go b/src/proxy/fields/alias.go index 3b56810f..289f9648 100644 --- a/src/proxy/fields/alias.go +++ b/src/proxy/fields/alias.go @@ -1,23 +1,6 @@ package fields -import ( - "strings" - - F "github.com/yusing/go-proxy/utils/functional" +type ( + Alias string + NewAlias = Alias ) - -type Alias string -type Aliases struct{ *F.Slice[Alias] } - -func NewAlias(s string) Alias { - return Alias(s) -} - -func NewAliases(s string) Aliases { - split := strings.Split(s, ",") - a := Aliases{F.NewSliceN[Alias](len(split))} - for i, v := range split { - a.Set(i, NewAlias(v)) - } - return a -} diff --git a/src/proxy/fields/headers.go b/src/proxy/fields/headers.go index 173767f8..fd1483ed 100644 --- a/src/proxy/fields/headers.go +++ b/src/proxy/fields/headers.go @@ -7,7 +7,7 @@ import ( E "github.com/yusing/go-proxy/error" ) -func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) { +func ValidateHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) { h := make(http.Header) for k, v := range headers { vSplit := strings.Split(v, ",") @@ -15,5 +15,5 @@ func NewHTTPHeaders(headers map[string]string) (http.Header, E.NestedError) { h.Add(k, strings.TrimSpace(header)) } } - return h, E.Nil() + return h, nil } diff --git a/src/proxy/fields/host.go b/src/proxy/fields/host.go index dda158c3..39b7b7b0 100644 --- a/src/proxy/fields/host.go +++ b/src/proxy/fields/host.go @@ -7,6 +7,6 @@ import ( type Host string type Subdomain = Alias -func NewHost(s string) (Host, E.NestedError) { - return Host(s), E.Nil() +func ValidateHost(s string) (Host, E.NestedError) { + return Host(s), nil } diff --git a/src/proxy/fields/path_mode.go b/src/proxy/fields/path_mode.go index f632cf9b..f4d4889d 100644 --- a/src/proxy/fields/path_mode.go +++ b/src/proxy/fields/path_mode.go @@ -9,7 +9,7 @@ type PathMode string func NewPathMode(pm string) (PathMode, E.NestedError) { switch pm { case "", "forward": - return PathMode(pm), E.Nil() + return PathMode(pm), nil default: return "", E.Invalid("path mode", pm) } diff --git a/src/proxy/fields/path_pattern.go b/src/proxy/fields/path_pattern.go index 4d68ec81..eec599f5 100644 --- a/src/proxy/fields/path_pattern.go +++ b/src/proxy/fields/path_pattern.go @@ -16,12 +16,12 @@ func NewPathPattern(s string) (PathPattern, E.NestedError) { if !pathPattern.MatchString(string(s)) { return "", E.Invalid("path pattern", s) } - return PathPattern(s), E.Nil() + return PathPattern(s), nil } -func NewPathPatterns(s []string) (PathPatterns, E.NestedError) { +func ValidatePathPatterns(s []string) (PathPatterns, E.NestedError) { if len(s) == 0 { - return []PathPattern{"/"}, E.Nil() + return []PathPattern{"/"}, nil } pp := make(PathPatterns, len(s)) for i, v := range s { @@ -31,7 +31,7 @@ func NewPathPatterns(s []string) (PathPatterns, E.NestedError) { pp[i] = pattern } } - return pp, E.Nil() + return pp, nil } var pathPattern = regexp.MustCompile("^((GET|POST|DELETE|PUT|PATCH|HEAD|OPTIONS|CONNECT)\\s)?(/\\w*)+/?$") diff --git a/src/proxy/fields/port.go b/src/proxy/fields/port.go index 783017e0..ce7582d5 100644 --- a/src/proxy/fields/port.go +++ b/src/proxy/fields/port.go @@ -8,7 +8,7 @@ import ( type Port int -func NewPort(v string) (Port, E.NestedError) { +func ValidatePort(v string) (Port, E.NestedError) { p, err := strconv.Atoi(v) if err != nil { return ErrPort, E.Invalid("port number", v).With(err) @@ -21,14 +21,14 @@ func NewPortInt[Int int | uint16](v Int) (Port, E.NestedError) { if err := pp.boundCheck(); err.HasError() { return ErrPort, err } - return pp, E.Nil() + return pp, nil } func (p Port) boundCheck() E.NestedError { if p < MinPort || p > MaxPort { return E.Invalid("port", p) } - return E.Nil() + return nil } const ( diff --git a/src/proxy/fields/scheme.go b/src/proxy/fields/scheme.go index f0dc510d..2c60178b 100644 --- a/src/proxy/fields/scheme.go +++ b/src/proxy/fields/scheme.go @@ -1,8 +1,6 @@ package fields import ( - "strings" - E "github.com/yusing/go-proxy/error" ) @@ -11,24 +9,11 @@ type Scheme string func NewScheme(s string) (Scheme, E.NestedError) { switch s { case "http", "https", "tcp", "udp": - return Scheme(s), E.Nil() + return Scheme(s), nil } return "", E.Invalid("scheme", s) } -func NewSchemeFromPort(p string) (Scheme, E.NestedError) { - var s string - switch { - case strings.ContainsRune(p, ':'): - s = "tcp" - case strings.HasSuffix(p, "443"): - s = "https" - default: - s = "http" - } - return Scheme(s), E.Nil() -} - func (s Scheme) IsHTTP() bool { return s == "http" } func (s Scheme) IsHTTPS() bool { return s == "https" } func (s Scheme) IsTCP() bool { return s == "tcp" } diff --git a/src/proxy/fields/signal.go b/src/proxy/fields/signal.go new file mode 100644 index 00000000..c2bae6c4 --- /dev/null +++ b/src/proxy/fields/signal.go @@ -0,0 +1,17 @@ +package fields + +import ( + E "github.com/yusing/go-proxy/error" +) + +type Signal string + +func ValidateSignal(s string) (Signal, E.NestedError) { + switch s { + case "", "SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT", + "INT", "TERM", "HUP", "QUIT": + return Signal(s), nil + } + + return "", E.Invalid("signal", s) +} diff --git a/src/proxy/fields/stop_method.go b/src/proxy/fields/stop_method.go new file mode 100644 index 00000000..6d4e7471 --- /dev/null +++ b/src/proxy/fields/stop_method.go @@ -0,0 +1,23 @@ +package fields + +import ( + E "github.com/yusing/go-proxy/error" +) + +type StopMethod string + +const ( + StopMethodPause StopMethod = "pause" + StopMethodStop StopMethod = "stop" + StopMethodKill StopMethod = "kill" +) + +func ValidateStopMethod(s string) (StopMethod, E.NestedError) { + sm := StopMethod(s) + switch sm { + case StopMethodPause, StopMethodStop, StopMethodKill: + return sm, nil + default: + return "", E.Invalid("stop_method", sm) + } +} diff --git a/src/proxy/fields/stream_port.go b/src/proxy/fields/stream_port.go index fdec1f04..f853f2cf 100644 --- a/src/proxy/fields/stream_port.go +++ b/src/proxy/fields/stream_port.go @@ -12,13 +12,13 @@ type StreamPort struct { ProxyPort Port `json:"proxy"` } -func NewStreamPort(p string) (StreamPort, E.NestedError) { +func ValidateStreamPort(p string) (StreamPort, E.NestedError) { split := strings.Split(p, ":") if len(split) != 2 { return StreamPort{}, E.Invalid("stream port", p).With("should be in 'x:y' format") } - listeningPort, err := NewPort(split[0]) + listeningPort, err := ValidatePort(split[0]) if err.HasError() { return StreamPort{}, err } @@ -26,7 +26,7 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) { return StreamPort{}, err } - proxyPort, err := NewPort(split[1]) + proxyPort, err := ValidatePort(split[1]) if err.HasError() { proxyPort, err = parseNameToPort(split[1]) if err.HasError() { @@ -37,13 +37,13 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) { return StreamPort{}, err } - return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, E.Nil() + return StreamPort{ListeningPort: listeningPort, ProxyPort: proxyPort}, nil } func parseNameToPort(name string) (Port, E.NestedError) { - port, ok := common.NamePortMapTCP[name] + port, ok := common.ServiceNamePortMapTCP[name] if !ok { return -1, E.Unsupported("service", name) } - return Port(port), E.Nil() + return Port(port), nil } diff --git a/src/proxy/fields/stream_scheme.go b/src/proxy/fields/stream_scheme.go index 6b88a044..3287ab73 100644 --- a/src/proxy/fields/stream_scheme.go +++ b/src/proxy/fields/stream_scheme.go @@ -12,7 +12,7 @@ type StreamScheme struct { ProxyScheme Scheme `json:"proxy"` } -func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { +func ValidateStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { ss = &StreamScheme{} parts := strings.Split(s, ":") if len(parts) == 1 { @@ -28,7 +28,7 @@ func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { if err.HasError() { return nil, err } - return ss, E.Nil() + return ss, nil } func (s StreamScheme) String() string { diff --git a/src/proxy/fields/timeout.go b/src/proxy/fields/timeout.go new file mode 100644 index 00000000..4f70f4bb --- /dev/null +++ b/src/proxy/fields/timeout.go @@ -0,0 +1,18 @@ +package fields + +import ( + "time" + + E "github.com/yusing/go-proxy/error" +) + +func ValidateDurationPostitive(value string) (time.Duration, E.NestedError) { + d, err := time.ParseDuration(value) + if err != nil { + return 0, E.Invalid("duration", value) + } + if d < 0 { + return 0, E.Invalid("duration", "negative value") + } + return d, nil +} diff --git a/src/proxy/provider/docker_provider.go b/src/proxy/provider/docker_provider.go index 2fada9bd..25986524 100755 --- a/src/proxy/provider/docker_provider.go +++ b/src/proxy/provider/docker_provider.go @@ -1,168 +1,160 @@ package provider import ( - "fmt" - "strings" - - "github.com/docker/docker/api/types" - "github.com/sirupsen/logrus" D "github.com/yusing/go-proxy/docker" E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" - PT "github.com/yusing/go-proxy/proxy/fields" - U "github.com/yusing/go-proxy/utils" + R "github.com/yusing/go-proxy/route" W "github.com/yusing/go-proxy/watcher" + . "github.com/yusing/go-proxy/watcher/event" ) type DockerProvider struct { - dockerHost string + dockerHost, hostname string } func DockerProviderImpl(dockerHost string) ProviderImpl { return &DockerProvider{dockerHost: dockerHost} } -// GetProxyEntries returns proxy entries from a docker client. -// -// It retrieves the docker client information using the dockerhelper.GetClientInfo method. -// Then, it iterates over the containers in the docker client information and calls -// the getEntriesFromLabels method to get the proxy entries for each container. -// Any errors encountered during the process are added to the ne error object. -// Finally, it returns the collected proxy entries and the ne error object. -// -// Parameters: -// - p: A pointer to the DockerProvider struct. -// -// Returns: -// - P.EntryModelSlice: (non-nil) A slice of EntryModel structs representing the proxy entries. -// - error: An error object if there was an error retrieving the docker client information or parsing the labels. -func (p DockerProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) { +func (p *DockerProvider) NewWatcher() W.Watcher { + return W.NewDockerWatcher(p.dockerHost) +} + +func (p *DockerProvider) LoadRoutesImpl() (routes R.Routes, err E.NestedError) { entries := M.NewProxyEntries() - info, err := D.GetClientInfo(p.dockerHost) + info, err := D.GetClientInfo(p.dockerHost, true) if err.HasError() { - return entries, err + return routes, E.FailWith("connect to docker", err) } errors := E.NewBuilder("errors when parse docker labels") - for _, container := range info.Containers { - en, err := p.getEntriesFromLabels(&container, info.Host) + for _, c := range info.Containers { + container := D.FromDocker(&c, p.dockerHost) + if container.IsExcluded { + continue + } + + newEntries, err := p.entriesFromContainerLabels(container) if err.HasError() { errors.Add(err) } // although err is not nil // there may be some valid entries in `en` - dups := entries.MergeWith(en) + dups := entries.MergeFrom(newEntries) // add the duplicate proxy entries to the error - dups.EachKV(func(k string, v *M.ProxyEntry) { + dups.RangeAll(func(k string, v *M.ProxyEntry) { errors.Addf("duplicate alias %s", k) }) } - return entries, errors.Build() + entries.RangeAll(func(_ string, e *M.ProxyEntry) { + e.DockerHost = p.dockerHost + }) + + routes, err = R.FromEntries(entries) + errors.Add(err) + + return routes, errors.Build() } -func (p *DockerProvider) NewWatcher() W.Watcher { - return W.NewDockerWatcher(p.dockerHost) +func (p *DockerProvider) OnEvent(event Event, routes R.Routes) (res EventResult) { + b := E.NewBuilder("event %s error", event) + defer b.To(&res.err) + + routes.RangeAll(func(k string, v R.Route) { + if v.Entry().ContainerName == event.ActorName { + b.Add(v.Stop()) + routes.Delete(k) + res.nRemoved++ + } + }) + + switch event.Action { + case ActionStarted, ActionCreated, ActionModified: + client, err := D.ConnectClient(p.dockerHost) + if err.HasError() { + b.Add(E.FailWith("connect to docker", err)) + return + } + defer client.Close() + cont, err := client.Inspect(event.ActorID) + if err.HasError() { + b.Add(E.FailWith("inspect container", err)) + return + } + entries, err := p.entriesFromContainerLabels(cont) + b.Add(err) + + entries.RangeAll(func(alias string, entry *M.ProxyEntry) { + if routes.Has(alias) { + b.Add(E.AlreadyExist("alias", alias)) + } else { + if route, err := R.NewRoute(entry); err.HasError() { + b.Add(err) + } else { + routes.Store(alias, route) + b.Add(route.Start()) + res.nAdded++ + } + } + }) + } + + return } // Returns a list of proxy entries for a container. // Always non-nil -func (p *DockerProvider) getEntriesFromLabels(container *types.Container, clientHost string) (M.ProxyEntries, E.NestedError) { - var mainAlias string - var aliases PT.Aliases - - if exclude, ok := container.Labels[D.NSProxy+".exclude"]; ok { - if U.ParseBool(exclude) { - return M.NewProxyEntries(), E.Nil() - } - } - - // set mainAlias to docker compose service name if available - if serviceName, ok := container.Labels["com.docker.compose.service"]; ok { - mainAlias = serviceName - } - - // if mainAlias is not set, - // or container name is different from service name - // use container name - if containerName := strings.TrimPrefix(container.Names[0], "/"); containerName != mainAlias { - mainAlias = containerName - } - - if l, ok := container.Labels[D.NSProxy+".aliases"]; ok { - aliases = PT.NewAliases(l) - delete(container.Labels, D.NSProxy+"proxy.aliases") - } else { - aliases = PT.NewAliases(mainAlias) - } - +func (p *DockerProvider) entriesFromContainerLabels(container D.Container) (M.ProxyEntries, E.NestedError) { entries := M.NewProxyEntries() - // find first port, return if no port exposed - defaultPort, err := findFirstPort(container) - if err.HasError() { - logrus.Debug(mainAlias, " ", err.Error()) - } - // init entries map for all aliases - aliases.ForEach(func(a PT.Alias) { - entries.Set(string(a), &M.ProxyEntry{ - Alias: string(a), - Host: clientHost, - Port: defaultPort, + for _, a := range container.Aliases { + entries.Store(a, &M.ProxyEntry{ + Alias: a, + Host: p.hostname, + ProxyProperties: container.ProxyProperties, }) - }) + } - errors := E.NewBuilder("failed to apply label for %q", mainAlias) + errors := E.NewBuilder("failed to apply label") for key, val := range container.Labels { - lbl, err := D.ParseLabel(key, val) - if err.HasError() { - errors.Add(E.From(err).Subject(key)) - continue - } - if lbl.Namespace != D.NSProxy { - continue - } - if lbl.Target == wildcardAlias { - // apply label for all aliases - entries.EachKV(func(a string, e *M.ProxyEntry) { - if err = D.ApplyLabel(e, lbl); err.HasError() { - errors.Add(E.From(err).Subject(lbl.Target)) - } - }) - } else { - config, ok := entries.UnsafeGet(lbl.Target) - if !ok { - errors.Add(E.NotExists("alias", lbl.Target)) - continue - } - if err = D.ApplyLabel(config, lbl); err.HasError() { - errors.Add(err.Subject(lbl.Target)) - } - } + errors.Add(p.applyLabel(entries, key, val)) } - entries.EachKV(func(a string, e *M.ProxyEntry) { - if e.Port == "" { - entries.UnsafeDelete(a) - } - }) - - return entries, errors.Build() + return entries, errors.Build().Subject(container.ContainerName) } -func findFirstPort(c *types.Container) (string, E.NestedError) { - if len(c.Ports) == 0 { - return "", E.FailureWhy("findFirstPort", "no port exposed") +func (p *DockerProvider) applyLabel(entries M.ProxyEntries, key, val string) (res E.NestedError) { + b := E.NewBuilder("errors in label %s", key) + defer b.To(&res) + + lbl, err := D.ParseLabel(key, val) + if err.HasError() { + b.Add(err.Subject(key)) } - for _, p := range c.Ports { - if p.PublicPort != 0 { - return fmt.Sprint(p.PublicPort), E.Nil() + if lbl.Namespace != D.NSProxy { + return + } + if lbl.Target == D.WildcardAlias { + // apply label for all aliases + entries.RangeAll(func(a string, e *M.ProxyEntry) { + if err = D.ApplyLabel(e, lbl); err.HasError() { + b.Add(err.Subject(lbl.Target)) + } + }) + } else { + config, ok := entries.Load(lbl.Target) + if !ok { + b.Add(E.NotExist("alias", lbl.Target)) + return + } + if err = D.ApplyLabel(config, lbl); err.HasError() { + b.Add(err.Subject(lbl.Target)) } } - return "", E.Failure("findFirstPort") + return } - -const wildcardAlias = "*" diff --git a/src/proxy/provider/file_provider.go b/src/proxy/provider/file_provider.go index b69df6ed..10e9763e 100644 --- a/src/proxy/provider/file_provider.go +++ b/src/proxy/provider/file_provider.go @@ -7,8 +7,10 @@ import ( "github.com/yusing/go-proxy/common" E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" + R "github.com/yusing/go-proxy/route" U "github.com/yusing/go-proxy/utils" W "github.com/yusing/go-proxy/watcher" + . "github.com/yusing/go-proxy/watcher/event" ) type FileProvider struct { @@ -27,26 +29,53 @@ func Validate(data []byte) E.NestedError { return U.ValidateYaml(U.GetSchema(common.ProvidersSchemaPath), data) } -func (p *FileProvider) String() string { - return p.fileName +func (p FileProvider) OnEvent(event Event, routes R.Routes) (res EventResult) { + b := E.NewBuilder("event %s error", event) + defer b.To(&res.err) + + newRoutes, err := p.LoadRoutesImpl() + if err.HasError() { + b.Add(err) + return + } + + routes.RangeAll(func(_ string, v R.Route) { + b.Add(v.Stop()) + }) + routes.Clear() + + newRoutes.RangeAll(func(_ string, v R.Route) { + b.Add(v.Start()) + }) + + routes.MergeFrom(newRoutes) + return } -func (p *FileProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) { +func (p *FileProvider) LoadRoutesImpl() (routes R.Routes, res E.NestedError) { + b := E.NewBuilder("file %q validation failure", p.fileName) + defer b.To(&res) + entries := M.NewProxyEntries() + data, err := E.Check(os.ReadFile(p.path)) if err.HasError() { - return entries, E.Failure("read file").Subject(p).With(err) + b.Add(E.FailWith("read file", err)) + return } - ne := E.Failure("validation").Subject(p) + if !common.NoSchemaValidation { if err = Validate(data); err.HasError() { - return entries, ne.With(err) + b.Add(err) + return } } if err = entries.UnmarshalFromYAML(data); err.HasError() { - return entries, ne.With(err) + b.Add(err) + return } - return entries, E.Nil() + + return R.FromEntries(entries) } func (p *FileProvider) NewWatcher() W.Watcher { diff --git a/src/proxy/provider/provider.go b/src/proxy/provider/provider.go index 52366cfd..0063a6ef 100644 --- a/src/proxy/provider/provider.go +++ b/src/proxy/provider/provider.go @@ -4,38 +4,40 @@ import ( "context" "fmt" "path" - "time" "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/error" - M "github.com/yusing/go-proxy/models" R "github.com/yusing/go-proxy/route" W "github.com/yusing/go-proxy/watcher" + . "github.com/yusing/go-proxy/watcher/event" ) -type ProviderImpl interface { - GetProxyEntries() (M.ProxyEntries, E.NestedError) - NewWatcher() W.Watcher -} +type ( + Provider struct { + ProviderImpl -type Provider struct { - ProviderImpl + name string + t ProviderType + routes R.Routes - name string - t ProviderType - routes *R.Routes - reloadReqCh chan struct{} + watcher W.Watcher + watcherCtx context.Context + watcherCancel context.CancelFunc - watcher W.Watcher - watcherCtx context.Context - watcherCancel context.CancelFunc - - l *logrus.Entry - - cooldownCh chan struct{} -} - -type ProviderType string + l *logrus.Entry + } + ProviderImpl interface { + NewWatcher() W.Watcher + LoadRoutesImpl() (R.Routes, E.NestedError) + OnEvent(event Event, routes R.Routes) EventResult + } + ProviderType string + EventResult struct { + nRemoved int + nAdded int + err E.NestedError + } +) const ( ProviderTypeDocker ProviderType = "docker" @@ -44,16 +46,14 @@ const ( func newProvider(name string, t ProviderType) *Provider { p := &Provider{ - name: name, - t: t, - routes: R.NewRoutes(), - reloadReqCh: make(chan struct{}, 1), - cooldownCh: make(chan struct{}, 1), + name: name, + t: t, + routes: R.NewRoutes(), } p.l = logrus.WithField("provider", p) - go p.processReloadRequests() return p } + func NewFileProvider(filename string) *Provider { name := path.Base(filename) p := newProvider(name, ProviderTypeFile) @@ -78,25 +78,21 @@ func (p *Provider) GetType() ProviderType { } func (p *Provider) String() string { - return fmt.Sprintf("%s: %s", p.t, p.name) + return fmt.Sprintf("%s-%s", p.t, p.name) } -func (p *Provider) StartAllRoutes() E.NestedError { - err := p.loadRoutes() +func (p *Provider) StartAllRoutes() (res E.NestedError) { + errors := E.NewBuilder("errors in routes") + defer errors.To(&res) // start watcher no matter load success or not p.watcherCtx, p.watcherCancel = context.WithCancel(context.Background()) go p.watchEvents() - errors := E.NewBuilder("errors in routes") nStarted := 0 nFailed := 0 - if err.HasError() { - errors.Add(err) - } - - p.routes.EachKVParallel(func(alias string, r R.Route) { + p.routes.RangeAll(func(alias string, r R.Route) { if err := r.Start(); err.HasError() { errors.Add(err.Subject(r)) nFailed++ @@ -106,18 +102,21 @@ func (p *Provider) StartAllRoutes() E.NestedError { }) p.l.Debugf("%d routes started, %d failed", nStarted, nFailed) - return errors.Build() + return } -func (p *Provider) StopAllRoutes() E.NestedError { +func (p *Provider) StopAllRoutes() (res E.NestedError) { if p.watcherCancel != nil { p.watcherCancel() p.watcherCancel = nil } + errors := E.NewBuilder("errors stopping routes for provider %q", p.name) + defer errors.To(&res) + nStopped := 0 nFailed := 0 - p.routes.EachKVParallel(func(alias string, r R.Route) { + p.routes.RangeAll(func(alias string, r R.Route) { if err := r.Stop(); err.HasError() { errors.Add(err.Subject(r)) nFailed++ @@ -126,20 +125,24 @@ func (p *Provider) StopAllRoutes() E.NestedError { } }) p.l.Debugf("%d routes stopped, %d failed", nStopped, nFailed) - return errors.Build() + return } -func (p *Provider) ReloadRoutes() { - select { - case p.reloadReqCh <- struct{}{}: - // Successfully sent reload request - default: - // Reload request already in progress, ignore this request +func (p *Provider) RangeRoutes(do func(string, R.Route)) { + p.routes.RangeAll(do) +} + +func (p *Provider) GetRoute(alias string) (R.Route, bool) { + return p.routes.Load(alias) +} + +func (p *Provider) LoadRoutes() E.NestedError { + routes, err := p.LoadRoutesImpl() + if err != nil { + return err } -} - -func (p *Provider) GetCurrentRoutes() *R.Routes { - return p.routes + p.routes = routes + return nil } func (p *Provider) watchEvents() { @@ -151,11 +154,15 @@ func (p *Provider) watchEvents() { case <-p.watcherCtx.Done(): return case event, ok := <-events: - if !ok { + if !ok { // channel closed return } - l.Info(event) - p.ReloadRoutes() + res := p.OnEvent(event, p.routes) + l.Infof("%s event %q", event.Type, event) + l.Infof("%d route added, %d routes removed", res.nAdded, res.nRemoved) + if res.err.HasError() { + l.Error(res.err) + } case err, ok := <-errs: if !ok { return @@ -167,50 +174,3 @@ func (p *Provider) watchEvents() { } } } - -func (p *Provider) processReloadRequests() { - for range p.reloadReqCh { - // prevent busy loop caused by a container - // repeating crashing and restarting - select { - case p.cooldownCh <- struct{}{}: - p.l.Info("Starting to reload routes") - nRoutes := p.routes.Size() - - p.StopAllRoutes() - p.loadRoutes() - p.StartAllRoutes() - - p.l.Infof("Routes reloaded (%d -> %d)", nRoutes, p.routes.Size()) - - go func() { - time.Sleep(reloadCooldown) - <-p.cooldownCh - }() - default: - } - } -} - -func (p *Provider) loadRoutes() E.NestedError { - entries, err := p.GetProxyEntries() - - if err.HasError() { - p.l.Warn(err.Subject(p)) - } - p.routes = R.NewRoutes() - - errors := E.NewBuilder("errors loading routes from %s", p) - entries.EachKV(func(a string, e *M.ProxyEntry) { - e.Alias = a - r, err := R.NewRoute(e) - if err.HasError() { - errors.Add(err.Subject(a)) - } else { - p.routes.Set(a, r) - } - }) - return errors.Build() -} - -const reloadCooldown = 50 * time.Millisecond diff --git a/src/proxy/reverse_proxy_mod.go b/src/proxy/reverse_proxy_mod.go index e5d23c5e..811e2202 100644 --- a/src/proxy/reverse_proxy_mod.go +++ b/src/proxy/reverse_proxy_mod.go @@ -207,7 +207,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) { // } // // TODO: headers in ModifyResponse -func NewReverseProxy(target *url.URL, transport *http.Transport, entry *Entry) *ReverseProxy { +func NewReverseProxy(target *url.URL, transport http.RoundTripper, entry *ReverseProxyEntry) *ReverseProxy { // check on init rather than on request var setHeaders = func(r *http.Request) {} var hideHeaders = func(r *http.Request) {} diff --git a/src/route/http_route.go b/src/route/http_route.go index 6b6bbe4b..c8b2b72b 100755 --- a/src/route/http_route.go +++ b/src/route/http_route.go @@ -2,8 +2,8 @@ package route import ( "crypto/tls" - "fmt" "net" + "sync" "time" "net/http" @@ -11,6 +11,7 @@ import ( "strings" "github.com/sirupsen/logrus" + "github.com/yusing/go-proxy/docker/idlewatcher" E "github.com/yusing/go-proxy/error" P "github.com/yusing/go-proxy/proxy" PT "github.com/yusing/go-proxy/proxy/fields" @@ -23,57 +24,65 @@ type ( TargetURL *URL `json:"target_url"` PathPatterns PT.PathPatterns `json:"path_patterns"` + entry *P.ReverseProxyEntry mux *http.ServeMux handler *P.ReverseProxy + + regIdleWatcher func() E.NestedError + unregIdleWatcher func() } URL url.URL - PathKey = PT.PathPattern SubdomainKey = PT.Alias ) -var httpRoutes = F.NewMap[SubdomainKey, *HTTPRoute]() +func NewHTTPRoute(entry *P.ReverseProxyEntry) (*HTTPRoute, E.NestedError) { + var trans http.RoundTripper + var regIdleWatcher func() E.NestedError + var unregIdleWatcher func() -func NewHTTPRoute(entry *P.Entry) (*HTTPRoute, E.NestedError) { - var tr *http.Transport if entry.NoTLSVerify { - tr = transportNoTLS + trans = transportNoTLS } else { - tr = transport + trans = transport } - rp := P.NewReverseProxy(entry.URL, tr, entry) + rp := P.NewReverseProxy(entry.URL, trans, entry) - httpRoutes.Lock() - defer httpRoutes.Unlock() - - var r *HTTPRoute - r, ok := httpRoutes.UnsafeGet(entry.Alias) - if !ok { - r = &HTTPRoute{ - Alias: entry.Alias, - TargetURL: (*URL)(entry.URL), - PathPatterns: entry.PathPatterns, - handler: rp, + if entry.UseIdleWatcher() { + regIdleWatcher = func() E.NestedError { + watcher, err := idlewatcher.Register(entry) + if err.HasError() { + return err + } + // patch round-tripper + rp.Transport = watcher.PatchRoundTripper(trans) + return nil } - httpRoutes.UnsafeSet(entry.Alias, r) - } - - rewrite := rp.Rewrite - - if logrus.GetLevel() == logrus.DebugLevel { - l := logrus.WithField("alias", entry.Alias) - - rp.Rewrite = func(pr *P.ProxyRequest) { - l.Debug("request URL: ", pr.In.Host, pr.In.URL.Path) - l.Debug("request headers: ", pr.In.Header) - rewrite(pr) + unregIdleWatcher = func() { + idlewatcher.Unregister(entry.ContainerName) + rp.Transport = trans } - } else { - rp.Rewrite = rewrite } - return r, E.Nil() + httpRoutesMu.Lock() + defer httpRoutesMu.Unlock() + + _, exists := httpRoutes.Load(entry.Alias) + if exists { + return nil, E.AlreadyExist("HTTPRoute alias", entry.Alias) + } + + r := &HTTPRoute{ + Alias: entry.Alias, + TargetURL: (*URL)(entry.URL), + PathPatterns: entry.PathPatterns, + entry: entry, + handler: rp, + regIdleWatcher: regIdleWatcher, + unregIdleWatcher: unregIdleWatcher, + } + return r, nil } func (r *HTTPRoute) String() string { @@ -81,18 +90,35 @@ func (r *HTTPRoute) String() string { } func (r *HTTPRoute) Start() E.NestedError { + httpRoutesMu.Lock() + defer httpRoutesMu.Unlock() + + if r.regIdleWatcher != nil { + if err := r.regIdleWatcher(); err.HasError() { + return err + } + } + r.mux = http.NewServeMux() for _, p := range r.PathPatterns { r.mux.HandleFunc(string(p), r.handler.ServeHTTP) } - httpRoutes.Set(r.Alias, r) - return E.Nil() + + httpRoutes.Store(r.Alias, r) + return nil } func (r *HTTPRoute) Stop() E.NestedError { + httpRoutesMu.Lock() + defer httpRoutesMu.Unlock() + + if r.unregIdleWatcher != nil { + r.unregIdleWatcher() + } + r.mux = nil httpRoutes.Delete(r.Alias) - return E.Nil() + return nil } func (u *URL) String() string { @@ -104,27 +130,26 @@ func (u *URL) MarshalText() (text []byte, err error) { } func ProxyHandler(w http.ResponseWriter, r *http.Request) { - mux, err := findMux(r.Host, PathKey(r.URL.Path)) + mux, err := findMux(r.Host) if err != nil { err = E.Failure("request"). Subjectf("%s %s%s", r.Method, r.Host, r.URL.Path). With(err) - http.Error(w, err.Error(), http.StatusNotFound) + http.Error(w, err.String(), http.StatusNotFound) logrus.Error(err) return } mux.ServeHTTP(w, r) } -func findMux(host string, path PathKey) (*http.ServeMux, error) { +func findMux(host string) (*http.ServeMux, E.NestedError) { sd := strings.Split(host, ".")[0] - if r, ok := httpRoutes.UnsafeGet(PT.Alias(sd)); ok { + if r, ok := httpRoutes.Load(PT.Alias(sd)); ok { return r.mux, nil } - return nil, E.NotExists("route", fmt.Sprintf("subdomain: %s, path: %s", sd, path)) + return nil, E.NotExist("route", sd) } -// TODO: default + per proxy var ( transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -135,10 +160,12 @@ var ( MaxIdleConns: 1000, MaxIdleConnsPerHost: 1000, } - transportNoTLS = func() *http.Transport { var clone = transport.Clone() clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} return clone }() + + httpRoutes = F.NewMapOf[SubdomainKey, *HTTPRoute]() + httpRoutesMu sync.Mutex ) diff --git a/src/route/route.go b/src/route/route.go index 5c2d60da..e3c92092 100755 --- a/src/route/route.go +++ b/src/route/route.go @@ -1,6 +1,9 @@ package route import ( + "fmt" + "net/url" + E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" P "github.com/yusing/go-proxy/proxy" @@ -9,27 +12,81 @@ import ( type ( Route interface { + RouteImpl + Entry() *M.ProxyEntry + Type() RouteType + URL() *url.URL + } + Routes = F.Map[string, Route] + RouteType string + + RouteImpl interface { Start() E.NestedError Stop() E.NestedError String() string } - Routes = F.Map[string, Route] + route struct { + RouteImpl + type_ RouteType + entry *M.ProxyEntry + } +) + +const ( + RouteTypeStream RouteType = "stream" + RouteTypeReverseProxy RouteType = "reverse_proxy" ) // function alias -var NewRoutes = F.NewMap[string, Route] +var NewRoutes = F.NewMapOf[string, Route] func NewRoute(en *M.ProxyEntry) (Route, E.NestedError) { - entry, err := P.NewEntry(en) + rt, err := P.ValidateEntry(en) if err.HasError() { return nil, err } - switch e := entry.(type) { + + var t RouteType + + switch e := rt.(type) { case *P.StreamEntry: - return NewStreamRoute(e) - case *P.Entry: - return NewHTTPRoute(e) + rt, err = NewStreamRoute(e) + t = RouteTypeStream + case *P.ReverseProxyEntry: + rt, err = NewHTTPRoute(e) + t = RouteTypeReverseProxy default: panic("bug: should not reach here") } + return &route{RouteImpl: rt.(RouteImpl), entry: en, type_: t}, err +} + +func (rt *route) Entry() *M.ProxyEntry { + return rt.entry +} + +func (rt *route) Type() RouteType { + return rt.type_ +} + +func (rt *route) URL() *url.URL { + url, _ := url.Parse(fmt.Sprintf("%s://%s", rt.entry.Scheme, rt.entry.Host)) + return url +} + +func FromEntries(entries M.ProxyEntries) (Routes, E.NestedError) { + b := E.NewBuilder("errors in routes") + + routes := NewRoutes() + entries.RangeAll(func(alias string, entry *M.ProxyEntry) { + entry.Alias = alias + r, err := NewRoute(entry) + if err.HasError() { + b.Add(err.Subject(alias)) + } else { + routes.Store(alias, r) + } + }) + + return routes, b.Build() } diff --git a/src/route/stream_route.go b/src/route/stream_route.go index f11abdb9..56b1a83c 100755 --- a/src/route/stream_route.go +++ b/src/route/stream_route.go @@ -12,7 +12,7 @@ import ( ) type StreamRoute struct { - *P.StreamEntry + P.StreamEntry StreamImpl `json:"-"` wg sync.WaitGroup @@ -35,7 +35,7 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { return nil, E.Unsupported("scheme", fmt.Sprintf("%v -> %v", entry.Scheme.ListeningScheme, entry.Scheme.ProxyScheme)) } base := &StreamRoute{ - StreamEntry: entry, + StreamEntry: *entry, wg: sync.WaitGroup{}, connCh: make(chan any), } @@ -45,11 +45,11 @@ func NewStreamRoute(entry *P.StreamEntry) (*StreamRoute, E.NestedError) { base.StreamImpl = NewUDPRoute(base) } base.l = logrus.WithField("route", base.StreamImpl) - return base, E.Nil() + return base, nil } func (r *StreamRoute) String() string { - return fmt.Sprintf("%s-stream: %s", r.Scheme, r.Alias) + return fmt.Sprintf("%s stream: %s", r.Scheme, r.Alias) } func (r *StreamRoute) Start() E.NestedError { @@ -59,13 +59,13 @@ func (r *StreamRoute) Start() E.NestedError { r.stopCh = make(chan struct{}, 1) r.wg.Wait() if err := r.Setup(); err != nil { - return E.Failure("setup").With(err) + return E.FailWith("setup", err) } r.started.Store(true) r.wg.Add(2) go r.grAcceptConnections() go r.grHandleConnections() - return E.Nil() + return nil } func (r *StreamRoute) Stop() E.NestedError { @@ -88,7 +88,7 @@ func (r *StreamRoute) Stop() E.NestedError { case <-time.After(streamStopListenTimeout): l.Error("timed out waiting for connections") } - return E.Nil() + return nil } func (r *StreamRoute) grAcceptConnections() { diff --git a/src/route/tcp_route.go b/src/route/tcp_route.go index 6e80159f..2baab418 100755 --- a/src/route/tcp_route.go +++ b/src/route/tcp_route.go @@ -65,9 +65,10 @@ func (route *TCPRoute) Handle(c any) error { }() route.mu.Lock() + defer route.mu.Unlock() + pipe := U.NewBidirectionalPipe(pipeCtx, clientConn, serverConn) route.pipe = append(route.pipe, pipe) - route.mu.Unlock() return pipe.Start() } @@ -78,7 +79,7 @@ func (route *TCPRoute) CloseListeners() { route.listener.Close() route.listener = nil for _, pipe := range route.pipe { - if err := pipe.Stop(); err.HasError() { + if err := pipe.Stop(); err != nil { route.l.Error(err) } } diff --git a/src/utils/functional/map.go b/src/utils/functional/map.go index 1832c886..812cbe6b 100644 --- a/src/utils/functional/map.go +++ b/src/utils/functional/map.go @@ -1,229 +1,116 @@ package functional import ( - "context" - "sync" - + "github.com/puzpuzpuz/xsync/v3" "gopkg.in/yaml.v3" E "github.com/yusing/go-proxy/error" ) type Map[KT comparable, VT any] struct { - m map[KT]VT - defVals map[KT]VT - sync.RWMutex + *xsync.MapOf[KT, VT] } -// NewMap creates a new Map with the given map as its initial values. -// -// Parameters: -// - dv: optional default values for the Map -// -// Return: -// - *Map[KT, VT]: a pointer to the newly created Map. -func NewMap[KT comparable, VT any](dv ...map[KT]VT) *Map[KT, VT] { - return NewMapFrom(make(map[KT]VT), dv...) +func NewMapOf[KT comparable, VT any](options ...func(*xsync.MapConfig)) Map[KT, VT] { + return Map[KT, VT]{xsync.NewMapOf[KT, VT](options...)} } -// NewMapOf creates a new Map with the given map as its initial values. -// -// Type parameters: -// - M: type for the new map. -// -// Parameters: -// - dv: optional default values for the Map -// -// Return: -// - *Map[KT, VT]: a pointer to the newly created Map. -func NewMapOf[M Map[KT, VT], KT comparable, VT any](dv ...map[KT]VT) *Map[KT, VT] { - return NewMapFrom(make(map[KT]VT), dv...) -} - -// NewMapFrom creates a new Map with the given map as its initial values. -// -// Parameters: -// - from: a map of type KT to VT, which will be the initial values of the Map. -// - dv: optional default values for the Map -// -// Return: -// - *Map[KT, VT]: a pointer to the newly created Map. -func NewMapFrom[KT comparable, VT any](from map[KT]VT, dv ...map[KT]VT) *Map[KT, VT] { - if len(dv) > 0 { - return &Map[KT, VT]{m: from, defVals: dv[0]} +func NewMapFrom[KT comparable, VT any](m map[KT]VT) (res Map[KT, VT]) { + res = NewMapOf[KT, VT](xsync.WithPresize(len(m))) + for k, v := range m { + res.Store(k, v) } - return &Map[KT, VT]{m: from} + return } -func (m *Map[KT, VT]) Set(key KT, value VT) { - m.Lock() - m.m[key] = value - m.Unlock() -} +func MapFind[KT comparable, VT, CT any](m Map[KT, VT], criteria func(VT) (CT, bool)) (_ CT) { + result := make(chan CT, 1) -func (m *Map[KT, VT]) Get(key KT) VT { - m.RLock() - defer m.RUnlock() - value, ok := m.m[key] - if !ok && m.defVals != nil { - return m.defVals[key] - } - return value -} - -// Find searches for the first element in the map that satisfies the given criteria. -// -// Parameters: -// - criteria: a function that takes a value of type VT and returns a tuple of any type and a boolean. -// -// Return: -// - any: the first value that satisfies the criteria, or nil if no match is found. -func (m *Map[KT, VT]) Find(criteria func(VT) (any, bool)) any { - m.RLock() - defer m.RUnlock() - - result := make(chan any) - wg := sync.WaitGroup{} - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - for _, v := range m.m { - wg.Add(1) - go func(val VT) { - defer wg.Done() - if value, ok := criteria(val); ok { - select { - case result <- value: - cancel() // Cancel other goroutines if a result is found - case <-ctx.Done(): // If already cancelled - return - } + m.Range(func(key KT, value VT) bool { + select { + case <-result: // already have a result + return false // stop iteration + default: + if got, ok := criteria(value); ok { + result <- got + return false } - }(v) - } - - go func() { - wg.Wait() - close(result) - }() - - // The first valid match, if any - select { - case res, ok := <-result: - if ok { - return res + return true } - case <-ctx.Done(): + }) + + select { + case v := <-result: + return v + default: + return } - - return nil // Return nil if no matches found } -func (m *Map[KT, VT]) UnsafeGet(key KT) (VT, bool) { - value, ok := m.m[key] - return value, ok -} - -func (m *Map[KT, VT]) UnsafeSet(key KT, value VT) { - m.m[key] = value -} - -func (m *Map[KT, VT]) Delete(key KT) { - m.Lock() - delete(m.m, key) - m.Unlock() -} - -func (m *Map[KT, VT]) UnsafeDelete(key KT) { - delete(m.m, key) -} - -// MergeWith merges the contents of another Map[KT, VT] -// into the current Map[KT, VT] and -// returns a map that were duplicated. +// MergeFrom add contents from another `Map`, ignore duplicated keys // // Parameters: -// - other: a pointer to another Map[KT, VT] to be merged into the current Map[KT, VT]. +// - other: `Map` of values to add from // // Return: -// - Map[KT, VT]: a map of key-value pairs that were duplicated during the merge. -func (m *Map[KT, VT]) MergeWith(other *Map[KT, VT]) Map[KT, VT] { - dups := make(map[KT]VT) +// - Map: a `Map` of duplicated keys-value pairs +func (m Map[KT, VT]) MergeFrom(other Map[KT, VT]) Map[KT, VT] { + dups := NewMapOf[KT, VT]() - m.Lock() - for k, v := range other.m { - if _, isDup := m.m[k]; !isDup { - m.m[k] = v + other.Range(func(k KT, v VT) bool { + if _, ok := m.Load(k); ok { + dups.Store(k, v) } else { - dups[k] = v + m.Store(k, v) } - } - m.Unlock() - return Map[KT, VT]{m: dups} + return true + }) + return dups } -func (m *Map[KT, VT]) Clear() { - m.Lock() - m.m = make(map[KT]VT) - m.Unlock() +func (m Map[KT, VT]) RangeAll(do func(k KT, v VT)) { + m.Range(func(k KT, v VT) bool { + do(k, v) + return true + }) } -func (m *Map[KT, VT]) Size() int { - m.RLock() - defer m.RUnlock() - return len(m.m) +func (m Map[KT, VT]) RemoveAll(criteria func(VT) bool) { + m.Range(func(k KT, v VT) bool { + if criteria(v) { + m.Delete(k) + } + return true + }) } -func (m *Map[KT, VT]) Contains(key KT) bool { - m.RLock() - _, ok := m.m[key] - m.RUnlock() +func (m Map[KT, VT]) Has(k KT) bool { + _, ok := m.Load(k) return ok } -func (m *Map[KT, VT]) Clone() *Map[KT, VT] { - m.RLock() - defer m.RUnlock() - clone := make(map[KT]VT, len(m.m)) - for k, v := range m.m { - clone[k] = v +func (m Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError { + if m.Size() != 0 { + return E.FailedWhy("unmarshal from yaml", "map is not empty") } - return &Map[KT, VT]{m: clone, defVals: m.defVals} -} - -func (m *Map[KT, VT]) EachKV(fn func(k KT, v VT)) { - m.Lock() - for k, v := range m.m { - fn(k, v) + tmp := make(map[KT]VT) + if err := E.From(yaml.Unmarshal(data, tmp)); err.HasError() { + return err } - m.Unlock() -} - -func (m *Map[KT, VT]) Each(fn func(v VT)) { - m.Lock() - for _, v := range m.m { - fn(v) + for k, v := range tmp { + m.Store(k, v) } - m.Unlock() + return nil } -func (m *Map[KT, VT]) EachParallel(fn func(v VT)) { - m.Lock() - ParallelForEachValue(m.m, fn) - m.Unlock() -} - -func (m *Map[KT, VT]) EachKVParallel(fn func(k KT, v VT)) { - m.Lock() - ParallelForEachKV(m.m, fn) - m.Unlock() -} - -func (m *Map[KT, VT]) UnmarshalFromYAML(data []byte) E.NestedError { - return E.From(yaml.Unmarshal(data, m.m)) -} - -func (m *Map[KT, VT]) Iterator() map[KT]VT { - return m.m +func (m Map[KT, VT]) String() string { + tmp := make(map[KT]VT, m.Size()) + m.RangeAll(func(k KT, v VT) { + tmp[k] = v + }) + data, err := yaml.Marshal(tmp) + if err != nil { + return err.Error() + } + return string(data) } diff --git a/src/utils/functional/map_test.go b/src/utils/functional/map_test.go new file mode 100644 index 00000000..031993e6 --- /dev/null +++ b/src/utils/functional/map_test.go @@ -0,0 +1,75 @@ +package functional_test + +import ( + "testing" + + . "github.com/yusing/go-proxy/utils/functional" + . "github.com/yusing/go-proxy/utils/testing" +) + +func TestNewMapFrom(t *testing.T) { + m := NewMapFrom(map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + ExpectEqual(t, m.Size(), 3) + ExpectTrue(t, m.Has("a")) + ExpectTrue(t, m.Has("b")) + ExpectTrue(t, m.Has("c")) +} + +func TestMapFind(t *testing.T) { + m := NewMapFrom(map[string]map[string]int{ + "a": { + "a": 1, + }, + "b": { + "a": 1, + "b": 2, + }, + "c": { + "b": 2, + "c": 3, + }, + }) + res := MapFind(m, func(inner map[string]int) (int, bool) { + if _, ok := inner["c"]; ok && inner["c"] == 3 { + return inner["c"], true + } + return 0, false + }) + ExpectEqual(t, res, 3) +} + +func TestMergeFrom(t *testing.T) { + m1 := NewMapFrom(map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + }) + m2 := NewMapFrom(map[string]int{ + "a": 1, + "c": 123, + "e": 456, + "f": 6, + }) + dup := m1.MergeFrom(m2) + + ExpectEqual(t, m1.Size(), 6) + ExpectTrue(t, m1.Has("e")) + ExpectTrue(t, m1.Has("f")) + c, _ := m1.Load("c") + d, _ := m1.Load("d") + e, _ := m1.Load("e") + f, _ := m1.Load("f") + ExpectEqual(t, c, 3) + ExpectEqual(t, d, 4) + ExpectEqual(t, e, 456) + ExpectEqual(t, f, 6) + + ExpectEqual(t, dup.Size(), 2) + ExpectTrue(t, dup.Has("a")) + ExpectTrue(t, dup.Has("c")) +} diff --git a/src/utils/io.go b/src/utils/io.go index 988a1e21..10f4fd7c 100644 --- a/src/utils/io.go +++ b/src/utils/io.go @@ -10,15 +10,8 @@ import ( E "github.com/yusing/go-proxy/error" ) +// TODO: move to "utils/io" type ( - Reader interface { - Read() ([]byte, E.NestedError) - } - - StdReader struct { - r Reader - } - FileReader struct { Path string } @@ -29,13 +22,6 @@ type ( closed atomic.Bool } - StdReadCloser struct { - r *ReadCloser - } - - ByteReader []byte - NewByteReader = ByteReader - Pipe struct { r ReadCloser w io.WriteCloser @@ -44,49 +30,25 @@ type ( } BidirectionalPipe struct { - pSrcDst Pipe - pDstSrc Pipe + pSrcDst *Pipe + pDstSrc *Pipe } ) -func NewFileReader(path string) *FileReader { - return &FileReader{Path: path} -} - -func (r StdReader) Read() ([]byte, error) { - return r.r.Read() -} - -func (r *FileReader) Read() ([]byte, E.NestedError) { - return E.Check(os.ReadFile(r.Path)) -} - -func (r ByteReader) Read() ([]byte, E.NestedError) { - return r, E.Nil() -} - -func (r *ReadCloser) Read(p []byte) (int, E.NestedError) { +func (r *ReadCloser) Read(p []byte) (int, error) { select { case <-r.ctx.Done(): - return 0, E.From(r.ctx.Err()) + return 0, r.ctx.Err() default: - return E.Check(r.r.Read(p)) + return r.r.Read(p) } } -func (r *ReadCloser) Close() E.NestedError { +func (r *ReadCloser) Close() error { if r.closed.Load() { - return E.Nil() + return nil } r.closed.Store(true) - return E.From(r.r.Close()) -} - -func (r StdReadCloser) Read(p []byte) (int, error) { - return r.r.Read(p) -} - -func (r StdReadCloser) Close() error { return r.r.Close() } @@ -100,35 +62,35 @@ func NewPipe(ctx context.Context, r io.ReadCloser, w io.WriteCloser) *Pipe { } } -func (p *Pipe) Start() E.NestedError { - return Copy(p.ctx, p.w, &StdReadCloser{&p.r}) +func (p *Pipe) Start() error { + return Copy(p.ctx, p.w, &p.r) } -func (p *Pipe) Stop() E.NestedError { +func (p *Pipe) Stop() error { p.cancel() - return E.Join("error stopping pipe", p.r.Close(), p.w.Close()) + return E.JoinE("error stopping pipe", p.r.Close(), p.w.Close()).Error() } -func (p *Pipe) Write(b []byte) (int, E.NestedError) { - return E.Check(p.w.Write(b)) +func (p *Pipe) Write(b []byte) (int, error) { + return p.w.Write(b) } func NewBidirectionalPipe(ctx context.Context, rw1 io.ReadWriteCloser, rw2 io.ReadWriteCloser) *BidirectionalPipe { return &BidirectionalPipe{ - pSrcDst: *NewPipe(ctx, rw1, rw2), - pDstSrc: *NewPipe(ctx, rw2, rw1), + pSrcDst: NewPipe(ctx, rw1, rw2), + pDstSrc: NewPipe(ctx, rw2, rw1), } } func NewBidirectionalPipeIntermediate(ctx context.Context, listener io.ReadCloser, client io.ReadWriteCloser, target io.ReadWriteCloser) *BidirectionalPipe { return &BidirectionalPipe{ - pSrcDst: *NewPipe(ctx, listener, client), - pDstSrc: *NewPipe(ctx, client, target), + pSrcDst: NewPipe(ctx, listener, client), + pDstSrc: NewPipe(ctx, client, target), } } -func (p *BidirectionalPipe) Start() E.NestedError { - errCh := make(chan E.NestedError, 2) +func (p *BidirectionalPipe) Start() error { + errCh := make(chan error, 2) go func() { errCh <- p.pSrcDst.Start() }() @@ -136,34 +98,34 @@ func (p *BidirectionalPipe) Start() E.NestedError { errCh <- p.pDstSrc.Start() }() for err := range errCh { - if err.HasError() { + if err != nil { return err } } - return E.Nil() + return nil } -func (p *BidirectionalPipe) Stop() E.NestedError { - return E.Join("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()) +func (p *BidirectionalPipe) Stop() error { + return E.JoinE("error stopping pipe", p.pSrcDst.Stop(), p.pDstSrc.Stop()).Error() } -func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) E.NestedError { - _, err := io.Copy(dst, StdReadCloser{&ReadCloser{ctx: ctx, r: src}}) - return E.From(err) +func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) error { + _, err := io.Copy(dst, &ReadCloser{ctx: ctx, r: src}) + return err } func LoadJson[T any](path string, pointer *T) E.NestedError { - data, err := os.ReadFile(path) - if err != nil { - return E.From(err) + data, err := E.Check(os.ReadFile(path)) + if err.HasError() { + return err } return E.From(json.Unmarshal(data, pointer)) } func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError { - data, err := json.Marshal(pointer) - if err != nil { - return E.From(err) + data, err := E.Check(json.Marshal(pointer)) + if err.HasError() { + return err } return E.From(os.WriteFile(path, data, perm)) } diff --git a/src/utils/reflection.go b/src/utils/reflection.go index e0a1f80d..8eb525ca 100644 --- a/src/utils/reflection.go +++ b/src/utils/reflection.go @@ -20,5 +20,5 @@ func SetFieldFromSnake[T, VT any](obj *T, field string, value VT) E.NestedError return E.Invalid("field", field) } prop.Set(reflect.ValueOf(value)) - return E.Nil() + return nil } diff --git a/src/utils/serialization.go b/src/utils/serialization.go index f62f1bd6..d6ef4974 100644 --- a/src/utils/serialization.go +++ b/src/utils/serialization.go @@ -17,45 +17,26 @@ func ValidateYaml(schema *jsonschema.Schema, data []byte) E.NestedError { err := yaml.Unmarshal(data, &i) if err != nil { - return E.Failure("unmarshal yaml").With(err) + return E.FailWith("unmarshal yaml", err) } m, err := json.Marshal(i) if err != nil { - return E.Failure("marshal json").With(err) + return E.FailWith("marshal json", err) } err = schema.Validate(bytes.NewReader(m)) if err == nil { - return E.Nil() + return nil } errors := E.NewBuilder("yaml validation error") for _, e := range err.(*jsonschema.ValidationError).Causes { - errors.Add(e) + errors.AddE(e) } return errors.Build() } -// TryJsonStringify converts the given object to a JSON string. -// -// It takes an object of any type and attempts to marshal it into a JSON string. -// If the marshaling is successful, the JSON string is returned. -// If the marshaling fails, the object is converted to a string using fmt.Sprint and returned. -// -// Parameters: -// - o: The object to be converted to a JSON string. -// -// Return type: -// - string: The JSON string representation of the object. -func TryJsonStringify(o any) string { - b, err := json.Marshal(o) - if err != nil { - return fmt.Sprint(o) - } - return string(b) -} - // Serialize converts the given data into a map[string]any representation. // // It uses reflection to inspect the data type and handle different kinds of data. @@ -123,7 +104,7 @@ func Serialize(data any) (SerializedObject, E.NestedError) { return nil, E.Unsupported("type", value.Kind()) } - return result, E.Nil() + return result, nil } func Deserialize(src map[string]any, target any) E.NestedError { @@ -166,7 +147,7 @@ func Deserialize(src map[string]any, target any) E.NestedError { propNew := reflect.New(propType.Elem()) err := Deserialize(vSerialized, propNew.Interface()) if err.HasError() { - return E.Failure("set field").With(k).With(err) + return E.Failure("set field").With(err).Subject(k) } prop.Set(propNew) default: @@ -180,7 +161,15 @@ func Deserialize(src map[string]any, target any) E.NestedError { } } - return E.Nil() + return nil +} + +func DeserializeJson(j map[string]string, target any) E.NestedError { + data, err := E.Check(json.Marshal(j)) + if err.HasError() { + return err + } + return E.From(json.Unmarshal(data, target)) } func toLowerNoSnake(s string) string { diff --git a/src/utils/string.go b/src/utils/string.go new file mode 100644 index 00000000..7e62e89d --- /dev/null +++ b/src/utils/string.go @@ -0,0 +1,11 @@ +package utils + +import "strings" + +func CommaSeperatedList(s string) []string { + res := strings.Split(s, ",") + for i, part := range res { + res[i] = strings.TrimSpace(part) + } + return res +} diff --git a/src/utils/testing.go b/src/utils/testing/testing.go similarity index 67% rename from src/utils/testing.go rename to src/utils/testing/testing.go index 62c93876..f15726c1 100644 --- a/src/utils/testing.go +++ b/src/utils/testing/testing.go @@ -3,25 +3,23 @@ package utils import ( "reflect" "testing" - - E "github.com/yusing/go-proxy/error" ) func ExpectNoError(t *testing.T, err error) { t.Helper() - var noError bool - switch t := err.(type) { - case E.NestedError: - noError = t.NoError() - default: - noError = err == nil - } - if !noError { + if err != nil && !reflect.ValueOf(err).IsNil() { t.Errorf("expected err=nil, got %s", err.Error()) } } -func ExpectEqual(t *testing.T, got, want any) { +func ExpectEqual[T comparable](t *testing.T, got T, want T) { + t.Helper() + if got != want { + t.Errorf("expected:\n%v, got\n%v", want, got) + } +} + +func ExpectDeepEqual[T any](t *testing.T, got T, want T) { t.Helper() if !reflect.DeepEqual(got, want) { t.Errorf("expected:\n%v, got\n%v", want, got) @@ -47,7 +45,7 @@ func ExpectType[T any](t *testing.T, got any) T { tExpect := reflect.TypeFor[T]() _, ok := got.(T) if !ok { - t.Errorf("expected type %T, got %T", tExpect, got) + t.Errorf("expected type %s, got %T", tExpect, got) } return got.(T) } diff --git a/src/watcher/docker_watcher.go b/src/watcher/docker_watcher.go index f87ae720..72f8f97b 100644 --- a/src/watcher/docker_watcher.go +++ b/src/watcher/docker_watcher.go @@ -2,13 +2,13 @@ package watcher import ( "context" - "fmt" "time" "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" D "github.com/yusing/go-proxy/docker" E "github.com/yusing/go-proxy/error" + . "github.com/yusing/go-proxy/watcher/event" ) type DockerWatcher struct { @@ -34,13 +34,14 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest if err.NoError() { break } - errCh <- E.From(err) + errCh <- err time.Sleep(1 * time.Second) } if err.HasError() { errCh <- E.Failure("connecting to docker") return } + defer cl.Close() cEventCh, cErrCh := cl.Events(ctx, dwOptions) started <- struct{}{} @@ -58,13 +59,16 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest case events.ActionStart: Action = ActionCreated case events.ActionDie: - Action = ActionDeleted + Action = ActionStopped default: // NOTE: should not happen Action = ActionModified } eventCh <- Event{ - ActorName: fmt.Sprintf("container %q", msg.Actor.Attributes["name"]), - Action: Action, + Type: EventTypeDocker, + ActorID: msg.Actor.ID, + ActorAttributes: msg.Actor.Attributes, // labels + ActorName: msg.Actor.Attributes["name"], + Action: Action, } case err := <-cErrCh: if err == nil { diff --git a/src/watcher/event.go b/src/watcher/event.go deleted file mode 100644 index ba27d2e2..00000000 --- a/src/watcher/event.go +++ /dev/null @@ -1,26 +0,0 @@ -package watcher - -import "fmt" - -type ( - Event struct { - ActorName string - Action Action - } - Action string -) - -const ( - ActionModified Action = "MODIFIED" - ActionCreated Action = "CREATED" - ActionStarted Action = "STARTED" - ActionDeleted Action = "DELETED" -) - -func (e Event) String() string { - return fmt.Sprintf("%s %s", e.ActorName, e.Action) -} - -func (a Action) IsDelete() bool { - return a == ActionDeleted -} diff --git a/src/watcher/event/event.go b/src/watcher/event/event.go new file mode 100644 index 00000000..8dc1213f --- /dev/null +++ b/src/watcher/event/event.go @@ -0,0 +1,34 @@ +package event + +import "fmt" + +type ( + Event struct { + Type EventType + ActorName string + ActorID string + ActorAttributes map[string]string + Action Action + } + Action string + EventType string +) + +const ( + ActionModified Action = "modified" + ActionCreated Action = "created" + ActionStarted Action = "started" + ActionDeleted Action = "deleted" + ActionStopped Action = "stopped" + + EventTypeDocker EventType = "docker" + EventTypeFile EventType = "file" +) + +func (e Event) String() string { + return fmt.Sprintf("%s %s", e.ActorName, e.Action) +} + +func (a Action) IsDelete() bool { + return a == ActionDeleted +} diff --git a/src/watcher/file_watcher.go b/src/watcher/file_watcher.go index f4cda622..1da16a24 100644 --- a/src/watcher/file_watcher.go +++ b/src/watcher/file_watcher.go @@ -6,6 +6,7 @@ import ( "github.com/yusing/go-proxy/common" E "github.com/yusing/go-proxy/error" + . "github.com/yusing/go-proxy/watcher/event" ) type fileWatcher struct { diff --git a/src/watcher/file_watcher_helper.go b/src/watcher/file_watcher_helper.go index daaa32fe..98f9802a 100644 --- a/src/watcher/file_watcher_helper.go +++ b/src/watcher/file_watcher_helper.go @@ -9,6 +9,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/error" + . "github.com/yusing/go-proxy/watcher/event" ) type fileWatcherHelper struct { @@ -93,7 +94,10 @@ func (h *fileWatcherHelper) start() { continue } - msg := Event{ActorName: w.filename} + msg := Event{ + Type: EventTypeFile, + ActorName: w.filename, + } switch { case event.Has(fsnotify.Create): msg.Action = ActionCreated diff --git a/src/watcher/watcher.go b/src/watcher/watcher.go index 59c46cd7..9869ee19 100644 --- a/src/watcher/watcher.go +++ b/src/watcher/watcher.go @@ -4,6 +4,7 @@ import ( "context" E "github.com/yusing/go-proxy/error" + . "github.com/yusing/go-proxy/watcher/event" ) type Watcher interface {