diff --git a/internal/api/v1/cert/renew.go b/internal/api/v1/cert/renew.go index 816e458a..d81e2562 100644 --- a/internal/api/v1/cert/renew.go +++ b/internal/api/v1/cert/renew.go @@ -9,7 +9,6 @@ import ( "github.com/yusing/godoxy/internal/autocert" "github.com/yusing/godoxy/internal/logging/memlogger" apitypes "github.com/yusing/goutils/apitypes" - gperr "github.com/yusing/goutils/errs" "github.com/yusing/goutils/http/websocket" ) @@ -40,33 +39,33 @@ func Renew(c *gin.Context) { logs, cancel := memlogger.Events() defer cancel() - done := make(chan struct{}) - go func() { - defer close(done) + // Stream logs until WebSocket connection closes (renewal runs in background) + for { + select { + case <-manager.Context().Done(): + return + case l := <-logs: + if err != nil { + return + } - err = autocert.ObtainCert() - if err != nil { - gperr.LogError("failed to obtain cert", err) - _ = manager.WriteData(websocket.TextMessage, []byte(err.Error()), 10*time.Second) - } else { - log.Info().Msg("cert obtained successfully") + err = manager.WriteData(websocket.TextMessage, l, 10*time.Second) + if err != nil { + return + } + } } }() - for { - select { - case l := <-logs: - if err != nil { - return - } - - err = manager.WriteData(websocket.TextMessage, l, 10*time.Second) - if err != nil { - return - } - case <-done: - return - } + // renewal happens in background + ok := autocert.ForceExpiryAll() + if !ok { + log.Error().Msg("cert renewal already in progress") + time.Sleep(1 * time.Second) // wait for the log above to be sent + return } + log.Info().Msg("cert force renewal requested") + + autocert.WaitRenewalDone(manager.Context()) } diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 43e2aa70..b1f035e3 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -5,6 +5,7 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/x509" + "fmt" "net/http" "os" "regexp" @@ -19,13 +20,14 @@ import ( strutils "github.com/yusing/goutils/strings" ) +type ConfigExtra Config type Config struct { Email string `json:"email,omitempty"` Domains []string `json:"domains,omitempty"` CertPath string `json:"cert_path,omitempty"` KeyPath string `json:"key_path,omitempty"` - Extra []Config `json:"extra,omitempty"` - ACMEKeyPath string `json:"acme_key_path,omitempty"` + Extra []ConfigExtra `json:"extra,omitempty"` + ACMEKeyPath string `json:"acme_key_path,omitempty"` // shared by all extra providers Provider string `json:"provider,omitempty"` Options map[string]strutils.Redacted `json:"options,omitempty"` @@ -42,15 +44,12 @@ type Config struct { HTTPClient *http.Client `json:"-"` // for tests only challengeProvider challenge.Provider + + idx int // 0: main, 1+: extra[i] } var ( - ErrMissingDomain = gperr.New("missing field 'domains'") - ErrMissingEmail = gperr.New("missing field 'email'") - ErrMissingProvider = gperr.New("missing field 'provider'") - ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'") - ErrMissingCertPath = gperr.New("missing field 'cert_path'") - ErrMissingKeyPath = gperr.New("missing field 'key_path'") + ErrMissingField = gperr.New("missing field") ErrDuplicatedPath = gperr.New("duplicated path") ErrInvalidDomain = gperr.New("invalid domain") ErrUnknownProvider = gperr.New("unknown provider") @@ -66,95 +65,22 @@ var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`) // Validate implements the utils.CustomValidator interface. func (cfg *Config) Validate() gperr.Error { - if cfg == nil { - return nil - } + seenPaths := make(map[string]int) // path -> provider idx (0 for main, 1+ for extras) + return cfg.validate(seenPaths) +} +func (cfg *ConfigExtra) Validate() gperr.Error { + return nil // done by main config's validate +} + +func (cfg *ConfigExtra) AsConfig() *Config { + return (*Config)(cfg) +} + +func (cfg *Config) validate(seenPaths map[string]int) gperr.Error { if cfg.Provider == "" { cfg.Provider = ProviderLocal } - - b := gperr.NewBuilder("autocert errors") - if len(cfg.Extra) > 0 { - seenCertPaths := make(map[string]int, len(cfg.Extra)) - seenKeyPaths := make(map[string]int, len(cfg.Extra)) - for i := range cfg.Extra { - if cfg.Extra[i].CertPath == "" { - b.Add(ErrMissingCertPath.Subjectf("extra[%d].cert_path", i)) - } - if cfg.Extra[i].KeyPath == "" { - b.Add(ErrMissingKeyPath.Subjectf("extra[%d].key_path", i)) - } - if cfg.Extra[i].CertPath != "" { - if first, ok := seenCertPaths[cfg.Extra[i].CertPath]; ok { - b.Add(ErrDuplicatedPath.Subjectf("extra[%d].cert_path", i).Withf("first: %d", first)) - } else { - seenCertPaths[cfg.Extra[i].CertPath] = i - } - } - if cfg.Extra[i].KeyPath != "" { - if first, ok := seenKeyPaths[cfg.Extra[i].KeyPath]; ok { - b.Add(ErrDuplicatedPath.Subjectf("extra[%d].key_path", i).Withf("first: %d", first)) - } else { - seenKeyPaths[cfg.Extra[i].KeyPath] = i - } - } - } - } - - if cfg.Provider == ProviderCustom && cfg.CADirURL == "" { - b.Add(ErrMissingCADirURL) - } - - if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { - if len(cfg.Domains) == 0 { - b.Add(ErrMissingDomain) - } - if cfg.Email == "" { - b.Add(ErrMissingEmail) - } - if cfg.Provider != ProviderCustom { - for i, d := range cfg.Domains { - if !domainOrWildcardRE.MatchString(d) { - b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) - } - } - } - // check if provider is implemented - providerConstructor, ok := Providers[cfg.Provider] - if !ok { - if cfg.Provider != ProviderCustom { - b.Add(ErrUnknownProvider. - Subject(cfg.Provider). - With(gperr.DoYouMeanField(cfg.Provider, Providers))) - } - } else { - provider, err := providerConstructor(cfg.Options) - if err != nil { - b.Add(err) - } else { - cfg.challengeProvider = provider - } - } - } - - if cfg.challengeProvider == nil { - cfg.challengeProvider, _ = Providers[ProviderLocal](nil) - } - return b.Error() -} - -func (cfg *Config) dns01Options() []dns01.ChallengeOption { - return []dns01.ChallengeOption{ - dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)), - } -} - -func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) { - if err := cfg.Validate(); err != nil { - return nil, nil, err - } - if cfg.CertPath == "" { cfg.CertPath = CertFileDefault } @@ -165,6 +91,83 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) { cfg.ACMEKeyPath = ACMEKeyFileDefault } + b := gperr.NewBuilder("certificate error") + + // check if cert_path is unique + if first, ok := seenPaths[cfg.CertPath]; ok { + b.Add(ErrDuplicatedPath.Subjectf("cert_path %s", cfg.CertPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first))) + } else { + seenPaths[cfg.CertPath] = cfg.idx + } + + // check if key_path is unique + if first, ok := seenPaths[cfg.KeyPath]; ok { + b.Add(ErrDuplicatedPath.Subjectf("key_path %s", cfg.KeyPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first))) + } else { + seenPaths[cfg.KeyPath] = cfg.idx + } + + if cfg.Provider == ProviderCustom && cfg.CADirURL == "" { + b.Add(ErrMissingField.Subject("ca_dir_url")) + } + + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { + if len(cfg.Domains) == 0 { + b.Add(ErrMissingField.Subject("domains")) + } + if cfg.Email == "" { + b.Add(ErrMissingField.Subject("email")) + } + if cfg.Provider != ProviderCustom { + for i, d := range cfg.Domains { + if !domainOrWildcardRE.MatchString(d) { + b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) + } + } + } + } + + // check if provider is implemented + providerConstructor, ok := Providers[cfg.Provider] + if !ok { + if cfg.Provider != ProviderCustom { + b.Add(ErrUnknownProvider. + Subject(cfg.Provider). + With(gperr.DoYouMeanField(cfg.Provider, Providers))) + } + } else { + provider, err := providerConstructor(cfg.Options) + if err != nil { + b.Add(err) + } else { + cfg.challengeProvider = provider + } + } + + if cfg.challengeProvider == nil { + cfg.challengeProvider, _ = Providers[ProviderLocal](nil) + } + + if len(cfg.Extra) > 0 { + for i := range cfg.Extra { + cfg.Extra[i] = MergeExtraConfig(cfg, &cfg.Extra[i]) + cfg.Extra[i].AsConfig().idx = i + 1 + err := cfg.Extra[i].AsConfig().validate(seenPaths) + if err != nil { + b.Add(err.Subjectf("extra[%d]", i)) + } + } + } + return b.Error() +} + +func (cfg *Config) dns01Options() []dns01.ChallengeOption { + return []dns01.ChallengeOption{ + dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)), + } +} + +func (cfg *Config) GetLegoConfig() (*User, *lego.Config, error) { var privKey *ecdsa.PrivateKey var err error @@ -208,6 +211,46 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) { return user, legoCfg, nil } +func MergeExtraConfig(mainCfg *Config, extraCfg *ConfigExtra) ConfigExtra { + merged := ConfigExtra(*mainCfg) + merged.Extra = nil + merged.CertPath = extraCfg.CertPath + merged.KeyPath = extraCfg.KeyPath + // NOTE: Using same ACME key as main provider + + if extraCfg.Provider != "" { + merged.Provider = extraCfg.Provider + } + if extraCfg.Email != "" { + merged.Email = extraCfg.Email + } + if len(extraCfg.Domains) > 0 { + merged.Domains = extraCfg.Domains + } + if len(extraCfg.Options) > 0 { + merged.Options = extraCfg.Options + } + if len(extraCfg.Resolvers) > 0 { + merged.Resolvers = extraCfg.Resolvers + } + if extraCfg.CADirURL != "" { + merged.CADirURL = extraCfg.CADirURL + } + if len(extraCfg.CACerts) > 0 { + merged.CACerts = extraCfg.CACerts + } + if extraCfg.EABKid != "" { + merged.EABKid = extraCfg.EABKid + } + if extraCfg.EABHmac != "" { + merged.EABHmac = extraCfg.EABHmac + } + if extraCfg.HTTPClient != nil { + merged.HTTPClient = extraCfg.HTTPClient + } + return merged +} + func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) { if common.IsTest { return nil, os.ErrNotExist diff --git a/internal/autocert/config_test.go b/internal/autocert/config_test.go index 58782366..6bb53de1 100644 --- a/internal/autocert/config_test.go +++ b/internal/autocert/config_test.go @@ -1,27 +1,32 @@ -package autocert +package autocert_test import ( "fmt" "testing" + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" + "github.com/yusing/godoxy/internal/dnsproviders" "github.com/yusing/godoxy/internal/serialization" ) func TestEABConfigRequired(t *testing.T) { + dnsproviders.InitProviders() + tests := []struct { name string - cfg *Config + cfg *autocert.Config wantErr bool }{ - {name: "Missing EABKid", cfg: &Config{EABHmac: "1234567890"}, wantErr: true}, - {name: "Missing EABHmac", cfg: &Config{EABKid: "1234567890"}, wantErr: true}, - {name: "Valid EAB", cfg: &Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false}, + {name: "Missing EABKid", cfg: &autocert.Config{EABHmac: "1234567890"}, wantErr: true}, + {name: "Missing EABHmac", cfg: &autocert.Config{EABKid: "1234567890"}, wantErr: true}, + {name: "Valid EAB", cfg: &autocert.Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { yaml := fmt.Appendf(nil, "eab_kid: %s\neab_hmac: %s", test.cfg.EABKid, test.cfg.EABHmac) - cfg := Config{} + cfg := autocert.Config{} err := serialization.UnmarshalValidateYAML(yaml, &cfg) if (err != nil) != test.wantErr { t.Errorf("Validate() error = %v, wantErr %v", err, test.wantErr) @@ -29,3 +34,27 @@ func TestEABConfigRequired(t *testing.T) { }) } } + +func TestExtraCertKeyPathsUnique(t *testing.T) { + t.Run("duplicate cert_path rejected", func(t *testing.T) { + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + Extra: []autocert.ConfigExtra{ + {CertPath: "a.crt", KeyPath: "a.key"}, + {CertPath: "a.crt", KeyPath: "b.key"}, + }, + } + require.Error(t, cfg.Validate()) + }) + + t.Run("duplicate key_path rejected", func(t *testing.T) { + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + Extra: []autocert.ConfigExtra{ + {CertPath: "a.crt", KeyPath: "a.key"}, + {CertPath: "b.crt", KeyPath: "a.key"}, + }, + } + require.Error(t, cfg.Validate()) + }) +} diff --git a/internal/autocert/paths.go b/internal/autocert/paths.go index c486f061..573d51e1 100644 --- a/internal/autocert/paths.go +++ b/internal/autocert/paths.go @@ -5,5 +5,4 @@ const ( CertFileDefault = certBasePath + "cert.crt" KeyFileDefault = certBasePath + "priv.key" ACMEKeyFileDefault = certBasePath + "acme.key" - LastFailureFile = certBasePath + ".last_failure" ) diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 163d6bbc..bff4e55a 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -1,16 +1,19 @@ package autocert import ( + "context" "crypto/sha256" "crypto/tls" "crypto/x509" "errors" "fmt" + "io/fs" "maps" "os" "path/filepath" "slices" "strings" + "sync" "sync/atomic" "time" @@ -28,6 +31,8 @@ import ( type ( Provider struct { + logger zerolog.Logger + cfg *Config user *User legoCfg *lego.Config @@ -42,12 +47,18 @@ type ( extraProviders []*Provider sniMatcher sniMatcher + + forceRenewalCh chan struct{} + forceRenewalDoneCh atomic.Value // chan struct{} + + scheduleRenewalOnce sync.Once } CertExpiries map[string]time.Time + RenewMode uint8 ) -var ErrGetCertFailure = errors.New("get certificate failed") +var ErrNoCertificate = errors.New("no certificate found") const ( // renew failed for whatever reason, 1 hour cooldown @@ -56,21 +67,36 @@ const ( requestCooldownDuration = 15 * time.Second ) +const ( + renewModeForce = iota + renewModeIfNeeded +) + // could be nil var ActiveProvider atomic.Pointer[Provider] -func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider { - return &Provider{ +func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) (*Provider, error) { + p := &Provider{ cfg: cfg, user: user, legoCfg: legoCfg, lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath), + forceRenewalCh: make(chan struct{}, 1), } + if cfg.idx == 0 { + p.logger = log.With().Str("provider", "main").Logger() + } else { + p.logger = log.With().Str("provider", fmt.Sprintf("extra[%d]", cfg.idx)).Logger() + } + if err := p.setupExtraProviders(); err != nil { + return nil, err + } + return p, nil } func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { if p.tlsCert == nil { - return nil, ErrGetCertFailure + return nil, ErrNoCertificate } if hello == nil || hello.ServerName == "" { return p.tlsCert, nil @@ -82,7 +108,14 @@ func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) } func (p *Provider) GetName() string { - return p.cfg.Provider + if p.cfg.idx == 0 { + return "main" + } + return fmt.Sprintf("extra[%d]", p.cfg.idx) +} + +func (p *Provider) fmtError(err error) error { + return gperr.PrependSubject(fmt.Sprintf("provider: %s", p.GetName()), err) } func (p *Provider) GetCertPath() string { @@ -129,45 +162,88 @@ func (p *Provider) ClearLastFailure() error { return nil } p.lastFailure = time.Time{} - return os.Remove(p.lastFailureFile) + err := os.Remove(p.lastFailureFile) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return err + } + return nil } -func (p *Provider) ObtainCert() error { - if len(p.extraProviders) > 0 { - errs := gperr.NewGroup("autocert errors") - errs.Go(p.obtainCertSelf) - for _, ep := range p.extraProviders { - errs.Go(ep.obtainCertSelf) - } - if err := errs.Wait().Error(); err != nil { - return err - } - p.rebuildSNIMatcher() +// allProviders returns all providers including this provider and all extra providers. +func (p *Provider) allProviders() []*Provider { + return append([]*Provider{p}, p.extraProviders...) +} + +// ObtainCertIfNotExistsAll obtains a new certificate for this provider and all extra providers if they do not exist. +func (p *Provider) ObtainCertIfNotExistsAll() error { + errs := gperr.NewGroup("obtain cert error") + + for _, provider := range p.allProviders() { + errs.Go(func() error { + if err := provider.obtainCertIfNotExists(); err != nil { + return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err) + } + return nil + }) + } + + p.rebuildSNIMatcher() + return errs.Wait().Error() +} + +// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist. +func (p *Provider) obtainCertIfNotExists() error { + err := p.LoadCert() + if err == nil { return nil } - return p.obtainCertSelf() + + if !errors.Is(err, fs.ErrNotExist) { + return err + } + + // check last failure + lastFailure, err := p.GetLastFailure() + if err != nil { + return fmt.Errorf("failed to get last failure: %w", err) + } + if !lastFailure.IsZero() && time.Since(lastFailure) < requestCooldownDuration { + return fmt.Errorf("still in cooldown until %s", strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local())) + } + + p.logger.Info().Msg("cert not found, obtaining new cert") + return p.ObtainCert() } -func (p *Provider) obtainCertSelf() error { +// ObtainCertAll renews existing certificates or obtains new certificates for this provider and all extra providers. +func (p *Provider) ObtainCertAll() error { + errs := gperr.NewGroup("obtain cert error") + for _, provider := range p.allProviders() { + errs.Go(func() error { + if err := provider.obtainCertIfNotExists(); err != nil { + return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err) + } + return nil + }) + } + return errs.Wait().Error() +} + +// ObtainCert renews existing certificate or obtains a new certificate for this provider. +func (p *Provider) ObtainCert() error { if p.cfg.Provider == ProviderLocal { return nil } if p.cfg.Provider == ProviderPseudo { - log.Info().Msg("init client for pseudo provider") + p.logger.Info().Msg("init client for pseudo provider") <-time.After(time.Second) - log.Info().Msg("registering acme for pseudo provider") + p.logger.Info().Msg("registering acme for pseudo provider") <-time.After(time.Second) - log.Info().Msg("obtained cert for pseudo provider") + p.logger.Info().Msg("obtained cert for pseudo provider") return nil } - if lastFailure, err := p.GetLastFailure(); err != nil { - return err - } else if time.Since(lastFailure) < requestCooldownDuration { - return fmt.Errorf("%w: still in cooldown until %s", ErrGetCertFailure, strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local())) - } - if p.client == nil { if err := p.initClient(); err != nil { return err @@ -227,6 +303,7 @@ func (p *Provider) obtainCertSelf() error { } p.tlsCert = &tlsCert p.certExpiries = expiries + p.rebuildSNIMatcher() if err := p.ClearLastFailure(); err != nil { return fmt.Errorf("failed to clear last failure: %w", err) @@ -235,19 +312,37 @@ func (p *Provider) obtainCertSelf() error { } func (p *Provider) LoadCert() error { + var errs gperr.Builder cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath) if err != nil { - return fmt.Errorf("load SSL certificate: %w", err) + errs.Addf("load SSL certificate: %w", p.fmtError(err)) } + expiries, err := getCertExpiries(&cert) if err != nil { - return fmt.Errorf("parse SSL certificate: %w", err) + errs.Addf("parse SSL certificate: %w", p.fmtError(err)) } + p.tlsCert = &cert p.certExpiries = expiries - log.Info().Msgf("next cert renewal in %s", strutils.FormatDuration(time.Until(p.ShouldRenewOn()))) - return p.renewIfNeeded() + for _, ep := range p.extraProviders { + if err := ep.LoadCert(); err != nil { + errs.Add(err) + } + } + + p.rebuildSNIMatcher() + return errs.Error() +} + +// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers. +func (p *Provider) PrintCertExpiriesAll() { + for _, provider := range p.allProviders() { + for domain, expiry := range provider.certExpiries { + p.logger.Info().Str("domain", domain).Msgf("certificate expire on %s", strutils.FormatTime(expiry)) + } + } } // ShouldRenewOn returns the time at which the certificate should be renewed. @@ -255,65 +350,129 @@ func (p *Provider) ShouldRenewOn() time.Time { for _, expiry := range p.certExpiries { return expiry.AddDate(0, -1, 0) // 1 month before } - // this line should never be reached - panic("no certificate available") + // this line should never be reached in production, but will be useful for testing + return time.Now().AddDate(0, 1, 0) // 1 month after } -func (p *Provider) ScheduleRenewal(parent task.Parent) { +// ForceExpiryAll triggers immediate certificate renewal for this provider and all extra providers. +// Returns true if the renewal was triggered, false if the renewal was dropped. +// +// If at least one renewal is triggered, returns true. +func (p *Provider) ForceExpiryAll() (ok bool) { + doneCh := make(chan struct{}) + if swapped := p.forceRenewalDoneCh.CompareAndSwap(nil, doneCh); !swapped { // already in progress + close(doneCh) + return false + } + + select { + case p.forceRenewalCh <- struct{}{}: + ok = true + default: + } + + for _, ep := range p.extraProviders { + if ep.ForceExpiryAll() { + ok = true + } + } + + return ok +} + +// WaitRenewalDone waits for the renewal to complete. +// Returns false if the renewal was dropped. +func (p *Provider) WaitRenewalDone(ctx context.Context) bool { + done, ok := p.forceRenewalDoneCh.Load().(chan struct{}) + if !ok || done == nil { + return false + } + select { + case <-done: + case <-ctx.Done(): + return false + } + + for _, ep := range p.extraProviders { + if !ep.WaitRenewalDone(ctx) { + return false + } + } + return true +} + +// ScheduleRenewalAll schedules the renewal of the certificate for this provider and all extra providers. +func (p *Provider) ScheduleRenewalAll(parent task.Parent) { + p.scheduleRenewalOnce.Do(func() { + p.scheduleRenewal(parent) + }) + for _, ep := range p.extraProviders { + ep.scheduleRenewalOnce.Do(func() { + ep.scheduleRenewal(parent) + }) + } +} + +var emptyForceRenewalDoneCh any = chan struct{}(nil) + +// scheduleRenewal schedules the renewal of the certificate for this provider. +func (p *Provider) scheduleRenewal(parent task.Parent) { if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo { return } - go func() { - renewalTime := p.ShouldRenewOn() - timer := time.NewTimer(time.Until(renewalTime)) - defer timer.Stop() - task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true) + timer := time.NewTimer(time.Until(p.ShouldRenewOn())) + task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true) + + renew := func(renewMode RenewMode) { + defer func() { + if done, ok := p.forceRenewalDoneCh.Swap(emptyForceRenewalDoneCh).(chan struct{}); ok && done != nil { + close(done) + } + }() + + renewed, err := p.renew(renewMode) + if err != nil { + gperr.LogWarn("autocert: cert renew failed", p.fmtError(err)) + notif.Notify(¬if.LogMessage{ + Level: zerolog.ErrorLevel, + Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()), + Body: notif.MessageBody(err.Error()), + }) + return + } + if renewed { + p.rebuildSNIMatcher() + + notif.Notify(¬if.LogMessage{ + Level: zerolog.InfoLevel, + Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()), + Body: notif.ListBody(p.cfg.Domains), + }) + + // Reset on success + if err := p.ClearLastFailure(); err != nil { + gperr.LogWarn("autocert: failed to clear last failure", p.fmtError(err)) + } + timer.Reset(time.Until(p.ShouldRenewOn())) + } + } + + go func() { + defer timer.Stop() defer task.Finish(nil) for { select { case <-task.Context().Done(): return + case <-p.forceRenewalCh: + renew(renewModeForce) case <-timer.C: - // Retry after 1 hour on failure - lastFailure, err := p.GetLastFailure() - if err != nil { - gperr.LogWarn("autocert: failed to get last failure", err) - continue - } - if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration { - continue - } - if err := p.renewIfNeeded(); err != nil { - gperr.LogWarn("autocert: cert renew failed", err) - if err := p.UpdateLastFailure(); err != nil { - gperr.LogWarn("autocert: failed to update last failure", err) - } - notif.Notify(¬if.LogMessage{ - Level: zerolog.ErrorLevel, - Title: "SSL certificate renewal failed", - Body: notif.MessageBody(err.Error()), - }) - continue - } - notif.Notify(¬if.LogMessage{ - Level: zerolog.InfoLevel, - Title: "SSL certificate renewed", - Body: notif.ListBody(p.cfg.Domains), - }) - // Reset on success - if err := p.ClearLastFailure(); err != nil { - gperr.LogWarn("autocert: failed to clear last failure", err) - } - renewalTime = p.ShouldRenewOn() - timer.Reset(time.Until(renewalTime)) + renew(renewModeIfNeeded) } } }() - for _, ep := range p.extraProviders { - ep.ScheduleRenewal(parent) - } } func (p *Provider) initClient() error { @@ -409,21 +568,42 @@ func (p *Provider) certState() CertState { return CertStateValid } -func (p *Provider) renewIfNeeded() error { +func (p *Provider) renew(mode RenewMode) (renewed bool, err error) { if p.cfg.Provider == ProviderLocal { - return nil + return false, nil } - switch p.certState() { - case CertStateExpired: - log.Info().Msg("certs expired, renewing") - case CertStateMismatch: - log.Info().Msg("cert domains mismatch with config, renewing") - default: - return nil + if mode != renewModeForce { + // Retry after 1 hour on failure + lastFailure, err := p.GetLastFailure() + if err != nil { + return false, fmt.Errorf("failed to get last failure: %w", err) + } + if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration { + until := lastFailure.Add(renewalCooldownDuration).Local() + return false, fmt.Errorf("still in cooldown until %s", strutils.FormatTime(until)) + } } - return p.obtainCertSelf() + if mode == renewModeIfNeeded { + switch p.certState() { + case CertStateExpired: + log.Info().Msg("certs expired, renewing") + case CertStateMismatch: + log.Info().Msg("cert domains mismatch with config, renewing") + default: + return false, nil + } + } + + if mode == renewModeForce { + log.Info().Msg("force renewing cert by user request") + } + + if err := p.ObtainCert(); err != nil { + return false, err + } + return true, nil } func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { @@ -445,15 +625,16 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { } func lastFailureFileFor(certPath, keyPath string) string { - if certPath == "" && keyPath == "" { - return LastFailureFile - } dir := filepath.Dir(certPath) sum := sha256.Sum256([]byte(certPath + "|" + keyPath)) return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6])) } func (p *Provider) rebuildSNIMatcher() { + if p.cfg.idx != 0 { // only main provider has extra providers + return + } + p.sniMatcher = sniMatcher{} p.sniMatcher.addProvider(p) for _, ep := range p.extraProviders { diff --git a/internal/autocert/provider_test/custom_test.go b/internal/autocert/provider_test/custom_test.go index d62ad77e..30ae2544 100644 --- a/internal/autocert/provider_test/custom_test.go +++ b/internal/autocert/provider_test/custom_test.go @@ -10,12 +10,15 @@ import ( "encoding/base64" "encoding/json" "encoding/pem" + "fmt" "io" "math/big" "net" "net/http" "net/http/httptest" + "sort" "strings" + "sync" "testing" "time" @@ -24,6 +27,368 @@ import ( "github.com/yusing/godoxy/internal/dnsproviders" ) +// TestACMEServer implements a minimal ACME server for testing with request tracking. +type TestACMEServer struct { + server *httptest.Server + caCert *x509.Certificate + caKey *rsa.PrivateKey + clientCSRs map[string]*x509.CertificateRequest + orderDomains map[string][]string + authzDomains map[string]string + orderSeq int + certRequestCount map[string]int + renewalRequestCount map[string]int + mu sync.Mutex +} + +func newTestACMEServer(t *testing.T) *TestACMEServer { + t.Helper() + + // Generate CA certificate and key + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test CA"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"Test"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + + caCert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + acme := &TestACMEServer{ + caCert: caCert, + caKey: caKey, + clientCSRs: make(map[string]*x509.CertificateRequest), + orderDomains: make(map[string][]string), + authzDomains: make(map[string]string), + orderSeq: 0, + certRequestCount: make(map[string]int), + renewalRequestCount: make(map[string]int), + } + + mux := http.NewServeMux() + acme.setupRoutes(mux) + + acme.server = httptest.NewUnstartedServer(mux) + acme.server.TLS = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{caCert.Raw}, + PrivateKey: caKey, + }, + }, + MinVersion: tls.VersionTLS12, + } + acme.server.StartTLS() + return acme +} + +func (s *TestACMEServer) Close() { + s.server.Close() +} + +func (s *TestACMEServer) URL() string { + return s.server.URL +} + +func (s *TestACMEServer) httpClient() *http.Client { + certPool := x509.NewCertPool() + certPool.AddCert(s.caCert) + + return &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 30 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + MinVersion: tls.VersionTLS12, + }, + }, + } +} + +func (s *TestACMEServer) setupRoutes(mux *http.ServeMux) { + mux.HandleFunc("/acme/acme/directory", s.handleDirectory) + mux.HandleFunc("/acme/new-nonce", s.handleNewNonce) + mux.HandleFunc("/acme/new-account", s.handleNewAccount) + mux.HandleFunc("/acme/new-order", s.handleNewOrder) + mux.HandleFunc("/acme/authz/", s.handleAuthorization) + mux.HandleFunc("/acme/chall/", s.handleChallenge) + mux.HandleFunc("/acme/order/", s.handleOrder) + mux.HandleFunc("/acme/cert/", s.handleCertificate) +} + +func (s *TestACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) { + directory := map[string]any{ + "newNonce": s.server.URL + "/acme/new-nonce", + "newAccount": s.server.URL + "/acme/new-account", + "newOrder": s.server.URL + "/acme/new-order", + "keyChange": s.server.URL + "/acme/key-change", + "meta": map[string]any{ + "termsOfService": s.server.URL + "/terms", + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(directory) +} + +func (s *TestACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "test-nonce-12345") + w.WriteHeader(http.StatusOK) +} + +func (s *TestACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) { + account := map[string]any{ + "status": "valid", + "contact": []string{"mailto:test@example.com"}, + "orders": s.server.URL + "/acme/orders", + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", s.server.URL+"/acme/account/1") + w.Header().Set("Replay-Nonce", "test-nonce-67890") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(account) +} + +func (s *TestACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var jws struct { + Payload string `json:"payload"` + } + json.Unmarshal(body, &jws) + payloadBytes, _ := base64.RawURLEncoding.DecodeString(jws.Payload) + var orderReq struct { + Identifiers []map[string]string `json:"identifiers"` + } + json.Unmarshal(payloadBytes, &orderReq) + + domains := []string{} + for _, id := range orderReq.Identifiers { + domains = append(domains, id["value"]) + } + sort.Strings(domains) + domainKey := strings.Join(domains, ",") + + s.mu.Lock() + s.orderSeq++ + orderID := fmt.Sprintf("test-order-%d", s.orderSeq) + authzID := fmt.Sprintf("test-authz-%d", s.orderSeq) + s.orderDomains[orderID] = domains + if len(domains) > 0 { + s.authzDomains[authzID] = domains[0] + } + s.certRequestCount[domainKey]++ + s.mu.Unlock() + + order := map[string]any{ + "status": "ready", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": orderReq.Identifiers, + "authorizations": []string{s.server.URL + "/acme/authz/" + authzID}, + "finalize": s.server.URL + "/acme/order/" + orderID + "/finalize", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", s.server.URL+"/acme/order/"+orderID) + w.Header().Set("Replay-Nonce", "test-nonce-order") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(order) +} + +func (s *TestACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) { + authzID := strings.TrimPrefix(r.URL.Path, "/acme/authz/") + domain := s.authzDomains[authzID] + if domain == "" { + domain = "test.example.com" + } + authz := map[string]any{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifier": map[string]string{"type": "dns", "value": domain}, + "challenges": []map[string]any{ + { + "type": "dns-01", + "status": "valid", + "url": s.server.URL + "/acme/chall/test-chall-789", + "token": "test-token-abc123", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-authz") + json.NewEncoder(w).Encode(authz) +} + +func (s *TestACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) { + challenge := map[string]any{ + "type": "dns-01", + "status": "valid", + "url": r.URL.String(), + "token": "test-token-abc123", + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-chall") + json.NewEncoder(w).Encode(challenge) +} + +func (s *TestACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/finalize") { + s.handleFinalize(w, r) + return + } + + orderID := strings.TrimPrefix(r.URL.Path, "/acme/order/") + domains := s.orderDomains[orderID] + if len(domains) == 0 { + domains = []string{"test.example.com"} + } + certURL := s.server.URL + "/acme/cert/" + orderID + order := map[string]any{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": func() []map[string]string { + out := make([]map[string]string, 0, len(domains)) + for _, d := range domains { + out = append(out, map[string]string{"type": "dns", "value": d}) + } + return out + }(), + "certificate": certURL, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-order-get") + json.NewEncoder(w).Encode(order) +} + +func (s *TestACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request", http.StatusBadRequest) + return + } + + csr, err := s.extractCSRFromJWS(body) + if err != nil { + http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest) + return + } + + orderID := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/acme/order/"), "/finalize") + s.mu.Lock() + s.clientCSRs[orderID] = csr + + // Detect renewal: if we already have a certificate for these domains, it's a renewal + domains := csr.DNSNames + sort.Strings(domains) + domainKey := strings.Join(domains, ",") + + if s.certRequestCount[domainKey] > 1 { + s.renewalRequestCount[domainKey]++ + } + s.mu.Unlock() + + certURL := s.server.URL + "/acme/cert/" + orderID + order := map[string]any{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": func() []map[string]string { + out := make([]map[string]string, 0, len(domains)) + for _, d := range domains { + out = append(out, map[string]string{"type": "dns", "value": d}) + } + return out + }(), + "certificate": certURL, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize")) + w.Header().Set("Replay-Nonce", "test-nonce-finalize") + json.NewEncoder(w).Encode(order) +} + +func (s *TestACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) { + var jws struct { + Payload string `json:"payload"` + } + if err := json.Unmarshal(jwsData, &jws); err != nil { + return nil, err + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return nil, err + } + var finalizeReq struct { + CSR string `json:"csr"` + } + if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil { + return nil, err + } + csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR) + if err != nil { + return nil, err + } + return x509.ParseCertificateRequest(csrBytes) +} + +func (s *TestACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) { + orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/") + csr, exists := s.clientCSRs[orderID] + if !exists { + http.Error(w, "No CSR found for order", http.StatusBadRequest) + return + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Cert"}, + Country: []string{"US"}, + }, + DNSNames: csr.DNSNames, + NotBefore: time.Now(), + NotAfter: time.Now().Add(90 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw}) + + w.Header().Set("Content-Type", "application/pem-certificate-chain") + w.Header().Set("Replay-Nonce", "test-nonce-cert") + w.Write(append(certPEM, caPEM...)) +} + func TestMain(m *testing.M) { dnsproviders.InitProviders() m.Run() @@ -41,7 +406,7 @@ func TestCustomProvider(t *testing.T) { ACMEKeyPath: "certs/custom-acme.key", } - err := cfg.Validate() + err := error(cfg.Validate()) require.NoError(t, err) user, legoCfg, err := cfg.GetLegoConfig() @@ -62,7 +427,8 @@ func TestCustomProvider(t *testing.T) { err := cfg.Validate() require.Error(t, err) - require.Contains(t, err.Error(), "missing field 'ca_dir_url'") + require.Contains(t, err.Error(), "missing field") + require.Contains(t, err.Error(), "ca_dir_url") }) t.Run("custom provider with step-ca internal CA", func(t *testing.T) { @@ -76,7 +442,7 @@ func TestCustomProvider(t *testing.T) { ACMEKeyPath: "certs/internal-acme.key", } - err := cfg.Validate() + err := error(cfg.Validate()) require.NoError(t, err) user, legoCfg, err := cfg.GetLegoConfig() @@ -86,9 +452,10 @@ func TestCustomProvider(t *testing.T) { require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL) require.Equal(t, "admin@internal.com", user.Email) - provider := autocert.NewProvider(cfg, user, legoCfg) + provider, err := autocert.NewProvider(cfg, user, legoCfg) + require.NoError(t, err) require.NotNil(t, provider) - require.Equal(t, autocert.ProviderCustom, provider.GetName()) + require.Equal(t, "main", provider.GetName()) require.Equal(t, "certs/internal.crt", provider.GetCertPath()) require.Equal(t, "certs/internal.key", provider.GetKeyPath()) }) @@ -119,7 +486,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) { require.NotNil(t, user) require.NotNil(t, legoCfg) - provider := autocert.NewProvider(cfg, user, legoCfg) + provider, err := autocert.NewProvider(cfg, user, legoCfg) + require.NoError(t, err) require.NotNil(t, provider) // Test obtaining certificate @@ -161,7 +529,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) { require.NotNil(t, user) require.NotNil(t, legoCfg) - provider := autocert.NewProvider(cfg, user, legoCfg) + provider, err := autocert.NewProvider(cfg, user, legoCfg) + require.NoError(t, err) require.NotNil(t, provider) err = provider.ObtainCert() @@ -178,330 +547,3 @@ func TestObtainCertFromCustomProvider(t *testing.T) { require.True(t, time.Now().After(x509Cert.NotBefore)) }) } - -// testACMEServer implements a minimal ACME server for testing. -type testACMEServer struct { - server *httptest.Server - caCert *x509.Certificate - caKey *rsa.PrivateKey - clientCSRs map[string]*x509.CertificateRequest - orderID string -} - -func newTestACMEServer(t *testing.T) *testACMEServer { - t.Helper() - - // Generate CA certificate and key - caKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test CA"}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"Test"}, - StreetAddress: []string{""}, - PostalCode: []string{""}, - }, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - NotBefore: time.Now(), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) - require.NoError(t, err) - - caCert, err := x509.ParseCertificate(caCertDER) - require.NoError(t, err) - - acme := &testACMEServer{ - caCert: caCert, - caKey: caKey, - clientCSRs: make(map[string]*x509.CertificateRequest), - orderID: "test-order-123", - } - - mux := http.NewServeMux() - acme.setupRoutes(mux) - - acme.server = httptest.NewUnstartedServer(mux) - acme.server.TLS = &tls.Config{ - Certificates: []tls.Certificate{ - { - Certificate: [][]byte{caCert.Raw}, - PrivateKey: caKey, - }, - }, - MinVersion: tls.VersionTLS12, - } - acme.server.StartTLS() - return acme -} - -func (s *testACMEServer) Close() { - s.server.Close() -} - -func (s *testACMEServer) URL() string { - return s.server.URL -} - -func (s *testACMEServer) httpClient() *http.Client { - certPool := x509.NewCertPool() - certPool.AddCert(s.caCert) - - return &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 30 * time.Second, - ResponseHeaderTimeout: 30 * time.Second, - TLSClientConfig: &tls.Config{ - RootCAs: certPool, - MinVersion: tls.VersionTLS12, - }, - }, - } -} - -func (s *testACMEServer) setupRoutes(mux *http.ServeMux) { - // ACME directory endpoint - mux.HandleFunc("/acme/acme/directory", s.handleDirectory) - - // ACME endpoints - mux.HandleFunc("/acme/new-nonce", s.handleNewNonce) - mux.HandleFunc("/acme/new-account", s.handleNewAccount) - mux.HandleFunc("/acme/new-order", s.handleNewOrder) - mux.HandleFunc("/acme/authz/", s.handleAuthorization) - mux.HandleFunc("/acme/chall/", s.handleChallenge) - mux.HandleFunc("/acme/order/", s.handleOrder) - mux.HandleFunc("/acme/cert/", s.handleCertificate) -} - -func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) { - directory := map[string]interface{}{ - "newNonce": s.server.URL + "/acme/new-nonce", - "newAccount": s.server.URL + "/acme/new-account", - "newOrder": s.server.URL + "/acme/new-order", - "keyChange": s.server.URL + "/acme/key-change", - "meta": map[string]interface{}{ - "termsOfService": s.server.URL + "/terms", - }, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(directory) -} - -func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Replay-Nonce", "test-nonce-12345") - w.WriteHeader(http.StatusOK) -} - -func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) { - account := map[string]interface{}{ - "status": "valid", - "contact": []string{"mailto:test@example.com"}, - "orders": s.server.URL + "/acme/orders", - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Location", s.server.URL+"/acme/account/1") - w.Header().Set("Replay-Nonce", "test-nonce-67890") - w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(account) -} - -func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) { - authzID := "test-authz-456" - - order := map[string]interface{}{ - "status": "ready", // Skip pending state for simplicity - "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), - "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, - "authorizations": []string{s.server.URL + "/acme/authz/" + authzID}, - "finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize", - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID) - w.Header().Set("Replay-Nonce", "test-nonce-order") - w.WriteHeader(http.StatusCreated) - json.NewEncoder(w).Encode(order) -} - -func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) { - authz := map[string]interface{}{ - "status": "valid", // Skip challenge validation for simplicity - "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), - "identifier": map[string]string{"type": "dns", "value": "test.example.com"}, - "challenges": []map[string]interface{}{ - { - "type": "dns-01", - "status": "valid", - "url": s.server.URL + "/acme/chall/test-chall-789", - "token": "test-token-abc123", - }, - }, - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Replay-Nonce", "test-nonce-authz") - json.NewEncoder(w).Encode(authz) -} - -func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) { - challenge := map[string]interface{}{ - "type": "dns-01", - "status": "valid", - "url": r.URL.String(), - "token": "test-token-abc123", - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Replay-Nonce", "test-nonce-chall") - json.NewEncoder(w).Encode(challenge) -} - -func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) { - if strings.HasSuffix(r.URL.Path, "/finalize") { - s.handleFinalize(w, r) - return - } - - certURL := s.server.URL + "/acme/cert/" + s.orderID - order := map[string]interface{}{ - "status": "valid", - "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), - "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, - "certificate": certURL, - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Replay-Nonce", "test-nonce-order-get") - json.NewEncoder(w).Encode(order) -} - -func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) { - // Read the JWS payload - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read request", http.StatusBadRequest) - return - } - - // Extract CSR from JWS payload - csr, err := s.extractCSRFromJWS(body) - if err != nil { - http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest) - return - } - - // Store the CSR for certificate generation - s.clientCSRs[s.orderID] = csr - - certURL := s.server.URL + "/acme/cert/" + s.orderID - order := map[string]interface{}{ - "status": "valid", - "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), - "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, - "certificate": certURL, - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize")) - w.Header().Set("Replay-Nonce", "test-nonce-finalize") - json.NewEncoder(w).Encode(order) -} - -func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) { - // Parse the JWS structure - var jws struct { - Protected string `json:"protected"` - Payload string `json:"payload"` - Signature string `json:"signature"` - } - - if err := json.Unmarshal(jwsData, &jws); err != nil { - return nil, err - } - - // Decode the payload - payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload) - if err != nil { - return nil, err - } - - // Parse the finalize request - var finalizeReq struct { - CSR string `json:"csr"` - } - - if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil { - return nil, err - } - - // Decode the CSR - csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR) - if err != nil { - return nil, err - } - - // Parse the CSR - csr, err := x509.ParseCertificateRequest(csrBytes) - if err != nil { - return nil, err - } - - return csr, nil -} - -func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) { - // Extract order ID from URL - orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/") - - // Get the CSR for this order - csr, exists := s.clientCSRs[orderID] - if !exists { - http.Error(w, "No CSR found for order", http.StatusBadRequest) - return - } - - // Create certificate using the public key from the client's CSR - template := &x509.Certificate{ - SerialNumber: big.NewInt(2), - Subject: pkix.Name{ - Organization: []string{"Test Cert"}, - Country: []string{"US"}, - }, - DNSNames: csr.DNSNames, - NotBefore: time.Now(), - NotAfter: time.Now().Add(90 * 24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - // Use the public key from the CSR and sign with CA key - certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Return certificate chain - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw}) - - w.Header().Set("Content-Type", "application/pem-certificate-chain") - w.Header().Set("Replay-Nonce", "test-nonce-cert") - w.Write(append(certPEM, caPEM...)) -} diff --git a/internal/autocert/provider_test/extra_validation_test.go b/internal/autocert/provider_test/extra_validation_test.go deleted file mode 100644 index 3fbb5174..00000000 --- a/internal/autocert/provider_test/extra_validation_test.go +++ /dev/null @@ -1,32 +0,0 @@ -package provider_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - "github.com/yusing/godoxy/internal/autocert" -) - -func TestExtraCertKeyPathsUnique(t *testing.T) { - t.Run("duplicate cert_path rejected", func(t *testing.T) { - cfg := &autocert.Config{ - Provider: autocert.ProviderLocal, - Extra: []autocert.Config{ - {CertPath: "a.crt", KeyPath: "a.key"}, - {CertPath: "a.crt", KeyPath: "b.key"}, - }, - } - require.Error(t, cfg.Validate()) - }) - - t.Run("duplicate key_path rejected", func(t *testing.T) { - cfg := &autocert.Config{ - Provider: autocert.ProviderLocal, - Extra: []autocert.Config{ - {CertPath: "a.crt", KeyPath: "a.key"}, - {CertPath: "b.crt", KeyPath: "a.key"}, - }, - } - require.Error(t, cfg.Validate()) - }) -} diff --git a/internal/autocert/provider_test/multi_cert_test.go b/internal/autocert/provider_test/multi_cert_test.go new file mode 100644 index 00000000..d77afe1f --- /dev/null +++ b/internal/autocert/provider_test/multi_cert_test.go @@ -0,0 +1,90 @@ +//nolint:errchkjson,errcheck +package provider_test + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" + "github.com/yusing/godoxy/internal/serialization" + "github.com/yusing/goutils/task" +) + +func buildMultiCertYAML(serverURL string) []byte { + return fmt.Appendf(nil, ` +email: main@example.com +domains: [main.example.com] +provider: custom +ca_dir_url: %s/acme/acme/directory +cert_path: certs/main.crt +key_path: certs/main.key +extra: + - email: extra1@example.com + domains: [extra1.example.com] + cert_path: certs/extra1.crt + key_path: certs/extra1.key + - email: extra2@example.com + domains: [extra2.example.com] + cert_path: certs/extra2.crt + key_path: certs/extra2.key +`, serverURL) +} + +func TestMultipleCertificatesLifecycle(t *testing.T) { + acmeServer := newTestACMEServer(t) + defer acmeServer.Close() + + yamlConfig := buildMultiCertYAML(acmeServer.URL()) + var cfg autocert.Config + cfg.HTTPClient = acmeServer.httpClient() + + /* unmarshal yaml config with multiple certs */ + err := error(serialization.UnmarshalValidateYAML(yamlConfig, &cfg)) + require.NoError(t, err) + require.Equal(t, []string{"main.example.com"}, cfg.Domains) + require.Len(t, cfg.Extra, 2) + require.Equal(t, []string{"extra1.example.com"}, cfg.Extra[0].Domains) + require.Equal(t, []string{"extra2.example.com"}, cfg.Extra[1].Domains) + + var provider *autocert.Provider + + /* initialize autocert with multi-cert config */ + user, legoCfg, gerr := cfg.GetLegoConfig() + require.NoError(t, gerr) + provider, err = autocert.NewProvider(&cfg, user, legoCfg) + require.NoError(t, err) + require.NotNil(t, provider) + + // Start renewal scheduler + root := task.RootTask("test", false) + defer root.Finish(nil) + provider.ScheduleRenewalAll(root) + + require.Equal(t, "custom", cfg.Provider) + require.Equal(t, "custom", cfg.Extra[0].Provider) + require.Equal(t, "custom", cfg.Extra[1].Provider) + + /* track cert requests for all configs */ + os.MkdirAll("certs", 0755) + defer os.RemoveAll("certs") + + err = provider.ObtainCertIfNotExistsAll() + require.NoError(t, err) + + require.Equal(t, 1, acmeServer.certRequestCount["main.example.com"]) + require.Equal(t, 1, acmeServer.certRequestCount["extra1.example.com"]) + require.Equal(t, 1, acmeServer.certRequestCount["extra2.example.com"]) + + /* track renewal scheduling and requests */ + + // force renewal for all providers and wait for completion + ok := provider.ForceExpiryAll() + require.True(t, ok) + provider.WaitRenewalDone(t.Context()) + + require.Equal(t, 1, acmeServer.renewalRequestCount["main.example.com"]) + require.Equal(t, 1, acmeServer.renewalRequestCount["extra1.example.com"]) + require.Equal(t, 1, acmeServer.renewalRequestCount["extra2.example.com"]) +} diff --git a/internal/autocert/provider_test/sni_test.go b/internal/autocert/provider_test/sni_test.go index 766593cd..01a07b8f 100644 --- a/internal/autocert/provider_test/sni_test.go +++ b/internal/autocert/provider_test/sni_test.go @@ -71,15 +71,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert, KeyPath: extraKey}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"}) require.NoError(t, err) @@ -100,15 +103,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert, KeyPath: extraKey}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) require.NoError(t, err) @@ -129,15 +135,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert, KeyPath: extraKey}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"}) require.NoError(t, err) @@ -159,8 +168,11 @@ func TestGetCertBySNI(t *testing.T) { require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(nil) require.NoError(t, err) @@ -182,8 +194,11 @@ func TestGetCertBySNI(t *testing.T) { require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""}) require.NoError(t, err) @@ -204,15 +219,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert, KeyPath: extraKey}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"}) require.NoError(t, err) @@ -233,15 +251,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert, KeyPath: extraKey}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "}) require.NoError(t, err) @@ -262,15 +283,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert1, KeyPath: extraKey1}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"}) require.NoError(t, err) @@ -292,8 +316,11 @@ func TestGetCertBySNI(t *testing.T) { require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"}) require.NoError(t, err) @@ -317,7 +344,7 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert1, KeyPath: extraKey1}, {CertPath: extraCert2, KeyPath: extraKey2}, }, @@ -325,8 +352,11 @@ func TestGetCertBySNI(t *testing.T) { require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"}) require.NoError(t, err) @@ -352,15 +382,18 @@ func TestGetCertBySNI(t *testing.T) { Provider: autocert.ProviderLocal, CertPath: mainCert, KeyPath: mainKey, - Extra: []autocert.Config{ + Extra: []autocert.ConfigExtra{ {CertPath: extraCert, KeyPath: extraKey}, }, } require.NoError(t, cfg.Validate()) - p := autocert.NewProvider(cfg, nil, nil) - require.NoError(t, p.Setup()) + p, err := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, err) + + err = p.LoadCert() + require.NoError(t, err) cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) require.NoError(t, err) diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index e114be50..119a8759 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -1,101 +1,30 @@ package autocert import ( - "errors" - "fmt" - "os" - - "github.com/rs/zerolog/log" gperr "github.com/yusing/goutils/errs" - strutils "github.com/yusing/goutils/strings" ) -func (p *Provider) Setup() (err error) { - if err = p.LoadCert(); err != nil { - if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist - return err - } - log.Debug().Msg("obtaining cert due to error loading cert") - if err = p.ObtainCert(); err != nil { - return err - } - } - - if err = p.setupExtraProviders(); err != nil { - return err - } - - for _, expiry := range p.GetExpiries() { - log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) - break - } - - return nil -} - -func (p *Provider) setupExtraProviders() error { - p.extraProviders = nil +func (p *Provider) setupExtraProviders() gperr.Error { p.sniMatcher = sniMatcher{} if len(p.cfg.Extra) == 0 { - p.rebuildSNIMatcher() return nil } - for i := range p.cfg.Extra { - merged := mergeExtraConfig(p.cfg, &p.cfg.Extra[i]) - user, legoCfg, err := merged.GetLegoConfig() + p.extraProviders = make([]*Provider, 0, len(p.cfg.Extra)) + + errs := gperr.NewBuilder("setup extra providers error") + for _, extra := range p.cfg.Extra { + user, legoCfg, err := extra.AsConfig().GetLegoConfig() if err != nil { - return err.Subjectf("extra[%d]", i) + errs.Add(p.fmtError(err)) + continue } - ep := NewProvider(&merged, user, legoCfg) - if err := ep.Setup(); err != nil { - return gperr.PrependSubject(fmt.Sprintf("extra[%d]", i), err) + ep, err := NewProvider(extra.AsConfig(), user, legoCfg) + if err != nil { + errs.Add(p.fmtError(err)) + continue } p.extraProviders = append(p.extraProviders, ep) } - p.rebuildSNIMatcher() - return nil -} - -func mergeExtraConfig(mainCfg *Config, extraCfg *Config) Config { - merged := *mainCfg - merged.Extra = nil - merged.CertPath = extraCfg.CertPath - merged.KeyPath = extraCfg.KeyPath - - if merged.Email == "" { - merged.Email = mainCfg.Email - } - - if len(extraCfg.Domains) > 0 { - merged.Domains = extraCfg.Domains - } - if extraCfg.ACMEKeyPath != "" { - merged.ACMEKeyPath = extraCfg.ACMEKeyPath - } - if extraCfg.Provider != "" { - merged.Provider = extraCfg.Provider - } - if len(extraCfg.Options) > 0 { - merged.Options = extraCfg.Options - } - if len(extraCfg.Resolvers) > 0 { - merged.Resolvers = extraCfg.Resolvers - } - if extraCfg.CADirURL != "" { - merged.CADirURL = extraCfg.CADirURL - } - if len(extraCfg.CACerts) > 0 { - merged.CACerts = extraCfg.CACerts - } - if extraCfg.EABKid != "" { - merged.EABKid = extraCfg.EABKid - } - if extraCfg.EABHmac != "" { - merged.EABHmac = extraCfg.EABHmac - } - if extraCfg.HTTPClient != nil { - merged.HTTPClient = extraCfg.HTTPClient - } - return merged + return errs.Error() } diff --git a/internal/autocert/setup_test.go b/internal/autocert/setup_test.go new file mode 100644 index 00000000..335c2e57 --- /dev/null +++ b/internal/autocert/setup_test.go @@ -0,0 +1,82 @@ +package autocert_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" + "github.com/yusing/godoxy/internal/dnsproviders" + "github.com/yusing/godoxy/internal/serialization" + strutils "github.com/yusing/goutils/strings" +) + +func TestSetupExtraProviders(t *testing.T) { + dnsproviders.InitProviders() + + cfgYAML := ` +email: test@example.com +domains: [example.com] +provider: custom +ca_dir_url: https://ca.example.com:9000/acme/acme/directory +cert_path: certs/test.crt +key_path: certs/test.key +options: {key: value} +resolvers: [8.8.8.8] +ca_certs: [ca.crt] +eab_kid: eabKid +eab_hmac: eabHmac +extra: + - cert_path: certs/extra.crt + key_path: certs/extra.key + - cert_path: certs/extra2.crt + key_path: certs/extra2.key + email: override@example.com + provider: pseudo + domains: [override.com] + ca_dir_url: https://ca2.example.com/directory + options: {opt2: val2} + resolvers: [1.1.1.1] + ca_certs: [ca2.crt] + eab_kid: eabKid2 + eab_hmac: eabHmac2 +` + + var cfg autocert.Config + err := error(serialization.UnmarshalValidateYAML([]byte(cfgYAML), &cfg)) + require.NoError(t, err) + + // Test: extra[0] inherits all fields from main except CertPath and KeyPath. + merged0 := cfg.Extra[0] + require.Equal(t, "certs/extra.crt", merged0.CertPath) + require.Equal(t, "certs/extra.key", merged0.KeyPath) + // Inherited fields from main config: + require.Equal(t, "test@example.com", merged0.Email) // inherited + require.Equal(t, "custom", merged0.Provider) // inherited + require.Equal(t, []string{"example.com"}, merged0.Domains) // inherited + require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", merged0.CADirURL) // inherited + require.Equal(t, map[string]strutils.Redacted{"key": "value"}, merged0.Options) // inherited + require.Equal(t, []string{"8.8.8.8"}, merged0.Resolvers) // inherited + require.Equal(t, []string{"ca.crt"}, merged0.CACerts) // inherited + require.Equal(t, "eabKid", merged0.EABKid) // inherited + require.Equal(t, "eabHmac", merged0.EABHmac) // inherited + require.Equal(t, cfg.HTTPClient, merged0.HTTPClient) // inherited + require.Nil(t, merged0.Extra) + + // Test: extra[1] overrides some fields, and inherits others. + merged1 := cfg.Extra[1] + require.Equal(t, "certs/extra2.crt", merged1.CertPath) + require.Equal(t, "certs/extra2.key", merged1.KeyPath) + // Overridden fields: + require.Equal(t, "override@example.com", merged1.Email) // overridden + require.Equal(t, "pseudo", merged1.Provider) // overridden + require.Equal(t, []string{"override.com"}, merged1.Domains) // overridden + require.Equal(t, "https://ca2.example.com/directory", merged1.CADirURL) // overridden + require.Equal(t, map[string]strutils.Redacted{"opt2": "val2"}, merged1.Options) // overridden + require.Equal(t, []string{"1.1.1.1"}, merged1.Resolvers) // overridden + require.Equal(t, []string{"ca2.crt"}, merged1.CACerts) // overridden + require.Equal(t, "eabKid2", merged1.EABKid) // overridden + require.Equal(t, "eabHmac2", merged1.EABHmac) // overridden + // Inherited field: + require.Equal(t, cfg.HTTPClient, merged1.HTTPClient) // inherited + require.Nil(t, merged1.Extra) +} diff --git a/internal/autocert/types/provider.go b/internal/autocert/types/provider.go index 69fdf918..64b95224 100644 --- a/internal/autocert/types/provider.go +++ b/internal/autocert/types/provider.go @@ -9,6 +9,6 @@ import ( type Provider interface { Setup() error GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) - ScheduleRenewal(task.Parent) - ObtainCert() error + ScheduleRenewalAll(task.Parent) + ObtainCertAll() error } diff --git a/internal/config/state.go b/internal/config/state.go index bf3e027b..0586dca0 100644 --- a/internal/config/state.go +++ b/internal/config/state.go @@ -272,6 +272,7 @@ func (state *state) initAutoCert() error { autocertCfg := state.AutoCert if autocertCfg == nil { autocertCfg = new(autocert.Config) + _ = autocertCfg.Validate() } user, legoCfg, err := autocertCfg.GetLegoConfig() @@ -279,12 +280,19 @@ func (state *state) initAutoCert() error { return err } - state.autocertProvider = autocert.NewProvider(autocertCfg, user, legoCfg) - if err := state.autocertProvider.Setup(); err != nil { - return fmt.Errorf("autocert error: %w", err) - } else { - state.autocertProvider.ScheduleRenewal(state.task) + p, err := autocert.NewProvider(autocertCfg, user, legoCfg) + if err != nil { + return err } + + if err := p.ObtainCertIfNotExistsAll(); err != nil { + return err + } + + p.ScheduleRenewalAll(state.task) + p.PrintCertExpiriesAll() + + state.autocertProvider = p return nil } diff --git a/internal/dnsproviders/dummy.go b/internal/dnsproviders/dummy.go index 42ddb47a..999bf3c7 100644 --- a/internal/dnsproviders/dummy.go +++ b/internal/dnsproviders/dummy.go @@ -1,7 +1,7 @@ package dnsproviders type ( - DummyConfig struct{} + DummyConfig map[string]any DummyProvider struct{} )