mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 16:58:31 +02:00
fix(autocert): ensure extra certificate registration and renewal scheduling
Extra providers were not being properly initialized during NewProvider(), causing certificate registration and renewal scheduling to be skipped. - Add ConfigExtra type with idx field for provider indexing - Add MergeExtraConfig() for inheriting main provider settings - Add setupExtraProviders() for recursive extra provider initialization - Refactor NewProvider to return error and call setupExtraProviders() - Add provider-scoped logger with "main" or "extra[N]" name - Add batch operations: ObtainCertIfNotExistsAll(), ObtainCertAll() - Add ForceExpiryAll() with completion tracking via WaitRenewalDone() - Add RenewMode (force/ifNeeded) for controlling renewal behavior - Add PrintCertExpiriesAll() for logging all provider certificate expiries Summary of staged changes: - config.go: Added ConfigExtra type, MergeExtraConfig(), recursive validation with path uniqueness checking - provider.go: Added provider indexing, scoped logger, batch cert operations, force renewal with completion tracking, RenewMode control - setup.go: New file with setupExtraProviders() for proper extra provider initialization - setup_test.go: New tests for extra provider setup - multi_cert_test.go: New tests for multi-certificate functionality - renew.go: Updated to use new provider API with error handling - state.go: Updated to handle NewProvider error return
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
package autocert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -28,6 +31,8 @@ import (
|
||||
|
||||
type (
|
||||
Provider struct {
|
||||
logger zerolog.Logger
|
||||
|
||||
cfg *Config
|
||||
user *User
|
||||
legoCfg *lego.Config
|
||||
@@ -42,12 +47,18 @@ type (
|
||||
|
||||
extraProviders []*Provider
|
||||
sniMatcher sniMatcher
|
||||
|
||||
forceRenewalCh chan struct{}
|
||||
forceRenewalDoneCh atomic.Value // chan struct{}
|
||||
|
||||
scheduleRenewalOnce sync.Once
|
||||
}
|
||||
|
||||
CertExpiries map[string]time.Time
|
||||
RenewMode uint8
|
||||
)
|
||||
|
||||
var ErrGetCertFailure = errors.New("get certificate failed")
|
||||
var ErrNoCertificate = errors.New("no certificate found")
|
||||
|
||||
const (
|
||||
// renew failed for whatever reason, 1 hour cooldown
|
||||
@@ -56,21 +67,36 @@ const (
|
||||
requestCooldownDuration = 15 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
renewModeForce = iota
|
||||
renewModeIfNeeded
|
||||
)
|
||||
|
||||
// could be nil
|
||||
var ActiveProvider atomic.Pointer[Provider]
|
||||
|
||||
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider {
|
||||
return &Provider{
|
||||
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) (*Provider, error) {
|
||||
p := &Provider{
|
||||
cfg: cfg,
|
||||
user: user,
|
||||
legoCfg: legoCfg,
|
||||
lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath),
|
||||
forceRenewalCh: make(chan struct{}, 1),
|
||||
}
|
||||
if cfg.idx == 0 {
|
||||
p.logger = log.With().Str("provider", "main").Logger()
|
||||
} else {
|
||||
p.logger = log.With().Str("provider", fmt.Sprintf("extra[%d]", cfg.idx)).Logger()
|
||||
}
|
||||
if err := p.setupExtraProviders(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if p.tlsCert == nil {
|
||||
return nil, ErrGetCertFailure
|
||||
return nil, ErrNoCertificate
|
||||
}
|
||||
if hello == nil || hello.ServerName == "" {
|
||||
return p.tlsCert, nil
|
||||
@@ -82,7 +108,14 @@ func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
func (p *Provider) GetName() string {
|
||||
return p.cfg.Provider
|
||||
if p.cfg.idx == 0 {
|
||||
return "main"
|
||||
}
|
||||
return fmt.Sprintf("extra[%d]", p.cfg.idx)
|
||||
}
|
||||
|
||||
func (p *Provider) fmtError(err error) error {
|
||||
return gperr.PrependSubject(fmt.Sprintf("provider: %s", p.GetName()), err)
|
||||
}
|
||||
|
||||
func (p *Provider) GetCertPath() string {
|
||||
@@ -129,45 +162,88 @@ func (p *Provider) ClearLastFailure() error {
|
||||
return nil
|
||||
}
|
||||
p.lastFailure = time.Time{}
|
||||
return os.Remove(p.lastFailureFile)
|
||||
err := os.Remove(p.lastFailureFile)
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) ObtainCert() error {
|
||||
if len(p.extraProviders) > 0 {
|
||||
errs := gperr.NewGroup("autocert errors")
|
||||
errs.Go(p.obtainCertSelf)
|
||||
for _, ep := range p.extraProviders {
|
||||
errs.Go(ep.obtainCertSelf)
|
||||
}
|
||||
if err := errs.Wait().Error(); err != nil {
|
||||
return err
|
||||
}
|
||||
p.rebuildSNIMatcher()
|
||||
// allProviders returns all providers including this provider and all extra providers.
|
||||
func (p *Provider) allProviders() []*Provider {
|
||||
return append([]*Provider{p}, p.extraProviders...)
|
||||
}
|
||||
|
||||
// ObtainCertIfNotExistsAll obtains a new certificate for this provider and all extra providers if they do not exist.
|
||||
func (p *Provider) ObtainCertIfNotExistsAll() error {
|
||||
errs := gperr.NewGroup("obtain cert error")
|
||||
|
||||
for _, provider := range p.allProviders() {
|
||||
errs.Go(func() error {
|
||||
if err := provider.obtainCertIfNotExists(); err != nil {
|
||||
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
p.rebuildSNIMatcher()
|
||||
return errs.Wait().Error()
|
||||
}
|
||||
|
||||
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
|
||||
func (p *Provider) obtainCertIfNotExists() error {
|
||||
err := p.LoadCert()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return p.obtainCertSelf()
|
||||
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
|
||||
// check last failure
|
||||
lastFailure, err := p.GetLastFailure()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get last failure: %w", err)
|
||||
}
|
||||
if !lastFailure.IsZero() && time.Since(lastFailure) < requestCooldownDuration {
|
||||
return fmt.Errorf("still in cooldown until %s", strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
|
||||
}
|
||||
|
||||
p.logger.Info().Msg("cert not found, obtaining new cert")
|
||||
return p.ObtainCert()
|
||||
}
|
||||
|
||||
func (p *Provider) obtainCertSelf() error {
|
||||
// ObtainCertAll renews existing certificates or obtains new certificates for this provider and all extra providers.
|
||||
func (p *Provider) ObtainCertAll() error {
|
||||
errs := gperr.NewGroup("obtain cert error")
|
||||
for _, provider := range p.allProviders() {
|
||||
errs.Go(func() error {
|
||||
if err := provider.obtainCertIfNotExists(); err != nil {
|
||||
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return errs.Wait().Error()
|
||||
}
|
||||
|
||||
// ObtainCert renews existing certificate or obtains a new certificate for this provider.
|
||||
func (p *Provider) ObtainCert() error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.cfg.Provider == ProviderPseudo {
|
||||
log.Info().Msg("init client for pseudo provider")
|
||||
p.logger.Info().Msg("init client for pseudo provider")
|
||||
<-time.After(time.Second)
|
||||
log.Info().Msg("registering acme for pseudo provider")
|
||||
p.logger.Info().Msg("registering acme for pseudo provider")
|
||||
<-time.After(time.Second)
|
||||
log.Info().Msg("obtained cert for pseudo provider")
|
||||
p.logger.Info().Msg("obtained cert for pseudo provider")
|
||||
return nil
|
||||
}
|
||||
|
||||
if lastFailure, err := p.GetLastFailure(); err != nil {
|
||||
return err
|
||||
} else if time.Since(lastFailure) < requestCooldownDuration {
|
||||
return fmt.Errorf("%w: still in cooldown until %s", ErrGetCertFailure, strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
|
||||
}
|
||||
|
||||
if p.client == nil {
|
||||
if err := p.initClient(); err != nil {
|
||||
return err
|
||||
@@ -227,6 +303,7 @@ func (p *Provider) obtainCertSelf() error {
|
||||
}
|
||||
p.tlsCert = &tlsCert
|
||||
p.certExpiries = expiries
|
||||
p.rebuildSNIMatcher()
|
||||
|
||||
if err := p.ClearLastFailure(); err != nil {
|
||||
return fmt.Errorf("failed to clear last failure: %w", err)
|
||||
@@ -235,19 +312,37 @@ func (p *Provider) obtainCertSelf() error {
|
||||
}
|
||||
|
||||
func (p *Provider) LoadCert() error {
|
||||
var errs gperr.Builder
|
||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load SSL certificate: %w", err)
|
||||
errs.Addf("load SSL certificate: %w", p.fmtError(err))
|
||||
}
|
||||
|
||||
expiries, err := getCertExpiries(&cert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse SSL certificate: %w", err)
|
||||
errs.Addf("parse SSL certificate: %w", p.fmtError(err))
|
||||
}
|
||||
|
||||
p.tlsCert = &cert
|
||||
p.certExpiries = expiries
|
||||
|
||||
log.Info().Msgf("next cert renewal in %s", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
||||
return p.renewIfNeeded()
|
||||
for _, ep := range p.extraProviders {
|
||||
if err := ep.LoadCert(); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
}
|
||||
|
||||
p.rebuildSNIMatcher()
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
|
||||
func (p *Provider) PrintCertExpiriesAll() {
|
||||
for _, provider := range p.allProviders() {
|
||||
for domain, expiry := range provider.certExpiries {
|
||||
p.logger.Info().Str("domain", domain).Msgf("certificate expire on %s", strutils.FormatTime(expiry))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldRenewOn returns the time at which the certificate should be renewed.
|
||||
@@ -255,65 +350,129 @@ func (p *Provider) ShouldRenewOn() time.Time {
|
||||
for _, expiry := range p.certExpiries {
|
||||
return expiry.AddDate(0, -1, 0) // 1 month before
|
||||
}
|
||||
// this line should never be reached
|
||||
panic("no certificate available")
|
||||
// this line should never be reached in production, but will be useful for testing
|
||||
return time.Now().AddDate(0, 1, 0) // 1 month after
|
||||
}
|
||||
|
||||
func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||
// ForceExpiryAll triggers immediate certificate renewal for this provider and all extra providers.
|
||||
// Returns true if the renewal was triggered, false if the renewal was dropped.
|
||||
//
|
||||
// If at least one renewal is triggered, returns true.
|
||||
func (p *Provider) ForceExpiryAll() (ok bool) {
|
||||
doneCh := make(chan struct{})
|
||||
if swapped := p.forceRenewalDoneCh.CompareAndSwap(nil, doneCh); !swapped { // already in progress
|
||||
close(doneCh)
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case p.forceRenewalCh <- struct{}{}:
|
||||
ok = true
|
||||
default:
|
||||
}
|
||||
|
||||
for _, ep := range p.extraProviders {
|
||||
if ep.ForceExpiryAll() {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// WaitRenewalDone waits for the renewal to complete.
|
||||
// Returns false if the renewal was dropped.
|
||||
func (p *Provider) WaitRenewalDone(ctx context.Context) bool {
|
||||
done, ok := p.forceRenewalDoneCh.Load().(chan struct{})
|
||||
if !ok || done == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
|
||||
for _, ep := range p.extraProviders {
|
||||
if !ep.WaitRenewalDone(ctx) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ScheduleRenewalAll schedules the renewal of the certificate for this provider and all extra providers.
|
||||
func (p *Provider) ScheduleRenewalAll(parent task.Parent) {
|
||||
p.scheduleRenewalOnce.Do(func() {
|
||||
p.scheduleRenewal(parent)
|
||||
})
|
||||
for _, ep := range p.extraProviders {
|
||||
ep.scheduleRenewalOnce.Do(func() {
|
||||
ep.scheduleRenewal(parent)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var emptyForceRenewalDoneCh any = chan struct{}(nil)
|
||||
|
||||
// scheduleRenewal schedules the renewal of the certificate for this provider.
|
||||
func (p *Provider) scheduleRenewal(parent task.Parent) {
|
||||
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
renewalTime := p.ShouldRenewOn()
|
||||
timer := time.NewTimer(time.Until(renewalTime))
|
||||
defer timer.Stop()
|
||||
|
||||
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
|
||||
timer := time.NewTimer(time.Until(p.ShouldRenewOn()))
|
||||
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
|
||||
|
||||
renew := func(renewMode RenewMode) {
|
||||
defer func() {
|
||||
if done, ok := p.forceRenewalDoneCh.Swap(emptyForceRenewalDoneCh).(chan struct{}); ok && done != nil {
|
||||
close(done)
|
||||
}
|
||||
}()
|
||||
|
||||
renewed, err := p.renew(renewMode)
|
||||
if err != nil {
|
||||
gperr.LogWarn("autocert: cert renew failed", p.fmtError(err))
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: zerolog.ErrorLevel,
|
||||
Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()),
|
||||
Body: notif.MessageBody(err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
if renewed {
|
||||
p.rebuildSNIMatcher()
|
||||
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: zerolog.InfoLevel,
|
||||
Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()),
|
||||
Body: notif.ListBody(p.cfg.Domains),
|
||||
})
|
||||
|
||||
// Reset on success
|
||||
if err := p.ClearLastFailure(); err != nil {
|
||||
gperr.LogWarn("autocert: failed to clear last failure", p.fmtError(err))
|
||||
}
|
||||
timer.Reset(time.Until(p.ShouldRenewOn()))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer timer.Stop()
|
||||
defer task.Finish(nil)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
case <-p.forceRenewalCh:
|
||||
renew(renewModeForce)
|
||||
case <-timer.C:
|
||||
// Retry after 1 hour on failure
|
||||
lastFailure, err := p.GetLastFailure()
|
||||
if err != nil {
|
||||
gperr.LogWarn("autocert: failed to get last failure", err)
|
||||
continue
|
||||
}
|
||||
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
|
||||
continue
|
||||
}
|
||||
if err := p.renewIfNeeded(); err != nil {
|
||||
gperr.LogWarn("autocert: cert renew failed", err)
|
||||
if err := p.UpdateLastFailure(); err != nil {
|
||||
gperr.LogWarn("autocert: failed to update last failure", err)
|
||||
}
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: zerolog.ErrorLevel,
|
||||
Title: "SSL certificate renewal failed",
|
||||
Body: notif.MessageBody(err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
notif.Notify(¬if.LogMessage{
|
||||
Level: zerolog.InfoLevel,
|
||||
Title: "SSL certificate renewed",
|
||||
Body: notif.ListBody(p.cfg.Domains),
|
||||
})
|
||||
// Reset on success
|
||||
if err := p.ClearLastFailure(); err != nil {
|
||||
gperr.LogWarn("autocert: failed to clear last failure", err)
|
||||
}
|
||||
renewalTime = p.ShouldRenewOn()
|
||||
timer.Reset(time.Until(renewalTime))
|
||||
renew(renewModeIfNeeded)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for _, ep := range p.extraProviders {
|
||||
ep.ScheduleRenewal(parent)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) initClient() error {
|
||||
@@ -409,21 +568,42 @@ func (p *Provider) certState() CertState {
|
||||
return CertStateValid
|
||||
}
|
||||
|
||||
func (p *Provider) renewIfNeeded() error {
|
||||
func (p *Provider) renew(mode RenewMode) (renewed bool, err error) {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch p.certState() {
|
||||
case CertStateExpired:
|
||||
log.Info().Msg("certs expired, renewing")
|
||||
case CertStateMismatch:
|
||||
log.Info().Msg("cert domains mismatch with config, renewing")
|
||||
default:
|
||||
return nil
|
||||
if mode != renewModeForce {
|
||||
// Retry after 1 hour on failure
|
||||
lastFailure, err := p.GetLastFailure()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get last failure: %w", err)
|
||||
}
|
||||
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
|
||||
until := lastFailure.Add(renewalCooldownDuration).Local()
|
||||
return false, fmt.Errorf("still in cooldown until %s", strutils.FormatTime(until))
|
||||
}
|
||||
}
|
||||
|
||||
return p.obtainCertSelf()
|
||||
if mode == renewModeIfNeeded {
|
||||
switch p.certState() {
|
||||
case CertStateExpired:
|
||||
log.Info().Msg("certs expired, renewing")
|
||||
case CertStateMismatch:
|
||||
log.Info().Msg("cert domains mismatch with config, renewing")
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if mode == renewModeForce {
|
||||
log.Info().Msg("force renewing cert by user request")
|
||||
}
|
||||
|
||||
if err := p.ObtainCert(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
||||
@@ -445,15 +625,16 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
||||
}
|
||||
|
||||
func lastFailureFileFor(certPath, keyPath string) string {
|
||||
if certPath == "" && keyPath == "" {
|
||||
return LastFailureFile
|
||||
}
|
||||
dir := filepath.Dir(certPath)
|
||||
sum := sha256.Sum256([]byte(certPath + "|" + keyPath))
|
||||
return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6]))
|
||||
}
|
||||
|
||||
func (p *Provider) rebuildSNIMatcher() {
|
||||
if p.cfg.idx != 0 { // only main provider has extra providers
|
||||
return
|
||||
}
|
||||
|
||||
p.sniMatcher = sniMatcher{}
|
||||
p.sniMatcher.addProvider(p)
|
||||
for _, ep := range p.extraProviders {
|
||||
|
||||
Reference in New Issue
Block a user