diff --git a/Dockerfile b/Dockerfile index f17aa7f6..626785e8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ENV GOCACHE=/root/.cache/go-build WORKDIR /src RUN --mount=type=cache,target="/go/pkg/mod" \ --mount=type=cache,target="/root/.cache/go-build" \ - go mod download + go mod download && \ CGO_ENABLED=0 GOOS=linux go build -pgo=auto -o go-proxy github.com/yusing/go-proxy FROM alpine:3.20 diff --git a/Makefile b/Makefile index 99eb35ca..ed5eae18 100755 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ get: cd src && go get -u && go mod tidy && cd .. debug: - make build && GOPROXY_DEBUG=1 bin/go-proxy + make build && sudo GOPROXY_DEBUG=1 bin/go-proxy archive: git archive HEAD -o ../go-proxy-$$(date +"%Y%m%d%H%M").zip diff --git a/README.md b/README.md index 15fa5f0f..93c3f35c 100755 --- a/README.md +++ b/README.md @@ -26,8 +26,10 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr ## Key Points - Easy to use + - Effortless configuration + - Error messages is clear and detailed - Auto certificate obtaining and renewal (See [Supported DNS Challenge Providers](docs/dns_providers.md)) -- Auto configuration for docker contaienrs +- Auto configuration for docker containers - Auto hot-reload on container state / config file changes - Support HTTP(s), TCP and UDP - Web UI for configuration and monitoring (See [screenshots](https://github.com/yusing/go-proxy-frontend?tab=readme-ov-file#screenshots)) @@ -37,7 +39,7 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr ## Getting Started -1. Setup DNS Records +1. Setup DNS Records, e.g. - A Record: `*.y.z` -> `10.0.10.1` - AAAA Record: `*.y.z` -> `::ffff:a00:a01` @@ -45,18 +47,19 @@ A [lightweight](docs/benchmark_result.md), easy-to-use, and efficient reverse pr 2. Setup `go-proxy` [See here](docs/docker.md) 3. Configure `go-proxy` - - with text editor (i.e. Visual Studio Code) + - with text editor (e.g. Visual Studio Code) - or with web config editor via `http://gp.y.z` [🔼Back to top](#table-of-content) ### Commands line arguments -| Argument | Description | -| ---------- | -------------------------------- | -| empty | start proxy server | -| `validate` | validate config and exit | -| `reload` | trigger a force reload of config | +| Argument | Description | Example | +| ----------- | -------------------------------- | -------------------------- | +| empty | start proxy server | | +| `validate` | validate config and exit | | +| `reload` | trigger a force reload of config | | +| `ls-config` | list config and exit | `go-proxy ls-config \| jq` | **run with `docker exec /app/go-proxy `** diff --git a/docs/add_dns_provider.md b/docs/add_dns_provider.md index 1c049839..31734e10 100644 --- a/docs/add_dns_provider.md +++ b/docs/add_dns_provider.md @@ -7,7 +7,7 @@ ```go var providersGenMap = map[string]ProviderGenerator{ "cloudflare": providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig), - // add here, i.e. + // add here, e.g. "clouddns": providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig), } ``` diff --git a/docs/docker.md b/docs/docker.md index 7ed2244c..d2d8c378 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -172,7 +172,7 @@ service_a: - Container not showing up in proxies list - Please check that either `ports` or label `proxy..port` is declared, i.e. + Please check that either `ports` or label `proxy..port` is declared, e.g. ```yaml services: diff --git a/src/api/v1/file.go b/src/api/v1/file.go index d7b13ea8..6edab084 100644 --- a/src/api/v1/file.go +++ b/src/api/v1/file.go @@ -33,7 +33,7 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { return } content, err := E.Check(io.ReadAll(r.Body)) - if err.IsNotNil() { + if err.HasError() { U.HandleErr(w, r, err) return } @@ -44,13 +44,13 @@ func SetFileContent(w http.ResponseWriter, r *http.Request) { err = provider.Validate(content) } - if err.IsNotNil() { + if err.HasError() { U.HandleErr(w, r, err, http.StatusBadRequest) return } err = E.From(os.WriteFile(path.Join(common.ConfigBasePath, filename), content, 0644)) - if err.IsNotNil() { + if err.HasError() { U.HandleErr(w, r, err) return } diff --git a/src/api/v1/reload.go b/src/api/v1/reload.go index 430e61aa..18011476 100644 --- a/src/api/v1/reload.go +++ b/src/api/v1/reload.go @@ -8,7 +8,7 @@ import ( ) func Reload(cfg *config.Config, w http.ResponseWriter, r *http.Request) { - if err := cfg.Reload(); err.IsNotNil() { + if err := cfg.Reload(); err.HasError() { U.HandleErr(w, r, err) return } diff --git a/src/autocert/config.go b/src/autocert/config.go index 11bd63c7..60076eed 100644 --- a/src/autocert/config.go +++ b/src/autocert/config.go @@ -20,6 +20,9 @@ func NewConfig(cfg *M.AutoCertConfig) *Config { if cfg.KeyPath == "" { cfg.KeyPath = KeyFileDefault } + if cfg.Provider == "" { + cfg.Provider = ProviderLocal + } return (*Config)(cfg) } @@ -36,43 +39,35 @@ func (cfg *Config) GetProvider() (*Provider, E.NestedError) { if cfg.Email == "" { errors.Addf("no email specified") } + // check if provider is implemented + _, ok := providersGenMap[cfg.Provider] + if !ok { + errors.Addf("unknown provider: %q", cfg.Provider) + } } - gen, ok := providersGenMap[cfg.Provider] - if !ok { - errors.Addf("unknown provider: %q", cfg.Provider) - } - if err := errors.Build(); err.IsNotNil() { + if err := errors.Build(); err.HasError() { return nil, err } privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) - if err.IsNotNil() { + if err.HasError() { return nil, E.Failure("generate private key").With(err) } + user := &User{ Email: cfg.Email, key: privKey, } + legoCfg := lego.NewConfig(user) legoCfg.Certificate.KeyType = certcrypto.RSA2048 - legoClient, err := E.Check(lego.NewClient(legoCfg)) - if err.IsNotNil() { - return nil, E.Failure("create lego client").With(err) - } + base := &Provider{ cfg: cfg, user: user, legoCfg: legoCfg, - client: legoClient, - } - legoProvider, err := E.Check(gen(cfg.Options)) - if err.IsNotNil() { - return nil, E.Failure("create lego provider").With(err) - } - err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) - if err.IsNotNil() { - return nil, E.Failure("set challenge provider").With(err) } + return base, E.Nil() } diff --git a/src/autocert/constants.go b/src/autocert/constants.go index 72129bc8..30fd7029 100644 --- a/src/autocert/constants.go +++ b/src/autocert/constants.go @@ -8,9 +8,10 @@ import ( ) const ( - certBasePath = "certs/" - CertFileDefault = certBasePath + "cert.crt" - KeyFileDefault = certBasePath + "priv.key" + certBasePath = "certs/" + CertFileDefault = certBasePath + "cert.crt" + KeyFileDefault = certBasePath + "priv.key" + RegistrationFile = certBasePath + "registration.json" ) const ( @@ -21,11 +22,10 @@ const ( ) var providersGenMap = map[string]ProviderGenerator{ - "": providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig), ProviderLocal: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig), ProviderCloudflare: providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig), ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig), ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig), } -var Logger = logrus.WithField("module", "autocert") +var logger = logrus.WithField("module", "autocert") diff --git a/src/autocert/provider.go b/src/autocert/provider.go index 52128158..f0e70767 100644 --- a/src/autocert/provider.go +++ b/src/autocert/provider.go @@ -5,18 +5,17 @@ import ( "crypto/tls" "crypto/x509" "os" - "slices" - "sync" + "reflect" + "sort" "time" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" - "github.com/sirupsen/logrus" E "github.com/yusing/go-proxy/error" M "github.com/yusing/go-proxy/models" - "github.com/yusing/go-proxy/utils" + U "github.com/yusing/go-proxy/utils" ) type Provider struct { @@ -27,10 +26,9 @@ type Provider struct { tlsCert *tls.Certificate certExpiries CertExpiries - mutex sync.Mutex } -type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, error) +type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, E.NestedError) type CertExpiries map[string]time.Time func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -57,59 +55,72 @@ func (p *Provider) GetExpiries() CertExpiries { } func (p *Provider) ObtainCert() E.NestedError { + if p.cfg.Provider == ProviderLocal { + return E.FailureWhy("obtain cert", "provider is set to \"local\"") + } + + if p.client == nil { + if err := p.initClient(); err.HasError() { + return E.Failure("obtain cert").With(err) + } + } + ne := E.Failure("obtain certificate") client := p.client if p.user.Registration == nil { - reg, err := E.Check(client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) - if err.IsNotNil() { - return ne.With(E.Failure("register account").With(err)) + if err := p.loadRegistration(); err.HasError() { + ne = ne.With(err) + if err := p.registerACME(); err.HasError() { + return ne.With(err) + } } - p.user.Registration = reg } req := certificate.ObtainRequest{ Domains: p.cfg.Domains, Bundle: true, } cert, err := E.Check(client.Certificate.Obtain(req)) - if err.IsNotNil() { + if err.HasError() { return ne.With(err) } err = p.saveCert(cert) - if err.IsNotNil() { + if err.HasError() { return ne.With(E.Failure("save certificate").With(err)) } tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey)) - if err.IsNotNil() { + if err.HasError() { return ne.With(E.Failure("parse obtained certificate").With(err)) } expiries, err := getCertExpiries(&tlsCert) - if err.IsNotNil() { + if err.HasError() { return ne.With(E.Failure("get certificate expiry").With(err)) } p.tlsCert = &tlsCert p.certExpiries = expiries + return E.Nil() } func (p *Provider) LoadCert() E.NestedError { cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)) - if err.IsNotNil() { + if err.HasError() { return err } expiries, err := getCertExpiries(&cert) - if err.IsNotNil() { + if err.HasError() { return err } p.tlsCert = &cert p.certExpiries = expiries - p.renewIfNeeded() - return E.Nil() + + logger.Infof("next renewal in %v", time.Until(p.ShouldRenewOn())) + return p.renewIfNeeded() } func (p *Provider) ShouldRenewOn() time.Time { for _, expiry := range p.certExpiries { - return expiry.AddDate(0, -1, 0) + return expiry.AddDate(0, -1, 0) // 1 month before } // this line should never be reached panic("no certificate available") @@ -120,117 +131,151 @@ func (p *Provider) ScheduleRenewal(ctx context.Context) { return } - logger.Debug("starting renewal scheduler") + logger.Debug("started renewal scheduler") defer logger.Debug("renewal scheduler stopped") - stop := make(chan struct{}) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() for { select { case <-ctx.Done(): return - default: - t := time.Until(p.ShouldRenewOn()) - Logger.Infof("next renewal in %v", t.Round(time.Second)) - go func() { - <-time.After(t) - close(stop) - }() - select { - case <-ctx.Done(): - return - case <-stop: - if err := p.renewIfNeeded(); err.IsNotNil() { - Logger.Fatal(err) - } + case <-ticker.C: // check every 5 seconds + if err := p.renewIfNeeded(); err.HasError() { + logger.Warn(err) } } } } +func (p *Provider) initClient() E.NestedError { + legoClient, err := E.Check(lego.NewClient(p.legoCfg)) + if err.HasError() { + return E.Failure("create lego client").With(err) + } + + legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options) + if err.HasError() { + return E.Failure("create lego provider").With(err) + } + + err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider)) + if err.HasError() { + return E.Failure("set challenge provider").With(err) + } + + p.client = legoClient + return E.Nil() +} + +func (p *Provider) registerACME() E.NestedError { + if p.user.Registration != nil { + return E.Nil() + } + reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})) + if err.HasError() { + return E.Failure("register ACME").With(err) + } + p.user.Registration = reg + + if err := p.saveRegistration(); err.HasError() { + logger.Warn(err) + } + return E.Nil() +} + +func (p *Provider) loadRegistration() E.NestedError { + if p.user.Registration != nil { + return E.Nil() + } + reg := ®istration.Resource{} + err := U.LoadJson(RegistrationFile, reg) + if err.HasError() { + return E.Failure("parse registration file").With(err) + } + p.user.Registration = reg + return E.Nil() +} + +func (p *Provider) saveRegistration() E.NestedError { + return U.SaveJson(RegistrationFile, p.user.Registration, 0o600) +} + func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError { - err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0600) // -rw------- + err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw------- if err != nil { return E.Failure("write key file").With(err) } - err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0644) // -rw-r--r-- + err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r-- if err != nil { return E.Failure("write cert file").With(err) } return E.Nil() } -func (p *Provider) needRenewal() bool { - expired := time.Now().After(p.ShouldRenewOn()) - if expired { - return true +func (p *Provider) certState() CertState { + if time.Now().After(p.ShouldRenewOn()) { + return CertStateExpired } - if len(p.cfg.Domains) != len(p.certExpiries) { - return true - } - wantedDomains := make([]string, len(p.cfg.Domains)) + certDomains := make([]string, len(p.certExpiries)) - copy(wantedDomains, p.cfg.Domains) + wantedDomains := make([]string, len(p.cfg.Domains)) i := 0 for domain := range p.certExpiries { certDomains[i] = domain i++ } - slices.Sort(wantedDomains) - slices.Sort(certDomains) - for i, domain := range certDomains { - if domain != wantedDomains[i] { - return true - } + copy(wantedDomains, p.cfg.Domains) + sort.Strings(wantedDomains) + sort.Strings(certDomains) + + if !reflect.DeepEqual(certDomains, wantedDomains) { + logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains) + return CertStateMismatch } - return false + + return CertStateValid } func (p *Provider) renewIfNeeded() E.NestedError { - if !p.needRenewal() { + switch p.certState() { + case CertStateExpired: + logger.Info("certs expired, renewing") + case CertStateMismatch: + logger.Info("cert domains mismatch with config, renewing") + default: return E.Nil() } - p.mutex.Lock() - defer p.mutex.Unlock() - - if !p.needRenewal() { - return E.Nil() - } - - trials := 0 - for { - err := p.ObtainCert() - if err.IsNotNil() { - return E.Nil() - } - trials++ - if trials > 3 { - return E.Failure("renew certificate").With(err) - } - time.Sleep(5 * time.Second) + if err := p.ObtainCert(); err.HasError() { + return E.Failure("renew certificate").With(err) } + return E.Nil() } func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) { r := make(CertExpiries, len(cert.Certificate)) for _, cert := range cert.Certificate { x509Cert, err := E.Check(x509.ParseCertificate(cert)) - if err.IsNotNil() { + if err.HasError() { return nil, E.Failure("parse certificate").With(err) } if x509Cert.IsCA { continue } r[x509Cert.Subject.CommonName] = x509Cert.NotAfter + for i := range x509Cert.DNSNames { + r[x509Cert.DNSNames[i]] = x509Cert.NotAfter + } } return r, E.Nil() } func setOptions[T interface{}](cfg *T, opt M.AutocertProviderOpt) E.NestedError { for k, v := range opt { - err := utils.SetFieldFromSnake(cfg, k, v) - if err.IsNotNil() { + err := U.SetFieldFromSnake(cfg, k, v) + if err.HasError() { return E.Failure("set autocert option").Subject(k).With(err) } } @@ -241,18 +286,16 @@ func providerGenerator[CT any, PT challenge.Provider]( defaultCfg func() *CT, newProvider func(*CT) (PT, error), ) ProviderGenerator { - return func(opt M.AutocertProviderOpt) (challenge.Provider, error) { + return func(opt M.AutocertProviderOpt) (challenge.Provider, E.NestedError) { cfg := defaultCfg() err := setOptions(cfg, opt) - if err.IsNotNil() { + if err.HasError() { return nil, err } p, err := E.Check(newProvider(cfg)) - if err.IsNotNil() { + if err.HasError() { return nil, err } - return p, nil + return p, E.Nil() } } - -var logger = logrus.WithField("module", "autocert") diff --git a/src/autocert/state.go b/src/autocert/state.go new file mode 100644 index 00000000..ffe308d6 --- /dev/null +++ b/src/autocert/state.go @@ -0,0 +1,9 @@ +package autocert + +type CertState int + +const ( + CertStateValid CertState = 0 + CertStateExpired CertState = iota + CertStateMismatch CertState = iota +) diff --git a/src/common/args.go b/src/common/args.go index 1d96f69d..b91a857b 100644 --- a/src/common/args.go +++ b/src/common/args.go @@ -12,18 +12,19 @@ type Args struct { } const ( - CommandStart = "" - CommandValidate = "validate" - CommandReload = "reload" + CommandStart = "" + CommandValidate = "validate" + CommandListConfigs = "ls-config" + CommandReload = "reload" ) -var ValidCommands = []string{CommandStart, CommandValidate, CommandReload} +var ValidCommands = []string{CommandStart, CommandValidate, CommandListConfigs, CommandReload} func GetArgs() Args { var args Args flag.Parse() args.Command = flag.Arg(0) - if err := validateArgs(args.Command, ValidCommands); err.IsNotNil() { + if err := validateArgs(args.Command, ValidCommands); err.HasError() { logrus.Fatal(err) } return args diff --git a/src/common/env.go b/src/common/env.go index 15326e8c..ac5f8e2c 100644 --- a/src/common/env.go +++ b/src/common/env.go @@ -3,21 +3,16 @@ package common import ( "os" "strings" - - "github.com/sirupsen/logrus" ) var NoSchemaValidation = getEnvBool("GOPROXY_NO_SCHEMA_VALIDATION") var IsDebug = getEnvBool("GOPROXY_DEBUG") -var LogLevel = func() logrus.Level { - if IsDebug { - logrus.SetLevel(logrus.DebugLevel) - } - return logrus.GetLevel() -}() - func getEnvBool(key string) bool { - v := os.Getenv(key) - return v == "1" || strings.ToLower(v) == "true" || strings.ToLower(v) == "yes" || strings.ToLower(v) == "on" + switch strings.ToLower(os.Getenv(key)) { + case "1", "true", "yes", "on": + return true + default: + return false + } } diff --git a/src/config/config.go b/src/config/config.go index f1b73149..3b28f41f 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -37,7 +37,7 @@ func New() (*Config, E.NestedError) { watcher: W.NewFileWatcher(common.ConfigFileName), reloadReq: make(chan struct{}, 1), } - if err := cfg.load(); err.IsNotNil() { + if err := cfg.load(); err.HasError() { return nil, err } cfg.startProviders() @@ -66,7 +66,7 @@ func (cfg *Config) Dispose() { func (cfg *Config) Reload() E.NestedError { cfg.stopProviders() - if err := cfg.load(); err.IsNotNil() { + if err := cfg.load(); err.HasError() { return err } cfg.startProviders() @@ -156,7 +156,7 @@ func (cfg *Config) watchChanges() { case <-cfg.watcherCtx.Done(): return case <-cfg.reloadReq: - if err := cfg.Reload(); err.IsNotNil() { + if err := cfg.Reload(); err.HasError() { cfg.l.Error(err) } } @@ -186,29 +186,29 @@ func (cfg *Config) load() E.NestedError { cfg.l.Debug("loading config") data, err := cfg.reader.Read() - if err.IsNotNil() { + if err.HasError() { return E.Failure("read config").With(err) } model := M.DefaultConfig() - if err := E.From(yaml.Unmarshal(data, model)); err.IsNotNil() { + if err := E.From(yaml.Unmarshal(data, model)); err.HasError() { return E.Failure("parse config").With(err) } if !common.NoSchemaValidation { - if err = Validate(data); err.IsNotNil() { + if err = Validate(data); err.HasError() { return err } } warnings := E.NewBuilder("errors loading config") - cfg.l.Debug("starting autocert") + cfg.l.Debug("initializing autocert") ap, err := autocert.NewConfig(&model.AutoCert).GetProvider() - if err.IsNotNil() { + if err.HasError() { warnings.Add(E.Failure("autocert provider").With(err)) } else { - cfg.l.Debug("started autocert") + cfg.l.Debug("initialized autocert") } cfg.autocertProvider = ap @@ -226,7 +226,7 @@ func (cfg *Config) load() E.NestedError { cfg.value = model - if err := warnings.Build(); err.IsNotNil() { + if err := warnings.Build(); err.HasError() { cfg.l.Warn(err) } @@ -238,12 +238,12 @@ func (cfg *Config) controlProviders(action string, do func(*PR.Provider) E.Neste errors := E.NewBuilder("cannot %s these providers", action) cfg.proxyProviders.EachKVParallel(func(name string, p *PR.Provider) { - if err := do(p); err.IsNotNil() { + if err := do(p); err.HasError() { errors.Add(E.From(err).Subject(p)) } }) - if err := errors.Build(); err.IsNotNil() { + if err := errors.Build(); err.HasError() { cfg.l.Error(err) } } diff --git a/src/docker/client.go b/src/docker/client.go index 4da34739..70ff22d9 100644 --- a/src/docker/client.go +++ b/src/docker/client.go @@ -40,7 +40,7 @@ func ConnectClient(host string) (Client, E.NestedError) { opt = clientOptEnvHost default: helper, err := E.Check(connhelper.GetConnectionHelper(host)) - if err.IsNotNil() { + if err.HasError() { logger.Fatalf("unexpected error: %s", err) } if helper != nil { @@ -65,7 +65,7 @@ func ConnectClient(host string) (Client, E.NestedError) { client, err := E.Check(client.NewClientWithOpts(opt...)) - if err.IsNotNil() { + if err.HasError() { return nil, err } diff --git a/src/docker/client_info.go b/src/docker/client_info.go index c6de93eb..20253a0e 100644 --- a/src/docker/client_info.go +++ b/src/docker/client_info.go @@ -18,7 +18,7 @@ type ClientInfo struct { func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) { dockerClient, err := ConnectClient(clientHost) - if err.IsNotNil() { + if err.HasError() { return nil, E.Failure("create docker client").With(err) } @@ -26,7 +26,7 @@ func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) { defer cancel() containers, err := E.Check(dockerClient.ContainerList(ctx, container.ListOptions{})) - if err.IsNotNil() { + if err.HasError() { return nil, E.Failure("list containers").With(err) } @@ -34,7 +34,7 @@ func GetClientInfo(clientHost string) (*ClientInfo, E.NestedError) { // since the services being proxied to // should have the same IP as the docker client url, err := E.Check(client.ParseHostURL(dockerClient.DaemonHost())) - if err.IsNotNil() { + if err.HasError() { return nil, E.Invalid("host url", dockerClient.DaemonHost()).With(err) } if url.Scheme == "unix" { diff --git a/src/docker/label.go b/src/docker/label.go index 1468226f..1697f99c 100644 --- a/src/docker/label.go +++ b/src/docker/label.go @@ -63,7 +63,7 @@ func ParseLabel(label string, value string) (*Label, E.NestedError) { } // try to parse value v, err := p(value) - if err.IsNotNil() { + if err.HasError() { return nil, err } l.Value = v diff --git a/src/error/error.go b/src/error/error.go index 4acc6067..6cb0fdf0 100644 --- a/src/error/error.go +++ b/src/error/error.go @@ -118,11 +118,11 @@ func (ne NestedError) Subjectf(format string, args ...any) NestedError { return ne } -func (ne NestedError) IsNil() bool { +func (ne NestedError) NoError() bool { return ne.err == nil } -func (ne NestedError) IsNotNil() bool { +func (ne NestedError) HasError() bool { return ne.err != nil } @@ -139,7 +139,7 @@ func (ne *NestedError) writeToSB(sb *strings.Builder, level int, prefix string) ne.writeIndents(sb, level) sb.WriteString(prefix) - if ne.IsNil() { + if ne.NoError() { sb.WriteString("nil") return } diff --git a/src/error/error_test.go b/src/error/error_test.go index 582ff483..e726abee 100644 --- a/src/error/error_test.go +++ b/src/error/error_test.go @@ -22,8 +22,8 @@ func TestErrorIs(t *testing.T) { } func TestNil(t *testing.T) { - ExpectTrue(t, Nil().IsNil()) - ExpectFalse(t, Nil().IsNotNil()) + ExpectTrue(t, Nil().NoError()) + ExpectFalse(t, Nil().HasError()) ExpectEqual(t, Nil().Error(), "nil") } diff --git a/src/main.go b/src/main.go index ee7fd930..ecb1c0cf 100755 --- a/src/main.go +++ b/src/main.go @@ -2,6 +2,8 @@ package main import ( "context" + "encoding/json" + "log" "net/http" "os" "os/signal" @@ -33,14 +35,15 @@ func main() { } logrus.SetFormatter(&logrus.TextFormatter{ - DisableSorting: true, - FullTimestamp: true, - ForceColors: true, - TimestampFormat: "01-02 15:04:05", + DisableSorting: true, + DisableLevelTruncation: true, + FullTimestamp: true, + ForceColors: true, + TimestampFormat: "01-02 15:04:05", }) if args.Command == common.CommandReload { - if err := apiUtils.ReloadServer(); err.IsNotNil() { + if err := apiUtils.ReloadServer(); err.HasError() { l.Fatal(err) } return @@ -52,10 +55,10 @@ func main() { if args.Command == common.CommandValidate { var err E.NestedError data, err := E.Check(os.ReadFile(common.ConfigPath)) - if err.IsNotNil() { + if err.HasError() { l.WithError(err).Fatalf("config error") } - if err = config.Validate(data); err.IsNotNil() { + if err = config.Validate(data); err.HasError() { l.WithError(err).Fatalf("config error") } l.Printf("config OK") @@ -63,10 +66,20 @@ func main() { } cfg, err := config.New() - if err.IsNotNil() { + if err.HasError() { l.Fatalf("config error: %s", err) } + if args.Command == common.CommandListConfigs { + yml, err := E.Check(json.Marshal(cfg.Value())) + if err.HasError() { + panic(err) + } + rawLogger := log.New(os.Stdout, "", 0) + rawLogger.Printf("%s", yml) // raw output for convenience using "jq" + return + } + onShutdown.Add(func() { docker.CloseAllClients() cfg.Dispose() @@ -80,23 +93,27 @@ func main() { autocert := cfg.GetAutoCertProvider() if autocert != nil { - err = autocert.LoadCert() - - if err.IsNotNil() { - l.Error(err) - l.Info("Now attempting to obtain a new certificate...") - if err = autocert.ObtainCert(); err.IsNotNil() { - ctx, certRenewalCancel := context.WithCancel(context.Background()) - go autocert.ScheduleRenewal(ctx) - onShutdown.Add(certRenewalCancel) - } else { + if err = autocert.LoadCert(); err.HasError() { + if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist + l.Error(err) + } + l.Debug("obtaining cert due to error loading cert") + if err = autocert.ObtainCert(); err.HasError() { l.Warn(err) } - } else { - for name, expiry := range autocert.GetExpiries() { - l.Infof("certificate %q: expire on %s", name, expiry) - } } + + if err.NoError() { + ctx, certRenewalCancel := context.WithCancel(context.Background()) + go autocert.ScheduleRenewal(ctx) + onShutdown.Add(certRenewalCancel) + } + + for name, expiry := range autocert.GetExpiries() { + l.Infof("certificate %q: expire on %s", name, expiry) + } + } else { + l.Info("autocert not configured") } proxyServer := server.InitProxyServer(server.Options{ diff --git a/src/proxy/entry.go b/src/proxy/entry.go index 08359c75..dc625cc0 100644 --- a/src/proxy/entry.go +++ b/src/proxy/entry.go @@ -33,7 +33,7 @@ type ( func NewEntry(m *M.ProxyEntry) (any, E.NestedError) { m.SetDefaults() scheme, err := T.NewScheme(m.Scheme) - if err.IsNotNil() { + if err.HasError() { return nil, err } if scheme.IsStream() { @@ -44,23 +44,23 @@ func NewEntry(m *M.ProxyEntry) (any, E.NestedError) { func validateEntry(m *M.ProxyEntry, s T.Scheme) (*Entry, E.NestedError) { host, err := T.NewHost(m.Host) - if err.IsNotNil() { + if err.HasError() { return nil, err } port, err := T.NewPort(m.Port) - if err.IsNotNil() { + if err.HasError() { return nil, err } pathPatterns, err := T.NewPathPatterns(m.PathPatterns) - if err.IsNotNil() { + if err.HasError() { return nil, err } setHeaders, err := T.NewHTTPHeaders(m.SetHeaders) - if err.IsNotNil() { + if err.HasError() { return nil, err } url, err := E.Check(url.Parse(fmt.Sprintf("%s://%s:%d", s, host, port))) - if err.IsNotNil() { + if err.HasError() { return nil, err } return &Entry{ @@ -78,15 +78,15 @@ func validateEntry(m *M.ProxyEntry, s T.Scheme) (*Entry, E.NestedError) { func validateStreamEntry(m *M.ProxyEntry) (*StreamEntry, E.NestedError) { host, err := T.NewHost(m.Host) - if err.IsNotNil() { + if err.HasError() { return nil, err } port, err := T.NewStreamPort(m.Port) - if err.IsNotNil() { + if err.HasError() { return nil, err } scheme, err := T.NewStreamScheme(m.Scheme) - if err.IsNotNil() { + if err.HasError() { return nil, err } return &StreamEntry{ diff --git a/src/proxy/fields/path_pattern.go b/src/proxy/fields/path_pattern.go index 114e9853..4d68ec81 100644 --- a/src/proxy/fields/path_pattern.go +++ b/src/proxy/fields/path_pattern.go @@ -25,7 +25,7 @@ func NewPathPatterns(s []string) (PathPatterns, E.NestedError) { } pp := make(PathPatterns, len(s)) for i, v := range s { - if pattern, err := NewPathPattern(v); err.IsNotNil() { + if pattern, err := NewPathPattern(v); err.HasError() { return nil, err } else { pp[i] = pattern diff --git a/src/proxy/fields/port.go b/src/proxy/fields/port.go index 5da17889..783017e0 100644 --- a/src/proxy/fields/port.go +++ b/src/proxy/fields/port.go @@ -18,7 +18,7 @@ func NewPort(v string) (Port, E.NestedError) { func NewPortInt[Int int | uint16](v Int) (Port, E.NestedError) { pp := Port(v) - if err := pp.boundCheck(); err.IsNotNil() { + if err := pp.boundCheck(); err.HasError() { return ErrPort, err } return pp, E.Nil() diff --git a/src/proxy/fields/stream_port.go b/src/proxy/fields/stream_port.go index 20a58ed5..fdec1f04 100644 --- a/src/proxy/fields/stream_port.go +++ b/src/proxy/fields/stream_port.go @@ -19,21 +19,21 @@ func NewStreamPort(p string) (StreamPort, E.NestedError) { } listeningPort, err := NewPort(split[0]) - if err.IsNotNil() { + if err.HasError() { return StreamPort{}, err } - if err = listeningPort.boundCheck(); err.IsNotNil() { + if err = listeningPort.boundCheck(); err.HasError() { return StreamPort{}, err } proxyPort, err := NewPort(split[1]) - if err.IsNotNil() { + if err.HasError() { proxyPort, err = parseNameToPort(split[1]) - if err.IsNotNil() { + if err.HasError() { return StreamPort{}, err } } - if err = proxyPort.boundCheck(); err.IsNotNil() { + if err = proxyPort.boundCheck(); err.HasError() { return StreamPort{}, err } diff --git a/src/proxy/fields/stream_scheme.go b/src/proxy/fields/stream_scheme.go index adc214d6..6b88a044 100644 --- a/src/proxy/fields/stream_scheme.go +++ b/src/proxy/fields/stream_scheme.go @@ -21,11 +21,11 @@ func NewStreamScheme(s string) (ss *StreamScheme, err E.NestedError) { return nil, E.Invalid("stream scheme", s) } ss.ListeningScheme, err = NewScheme(parts[0]) - if err.IsNotNil() { + if err.HasError() { return nil, err } ss.ProxyScheme, err = NewScheme(parts[1]) - if err.IsNotNil() { + if err.HasError() { return nil, err } return ss, E.Nil() diff --git a/src/proxy/provider/docker_provider.go b/src/proxy/provider/docker_provider.go index a68e6277..8893e658 100755 --- a/src/proxy/provider/docker_provider.go +++ b/src/proxy/provider/docker_provider.go @@ -39,7 +39,7 @@ func (p DockerProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) { entries := M.NewProxyEntries() info, err := D.GetClientInfo(p.dockerHost) - if err.IsNotNil() { + if err.HasError() { return entries, err } @@ -47,7 +47,7 @@ func (p DockerProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) { for _, container := range info.Containers { en, err := p.getEntriesFromLabels(&container, info.Host) - if err.IsNotNil() { + if err.HasError() { errors.Add(err) } // although err is not nil @@ -95,7 +95,7 @@ func (p *DockerProvider) getEntriesFromLabels(container *types.Container, client // find first port, return if no port exposed defaultPort, err := findFirstPort(container) - if err.IsNotNil() { + if err.HasError() { logrus.Debug(mainAlias, " ", err.Error()) } @@ -111,7 +111,7 @@ func (p *DockerProvider) getEntriesFromLabels(container *types.Container, client errors := E.NewBuilder("failed to apply label for %q", mainAlias) for key, val := range container.Labels { lbl, err := D.ParseLabel(key, val) - if err.IsNotNil() { + if err.HasError() { errors.Add(E.From(err).Subject(key)) continue } @@ -121,7 +121,7 @@ func (p *DockerProvider) getEntriesFromLabels(container *types.Container, client if lbl.Target == wildcardAlias { // apply label for all aliases entries.EachKV(func(a string, e *M.ProxyEntry) { - if err = D.ApplyLabel(e, lbl); err.IsNotNil() { + if err = D.ApplyLabel(e, lbl); err.HasError() { errors.Add(E.From(err).Subject(lbl.Target)) } }) @@ -131,7 +131,7 @@ func (p *DockerProvider) getEntriesFromLabels(container *types.Container, client errors.Add(E.NotExists("alias", lbl.Target)) continue } - if err = D.ApplyLabel(config, lbl); err.IsNotNil() { + if err = D.ApplyLabel(config, lbl); err.HasError() { errors.Add(err.Subject(lbl.Target)) } } diff --git a/src/proxy/provider/file_provider.go b/src/proxy/provider/file_provider.go index 0b3f289b..b69df6ed 100644 --- a/src/proxy/provider/file_provider.go +++ b/src/proxy/provider/file_provider.go @@ -34,16 +34,16 @@ func (p *FileProvider) String() string { func (p *FileProvider) GetProxyEntries() (M.ProxyEntries, E.NestedError) { entries := M.NewProxyEntries() data, err := E.Check(os.ReadFile(p.path)) - if err.IsNotNil() { + if err.HasError() { return entries, E.Failure("read file").Subject(p).With(err) } ne := E.Failure("validation").Subject(p) if !common.NoSchemaValidation { - if err = Validate(data); err.IsNotNil() { + if err = Validate(data); err.HasError() { return entries, ne.With(err) } } - if err = entries.UnmarshalFromYAML(data); err.IsNotNil() { + if err = entries.UnmarshalFromYAML(data); err.HasError() { return entries, ne.With(err) } return entries, E.Nil() diff --git a/src/proxy/provider/provider.go b/src/proxy/provider/provider.go index a9ad8f39..52366cfd 100644 --- a/src/proxy/provider/provider.go +++ b/src/proxy/provider/provider.go @@ -92,12 +92,12 @@ func (p *Provider) StartAllRoutes() E.NestedError { nStarted := 0 nFailed := 0 - if err.IsNotNil() { + if err.HasError() { errors.Add(err) } p.routes.EachKVParallel(func(alias string, r R.Route) { - if err := r.Start(); err.IsNotNil() { + if err := r.Start(); err.HasError() { errors.Add(err.Subject(r)) nFailed++ } else { @@ -118,7 +118,7 @@ func (p *Provider) StopAllRoutes() E.NestedError { nStopped := 0 nFailed := 0 p.routes.EachKVParallel(func(alias string, r R.Route) { - if err := r.Stop(); err.IsNotNil() { + if err := r.Stop(); err.HasError() { errors.Add(err.Subject(r)) nFailed++ } else { @@ -195,7 +195,7 @@ func (p *Provider) processReloadRequests() { func (p *Provider) loadRoutes() E.NestedError { entries, err := p.GetProxyEntries() - if err.IsNotNil() { + if err.HasError() { p.l.Warn(err.Subject(p)) } p.routes = R.NewRoutes() @@ -204,7 +204,7 @@ func (p *Provider) loadRoutes() E.NestedError { entries.EachKV(func(a string, e *M.ProxyEntry) { e.Alias = a r, err := R.NewRoute(e) - if err.IsNotNil() { + if err.HasError() { errors.Add(err.Subject(a)) } else { p.routes.Set(a, r) diff --git a/src/route/route.go b/src/route/route.go index bf448aeb..5c2d60da 100755 --- a/src/route/route.go +++ b/src/route/route.go @@ -21,7 +21,7 @@ var NewRoutes = F.NewMap[string, Route] func NewRoute(en *M.ProxyEntry) (Route, E.NestedError) { entry, err := P.NewEntry(en) - if err.IsNotNil() { + if err.HasError() { return nil, err } switch e := entry.(type) { diff --git a/src/route/tcp_route.go b/src/route/tcp_route.go index 0f25bd2d..8bcfba8e 100755 --- a/src/route/tcp_route.go +++ b/src/route/tcp_route.go @@ -78,7 +78,7 @@ func (route *TCPRoute) CloseListeners() { route.listener.Close() route.listener = nil for _, pipe := range route.pipe { - if err := pipe.Stop(); err.IsNotNil() { + if err := pipe.Stop(); err.HasError() { route.l.Error(err) } } diff --git a/src/utils/io.go b/src/utils/io.go index 3df7184f..988a1e21 100644 --- a/src/utils/io.go +++ b/src/utils/io.go @@ -2,6 +2,7 @@ package utils import ( "context" + "encoding/json" "io" "os" "sync/atomic" @@ -135,7 +136,7 @@ func (p *BidirectionalPipe) Start() E.NestedError { errCh <- p.pDstSrc.Start() }() for err := range errCh { - if err.IsNotNil() { + if err.HasError() { return err } } @@ -149,4 +150,20 @@ func (p *BidirectionalPipe) Stop() E.NestedError { func Copy(ctx context.Context, dst io.WriteCloser, src io.ReadCloser) E.NestedError { _, err := io.Copy(dst, StdReadCloser{&ReadCloser{ctx: ctx, r: src}}) return E.From(err) -} \ No newline at end of file +} + +func LoadJson[T any](path string, pointer *T) E.NestedError { + data, err := os.ReadFile(path) + if err != nil { + return E.From(err) + } + return E.From(json.Unmarshal(data, pointer)) +} + +func SaveJson[T any](path string, pointer *T, perm os.FileMode) E.NestedError { + data, err := json.Marshal(pointer) + if err != nil { + return E.From(err) + } + return E.From(os.WriteFile(path, data, perm)) +} diff --git a/src/utils/testing.go b/src/utils/testing.go index b33f4a73..c51a8eef 100644 --- a/src/utils/testing.go +++ b/src/utils/testing.go @@ -9,7 +9,7 @@ import ( func ExpectErrNil(t *testing.T, err E.NestedError) { t.Helper() - if err.IsNotNil() { + if err.HasError() { t.Errorf("expected err=nil, got %s", err.Error()) } } diff --git a/src/watcher/docker_watcher.go b/src/watcher/docker_watcher.go index 3f1484eb..f87ae720 100644 --- a/src/watcher/docker_watcher.go +++ b/src/watcher/docker_watcher.go @@ -31,13 +31,13 @@ func (w *DockerWatcher) Events(ctx context.Context) (<-chan Event, <-chan E.Nest var err E.NestedError for range 3 { cl, err = D.ConnectClient(w.host) - if err.IsNil() { + if err.NoError() { break } errCh <- E.From(err) time.Sleep(1 * time.Second) } - if err.IsNotNil() { + if err.HasError() { errCh <- E.Failure("connecting to docker") return }