fix(autocert): ensure extra certificate registration and renewal scheduling

Extra providers were not being properly initialized during NewProvider(),
causing certificate registration and renewal scheduling to be skipped.

- Add ConfigExtra type with idx field for provider indexing
- Add MergeExtraConfig() for inheriting main provider settings
- Add setupExtraProviders() for recursive extra provider initialization
- Refactor NewProvider to return error and call setupExtraProviders()
- Add provider-scoped logger with "main" or "extra[N]" name
- Add batch operations: ObtainCertIfNotExistsAll(), ObtainCertAll()
- Add ForceExpiryAll() with completion tracking via WaitRenewalDone()
- Add RenewMode (force/ifNeeded) for controlling renewal behavior
- Add PrintCertExpiriesAll() for logging all provider certificate expiries

Summary of staged changes:
- config.go: Added ConfigExtra type, MergeExtraConfig(), recursive validation with path uniqueness checking
- provider.go: Added provider indexing, scoped logger, batch cert operations, force renewal with completion tracking, RenewMode control
- setup.go: New file with setupExtraProviders() for proper extra provider initialization
- setup_test.go: New tests for extra provider setup
- multi_cert_test.go: New tests for multi-certificate functionality
- renew.go: Updated to use new provider API with error handling
- state.go: Updated to handle NewProvider error return
This commit is contained in:
yusing
2026-01-04 20:30:58 +08:00
parent 11d0c61b9c
commit 2835fd5fb0
14 changed files with 1103 additions and 700 deletions

View File

@@ -9,7 +9,6 @@ import (
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/logging/memlogger"
apitypes "github.com/yusing/goutils/apitypes"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/http/websocket"
)
@@ -40,33 +39,33 @@ func Renew(c *gin.Context) {
logs, cancel := memlogger.Events()
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
// Stream logs until WebSocket connection closes (renewal runs in background)
for {
select {
case <-manager.Context().Done():
return
case l := <-logs:
if err != nil {
return
}
err = autocert.ObtainCert()
if err != nil {
gperr.LogError("failed to obtain cert", err)
_ = manager.WriteData(websocket.TextMessage, []byte(err.Error()), 10*time.Second)
} else {
log.Info().Msg("cert obtained successfully")
err = manager.WriteData(websocket.TextMessage, l, 10*time.Second)
if err != nil {
return
}
}
}
}()
for {
select {
case l := <-logs:
if err != nil {
return
}
err = manager.WriteData(websocket.TextMessage, l, 10*time.Second)
if err != nil {
return
}
case <-done:
return
}
// renewal happens in background
ok := autocert.ForceExpiryAll()
if !ok {
log.Error().Msg("cert renewal already in progress")
time.Sleep(1 * time.Second) // wait for the log above to be sent
return
}
log.Info().Msg("cert force renewal requested")
autocert.WaitRenewalDone(manager.Context())
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"fmt"
"net/http"
"os"
"regexp"
@@ -19,13 +20,14 @@ import (
strutils "github.com/yusing/goutils/strings"
)
type ConfigExtra Config
type Config struct {
Email string `json:"email,omitempty"`
Domains []string `json:"domains,omitempty"`
CertPath string `json:"cert_path,omitempty"`
KeyPath string `json:"key_path,omitempty"`
Extra []Config `json:"extra,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"`
Extra []ConfigExtra `json:"extra,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"` // shared by all extra providers
Provider string `json:"provider,omitempty"`
Options map[string]strutils.Redacted `json:"options,omitempty"`
@@ -42,15 +44,12 @@ type Config struct {
HTTPClient *http.Client `json:"-"` // for tests only
challengeProvider challenge.Provider
idx int // 0: main, 1+: extra[i]
}
var (
ErrMissingDomain = gperr.New("missing field 'domains'")
ErrMissingEmail = gperr.New("missing field 'email'")
ErrMissingProvider = gperr.New("missing field 'provider'")
ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
ErrMissingCertPath = gperr.New("missing field 'cert_path'")
ErrMissingKeyPath = gperr.New("missing field 'key_path'")
ErrMissingField = gperr.New("missing field")
ErrDuplicatedPath = gperr.New("duplicated path")
ErrInvalidDomain = gperr.New("invalid domain")
ErrUnknownProvider = gperr.New("unknown provider")
@@ -66,95 +65,22 @@ var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
// Validate implements the utils.CustomValidator interface.
func (cfg *Config) Validate() gperr.Error {
if cfg == nil {
return nil
}
seenPaths := make(map[string]int) // path -> provider idx (0 for main, 1+ for extras)
return cfg.validate(seenPaths)
}
func (cfg *ConfigExtra) Validate() gperr.Error {
return nil // done by main config's validate
}
func (cfg *ConfigExtra) AsConfig() *Config {
return (*Config)(cfg)
}
func (cfg *Config) validate(seenPaths map[string]int) gperr.Error {
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
}
b := gperr.NewBuilder("autocert errors")
if len(cfg.Extra) > 0 {
seenCertPaths := make(map[string]int, len(cfg.Extra))
seenKeyPaths := make(map[string]int, len(cfg.Extra))
for i := range cfg.Extra {
if cfg.Extra[i].CertPath == "" {
b.Add(ErrMissingCertPath.Subjectf("extra[%d].cert_path", i))
}
if cfg.Extra[i].KeyPath == "" {
b.Add(ErrMissingKeyPath.Subjectf("extra[%d].key_path", i))
}
if cfg.Extra[i].CertPath != "" {
if first, ok := seenCertPaths[cfg.Extra[i].CertPath]; ok {
b.Add(ErrDuplicatedPath.Subjectf("extra[%d].cert_path", i).Withf("first: %d", first))
} else {
seenCertPaths[cfg.Extra[i].CertPath] = i
}
}
if cfg.Extra[i].KeyPath != "" {
if first, ok := seenKeyPaths[cfg.Extra[i].KeyPath]; ok {
b.Add(ErrDuplicatedPath.Subjectf("extra[%d].key_path", i).Withf("first: %d", first))
} else {
seenKeyPaths[cfg.Extra[i].KeyPath] = i
}
}
}
}
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
b.Add(ErrMissingCADirURL)
}
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
if cfg.Provider != ProviderCustom {
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
}
// check if provider is implemented
providerConstructor, ok := Providers[cfg.Provider]
if !ok {
if cfg.Provider != ProviderCustom {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
With(gperr.DoYouMeanField(cfg.Provider, Providers)))
}
} else {
provider, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
} else {
cfg.challengeProvider = provider
}
}
}
if cfg.challengeProvider == nil {
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
}
return b.Error()
}
func (cfg *Config) dns01Options() []dns01.ChallengeOption {
return []dns01.ChallengeOption{
dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)),
}
}
func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
if err := cfg.Validate(); err != nil {
return nil, nil, err
}
if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault
}
@@ -165,6 +91,83 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
cfg.ACMEKeyPath = ACMEKeyFileDefault
}
b := gperr.NewBuilder("certificate error")
// check if cert_path is unique
if first, ok := seenPaths[cfg.CertPath]; ok {
b.Add(ErrDuplicatedPath.Subjectf("cert_path %s", cfg.CertPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first)))
} else {
seenPaths[cfg.CertPath] = cfg.idx
}
// check if key_path is unique
if first, ok := seenPaths[cfg.KeyPath]; ok {
b.Add(ErrDuplicatedPath.Subjectf("key_path %s", cfg.KeyPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first)))
} else {
seenPaths[cfg.KeyPath] = cfg.idx
}
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
b.Add(ErrMissingField.Subject("ca_dir_url"))
}
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingField.Subject("domains"))
}
if cfg.Email == "" {
b.Add(ErrMissingField.Subject("email"))
}
if cfg.Provider != ProviderCustom {
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
}
}
// check if provider is implemented
providerConstructor, ok := Providers[cfg.Provider]
if !ok {
if cfg.Provider != ProviderCustom {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
With(gperr.DoYouMeanField(cfg.Provider, Providers)))
}
} else {
provider, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
} else {
cfg.challengeProvider = provider
}
}
if cfg.challengeProvider == nil {
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
}
if len(cfg.Extra) > 0 {
for i := range cfg.Extra {
cfg.Extra[i] = MergeExtraConfig(cfg, &cfg.Extra[i])
cfg.Extra[i].AsConfig().idx = i + 1
err := cfg.Extra[i].AsConfig().validate(seenPaths)
if err != nil {
b.Add(err.Subjectf("extra[%d]", i))
}
}
}
return b.Error()
}
func (cfg *Config) dns01Options() []dns01.ChallengeOption {
return []dns01.ChallengeOption{
dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)),
}
}
func (cfg *Config) GetLegoConfig() (*User, *lego.Config, error) {
var privKey *ecdsa.PrivateKey
var err error
@@ -208,6 +211,46 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
return user, legoCfg, nil
}
func MergeExtraConfig(mainCfg *Config, extraCfg *ConfigExtra) ConfigExtra {
merged := ConfigExtra(*mainCfg)
merged.Extra = nil
merged.CertPath = extraCfg.CertPath
merged.KeyPath = extraCfg.KeyPath
// NOTE: Using same ACME key as main provider
if extraCfg.Provider != "" {
merged.Provider = extraCfg.Provider
}
if extraCfg.Email != "" {
merged.Email = extraCfg.Email
}
if len(extraCfg.Domains) > 0 {
merged.Domains = extraCfg.Domains
}
if len(extraCfg.Options) > 0 {
merged.Options = extraCfg.Options
}
if len(extraCfg.Resolvers) > 0 {
merged.Resolvers = extraCfg.Resolvers
}
if extraCfg.CADirURL != "" {
merged.CADirURL = extraCfg.CADirURL
}
if len(extraCfg.CACerts) > 0 {
merged.CACerts = extraCfg.CACerts
}
if extraCfg.EABKid != "" {
merged.EABKid = extraCfg.EABKid
}
if extraCfg.EABHmac != "" {
merged.EABHmac = extraCfg.EABHmac
}
if extraCfg.HTTPClient != nil {
merged.HTTPClient = extraCfg.HTTPClient
}
return merged
}
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
if common.IsTest {
return nil, os.ErrNotExist

View File

@@ -1,27 +1,32 @@
package autocert
package autocert_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/dnsproviders"
"github.com/yusing/godoxy/internal/serialization"
)
func TestEABConfigRequired(t *testing.T) {
dnsproviders.InitProviders()
tests := []struct {
name string
cfg *Config
cfg *autocert.Config
wantErr bool
}{
{name: "Missing EABKid", cfg: &Config{EABHmac: "1234567890"}, wantErr: true},
{name: "Missing EABHmac", cfg: &Config{EABKid: "1234567890"}, wantErr: true},
{name: "Valid EAB", cfg: &Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false},
{name: "Missing EABKid", cfg: &autocert.Config{EABHmac: "1234567890"}, wantErr: true},
{name: "Missing EABHmac", cfg: &autocert.Config{EABKid: "1234567890"}, wantErr: true},
{name: "Valid EAB", cfg: &autocert.Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
yaml := fmt.Appendf(nil, "eab_kid: %s\neab_hmac: %s", test.cfg.EABKid, test.cfg.EABHmac)
cfg := Config{}
cfg := autocert.Config{}
err := serialization.UnmarshalValidateYAML(yaml, &cfg)
if (err != nil) != test.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, test.wantErr)
@@ -29,3 +34,27 @@ func TestEABConfigRequired(t *testing.T) {
})
}
}
func TestExtraCertKeyPathsUnique(t *testing.T) {
t.Run("duplicate cert_path rejected", func(t *testing.T) {
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
Extra: []autocert.ConfigExtra{
{CertPath: "a.crt", KeyPath: "a.key"},
{CertPath: "a.crt", KeyPath: "b.key"},
},
}
require.Error(t, cfg.Validate())
})
t.Run("duplicate key_path rejected", func(t *testing.T) {
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
Extra: []autocert.ConfigExtra{
{CertPath: "a.crt", KeyPath: "a.key"},
{CertPath: "b.crt", KeyPath: "a.key"},
},
}
require.Error(t, cfg.Validate())
})
}

View File

@@ -5,5 +5,4 @@ const (
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
ACMEKeyFileDefault = certBasePath + "acme.key"
LastFailureFile = certBasePath + ".last_failure"
)

View File

@@ -1,16 +1,19 @@
package autocert
import (
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"maps"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -28,6 +31,8 @@ import (
type (
Provider struct {
logger zerolog.Logger
cfg *Config
user *User
legoCfg *lego.Config
@@ -42,12 +47,18 @@ type (
extraProviders []*Provider
sniMatcher sniMatcher
forceRenewalCh chan struct{}
forceRenewalDoneCh atomic.Value // chan struct{}
scheduleRenewalOnce sync.Once
}
CertExpiries map[string]time.Time
RenewMode uint8
)
var ErrGetCertFailure = errors.New("get certificate failed")
var ErrNoCertificate = errors.New("no certificate found")
const (
// renew failed for whatever reason, 1 hour cooldown
@@ -56,21 +67,36 @@ const (
requestCooldownDuration = 15 * time.Second
)
const (
renewModeForce = iota
renewModeIfNeeded
)
// could be nil
var ActiveProvider atomic.Pointer[Provider]
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider {
return &Provider{
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) (*Provider, error) {
p := &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath),
forceRenewalCh: make(chan struct{}, 1),
}
if cfg.idx == 0 {
p.logger = log.With().Str("provider", "main").Logger()
} else {
p.logger = log.With().Str("provider", fmt.Sprintf("extra[%d]", cfg.idx)).Logger()
}
if err := p.setupExtraProviders(); err != nil {
return nil, err
}
return p, nil
}
func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
return nil, ErrGetCertFailure
return nil, ErrNoCertificate
}
if hello == nil || hello.ServerName == "" {
return p.tlsCert, nil
@@ -82,7 +108,14 @@ func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
func (p *Provider) GetName() string {
return p.cfg.Provider
if p.cfg.idx == 0 {
return "main"
}
return fmt.Sprintf("extra[%d]", p.cfg.idx)
}
func (p *Provider) fmtError(err error) error {
return gperr.PrependSubject(fmt.Sprintf("provider: %s", p.GetName()), err)
}
func (p *Provider) GetCertPath() string {
@@ -129,45 +162,88 @@ func (p *Provider) ClearLastFailure() error {
return nil
}
p.lastFailure = time.Time{}
return os.Remove(p.lastFailureFile)
err := os.Remove(p.lastFailureFile)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
}
return nil
}
func (p *Provider) ObtainCert() error {
if len(p.extraProviders) > 0 {
errs := gperr.NewGroup("autocert errors")
errs.Go(p.obtainCertSelf)
for _, ep := range p.extraProviders {
errs.Go(ep.obtainCertSelf)
}
if err := errs.Wait().Error(); err != nil {
return err
}
p.rebuildSNIMatcher()
// allProviders returns all providers including this provider and all extra providers.
func (p *Provider) allProviders() []*Provider {
return append([]*Provider{p}, p.extraProviders...)
}
// ObtainCertIfNotExistsAll obtains a new certificate for this provider and all extra providers if they do not exist.
func (p *Provider) ObtainCertIfNotExistsAll() error {
errs := gperr.NewGroup("obtain cert error")
for _, provider := range p.allProviders() {
errs.Go(func() error {
if err := provider.obtainCertIfNotExists(); err != nil {
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
}
return nil
})
}
p.rebuildSNIMatcher()
return errs.Wait().Error()
}
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
func (p *Provider) obtainCertIfNotExists() error {
err := p.LoadCert()
if err == nil {
return nil
}
return p.obtainCertSelf()
if !errors.Is(err, fs.ErrNotExist) {
return err
}
// check last failure
lastFailure, err := p.GetLastFailure()
if err != nil {
return fmt.Errorf("failed to get last failure: %w", err)
}
if !lastFailure.IsZero() && time.Since(lastFailure) < requestCooldownDuration {
return fmt.Errorf("still in cooldown until %s", strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
}
p.logger.Info().Msg("cert not found, obtaining new cert")
return p.ObtainCert()
}
func (p *Provider) obtainCertSelf() error {
// ObtainCertAll renews existing certificates or obtains new certificates for this provider and all extra providers.
func (p *Provider) ObtainCertAll() error {
errs := gperr.NewGroup("obtain cert error")
for _, provider := range p.allProviders() {
errs.Go(func() error {
if err := provider.obtainCertIfNotExists(); err != nil {
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
}
return nil
})
}
return errs.Wait().Error()
}
// ObtainCert renews existing certificate or obtains a new certificate for this provider.
func (p *Provider) ObtainCert() error {
if p.cfg.Provider == ProviderLocal {
return nil
}
if p.cfg.Provider == ProviderPseudo {
log.Info().Msg("init client for pseudo provider")
p.logger.Info().Msg("init client for pseudo provider")
<-time.After(time.Second)
log.Info().Msg("registering acme for pseudo provider")
p.logger.Info().Msg("registering acme for pseudo provider")
<-time.After(time.Second)
log.Info().Msg("obtained cert for pseudo provider")
p.logger.Info().Msg("obtained cert for pseudo provider")
return nil
}
if lastFailure, err := p.GetLastFailure(); err != nil {
return err
} else if time.Since(lastFailure) < requestCooldownDuration {
return fmt.Errorf("%w: still in cooldown until %s", ErrGetCertFailure, strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
}
if p.client == nil {
if err := p.initClient(); err != nil {
return err
@@ -227,6 +303,7 @@ func (p *Provider) obtainCertSelf() error {
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
p.rebuildSNIMatcher()
if err := p.ClearLastFailure(); err != nil {
return fmt.Errorf("failed to clear last failure: %w", err)
@@ -235,19 +312,37 @@ func (p *Provider) obtainCertSelf() error {
}
func (p *Provider) LoadCert() error {
var errs gperr.Builder
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
if err != nil {
return fmt.Errorf("load SSL certificate: %w", err)
errs.Addf("load SSL certificate: %w", p.fmtError(err))
}
expiries, err := getCertExpiries(&cert)
if err != nil {
return fmt.Errorf("parse SSL certificate: %w", err)
errs.Addf("parse SSL certificate: %w", p.fmtError(err))
}
p.tlsCert = &cert
p.certExpiries = expiries
log.Info().Msgf("next cert renewal in %s", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
for _, ep := range p.extraProviders {
if err := ep.LoadCert(); err != nil {
errs.Add(err)
}
}
p.rebuildSNIMatcher()
return errs.Error()
}
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
func (p *Provider) PrintCertExpiriesAll() {
for _, provider := range p.allProviders() {
for domain, expiry := range provider.certExpiries {
p.logger.Info().Str("domain", domain).Msgf("certificate expire on %s", strutils.FormatTime(expiry))
}
}
}
// ShouldRenewOn returns the time at which the certificate should be renewed.
@@ -255,65 +350,129 @@ func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0) // 1 month before
}
// this line should never be reached
panic("no certificate available")
// this line should never be reached in production, but will be useful for testing
return time.Now().AddDate(0, 1, 0) // 1 month after
}
func (p *Provider) ScheduleRenewal(parent task.Parent) {
// ForceExpiryAll triggers immediate certificate renewal for this provider and all extra providers.
// Returns true if the renewal was triggered, false if the renewal was dropped.
//
// If at least one renewal is triggered, returns true.
func (p *Provider) ForceExpiryAll() (ok bool) {
doneCh := make(chan struct{})
if swapped := p.forceRenewalDoneCh.CompareAndSwap(nil, doneCh); !swapped { // already in progress
close(doneCh)
return false
}
select {
case p.forceRenewalCh <- struct{}{}:
ok = true
default:
}
for _, ep := range p.extraProviders {
if ep.ForceExpiryAll() {
ok = true
}
}
return ok
}
// WaitRenewalDone waits for the renewal to complete.
// Returns false if the renewal was dropped.
func (p *Provider) WaitRenewalDone(ctx context.Context) bool {
done, ok := p.forceRenewalDoneCh.Load().(chan struct{})
if !ok || done == nil {
return false
}
select {
case <-done:
case <-ctx.Done():
return false
}
for _, ep := range p.extraProviders {
if !ep.WaitRenewalDone(ctx) {
return false
}
}
return true
}
// ScheduleRenewalAll schedules the renewal of the certificate for this provider and all extra providers.
func (p *Provider) ScheduleRenewalAll(parent task.Parent) {
p.scheduleRenewalOnce.Do(func() {
p.scheduleRenewal(parent)
})
for _, ep := range p.extraProviders {
ep.scheduleRenewalOnce.Do(func() {
ep.scheduleRenewal(parent)
})
}
}
var emptyForceRenewalDoneCh any = chan struct{}(nil)
// scheduleRenewal schedules the renewal of the certificate for this provider.
func (p *Provider) scheduleRenewal(parent task.Parent) {
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
return
}
go func() {
renewalTime := p.ShouldRenewOn()
timer := time.NewTimer(time.Until(renewalTime))
defer timer.Stop()
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
timer := time.NewTimer(time.Until(p.ShouldRenewOn()))
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
renew := func(renewMode RenewMode) {
defer func() {
if done, ok := p.forceRenewalDoneCh.Swap(emptyForceRenewalDoneCh).(chan struct{}); ok && done != nil {
close(done)
}
}()
renewed, err := p.renew(renewMode)
if err != nil {
gperr.LogWarn("autocert: cert renew failed", p.fmtError(err))
notif.Notify(&notif.LogMessage{
Level: zerolog.ErrorLevel,
Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()),
Body: notif.MessageBody(err.Error()),
})
return
}
if renewed {
p.rebuildSNIMatcher()
notif.Notify(&notif.LogMessage{
Level: zerolog.InfoLevel,
Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()),
Body: notif.ListBody(p.cfg.Domains),
})
// Reset on success
if err := p.ClearLastFailure(); err != nil {
gperr.LogWarn("autocert: failed to clear last failure", p.fmtError(err))
}
timer.Reset(time.Until(p.ShouldRenewOn()))
}
}
go func() {
defer timer.Stop()
defer task.Finish(nil)
for {
select {
case <-task.Context().Done():
return
case <-p.forceRenewalCh:
renew(renewModeForce)
case <-timer.C:
// Retry after 1 hour on failure
lastFailure, err := p.GetLastFailure()
if err != nil {
gperr.LogWarn("autocert: failed to get last failure", err)
continue
}
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
continue
}
if err := p.renewIfNeeded(); err != nil {
gperr.LogWarn("autocert: cert renew failed", err)
if err := p.UpdateLastFailure(); err != nil {
gperr.LogWarn("autocert: failed to update last failure", err)
}
notif.Notify(&notif.LogMessage{
Level: zerolog.ErrorLevel,
Title: "SSL certificate renewal failed",
Body: notif.MessageBody(err.Error()),
})
continue
}
notif.Notify(&notif.LogMessage{
Level: zerolog.InfoLevel,
Title: "SSL certificate renewed",
Body: notif.ListBody(p.cfg.Domains),
})
// Reset on success
if err := p.ClearLastFailure(); err != nil {
gperr.LogWarn("autocert: failed to clear last failure", err)
}
renewalTime = p.ShouldRenewOn()
timer.Reset(time.Until(renewalTime))
renew(renewModeIfNeeded)
}
}
}()
for _, ep := range p.extraProviders {
ep.ScheduleRenewal(parent)
}
}
func (p *Provider) initClient() error {
@@ -409,21 +568,42 @@ func (p *Provider) certState() CertState {
return CertStateValid
}
func (p *Provider) renewIfNeeded() error {
func (p *Provider) renew(mode RenewMode) (renewed bool, err error) {
if p.cfg.Provider == ProviderLocal {
return nil
return false, nil
}
switch p.certState() {
case CertStateExpired:
log.Info().Msg("certs expired, renewing")
case CertStateMismatch:
log.Info().Msg("cert domains mismatch with config, renewing")
default:
return nil
if mode != renewModeForce {
// Retry after 1 hour on failure
lastFailure, err := p.GetLastFailure()
if err != nil {
return false, fmt.Errorf("failed to get last failure: %w", err)
}
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
until := lastFailure.Add(renewalCooldownDuration).Local()
return false, fmt.Errorf("still in cooldown until %s", strutils.FormatTime(until))
}
}
return p.obtainCertSelf()
if mode == renewModeIfNeeded {
switch p.certState() {
case CertStateExpired:
log.Info().Msg("certs expired, renewing")
case CertStateMismatch:
log.Info().Msg("cert domains mismatch with config, renewing")
default:
return false, nil
}
}
if mode == renewModeForce {
log.Info().Msg("force renewing cert by user request")
}
if err := p.ObtainCert(); err != nil {
return false, err
}
return true, nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
@@ -445,15 +625,16 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
}
func lastFailureFileFor(certPath, keyPath string) string {
if certPath == "" && keyPath == "" {
return LastFailureFile
}
dir := filepath.Dir(certPath)
sum := sha256.Sum256([]byte(certPath + "|" + keyPath))
return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6]))
}
func (p *Provider) rebuildSNIMatcher() {
if p.cfg.idx != 0 { // only main provider has extra providers
return
}
p.sniMatcher = sniMatcher{}
p.sniMatcher.addProvider(p)
for _, ep := range p.extraProviders {

View File

@@ -10,12 +10,15 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"sort"
"strings"
"sync"
"testing"
"time"
@@ -24,6 +27,368 @@ import (
"github.com/yusing/godoxy/internal/dnsproviders"
)
// TestACMEServer implements a minimal ACME server for testing with request tracking.
type TestACMEServer struct {
server *httptest.Server
caCert *x509.Certificate
caKey *rsa.PrivateKey
clientCSRs map[string]*x509.CertificateRequest
orderDomains map[string][]string
authzDomains map[string]string
orderSeq int
certRequestCount map[string]int
renewalRequestCount map[string]int
mu sync.Mutex
}
func newTestACMEServer(t *testing.T) *TestACMEServer {
t.Helper()
// Generate CA certificate and key
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caCert, err := x509.ParseCertificate(caCertDER)
require.NoError(t, err)
acme := &TestACMEServer{
caCert: caCert,
caKey: caKey,
clientCSRs: make(map[string]*x509.CertificateRequest),
orderDomains: make(map[string][]string),
authzDomains: make(map[string]string),
orderSeq: 0,
certRequestCount: make(map[string]int),
renewalRequestCount: make(map[string]int),
}
mux := http.NewServeMux()
acme.setupRoutes(mux)
acme.server = httptest.NewUnstartedServer(mux)
acme.server.TLS = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{caCert.Raw},
PrivateKey: caKey,
},
},
MinVersion: tls.VersionTLS12,
}
acme.server.StartTLS()
return acme
}
func (s *TestACMEServer) Close() {
s.server.Close()
}
func (s *TestACMEServer) URL() string {
return s.server.URL
}
func (s *TestACMEServer) httpClient() *http.Client {
certPool := x509.NewCertPool()
certPool.AddCert(s.caCert)
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
},
},
}
}
func (s *TestACMEServer) setupRoutes(mux *http.ServeMux) {
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
mux.HandleFunc("/acme/chall/", s.handleChallenge)
mux.HandleFunc("/acme/order/", s.handleOrder)
mux.HandleFunc("/acme/cert/", s.handleCertificate)
}
func (s *TestACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
directory := map[string]any{
"newNonce": s.server.URL + "/acme/new-nonce",
"newAccount": s.server.URL + "/acme/new-account",
"newOrder": s.server.URL + "/acme/new-order",
"keyChange": s.server.URL + "/acme/key-change",
"meta": map[string]any{
"termsOfService": s.server.URL + "/terms",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(directory)
}
func (s *TestACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "test-nonce-12345")
w.WriteHeader(http.StatusOK)
}
func (s *TestACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
account := map[string]any{
"status": "valid",
"contact": []string{"mailto:test@example.com"},
"orders": s.server.URL + "/acme/orders",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/account/1")
w.Header().Set("Replay-Nonce", "test-nonce-67890")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(account)
}
func (s *TestACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var jws struct {
Payload string `json:"payload"`
}
json.Unmarshal(body, &jws)
payloadBytes, _ := base64.RawURLEncoding.DecodeString(jws.Payload)
var orderReq struct {
Identifiers []map[string]string `json:"identifiers"`
}
json.Unmarshal(payloadBytes, &orderReq)
domains := []string{}
for _, id := range orderReq.Identifiers {
domains = append(domains, id["value"])
}
sort.Strings(domains)
domainKey := strings.Join(domains, ",")
s.mu.Lock()
s.orderSeq++
orderID := fmt.Sprintf("test-order-%d", s.orderSeq)
authzID := fmt.Sprintf("test-authz-%d", s.orderSeq)
s.orderDomains[orderID] = domains
if len(domains) > 0 {
s.authzDomains[authzID] = domains[0]
}
s.certRequestCount[domainKey]++
s.mu.Unlock()
order := map[string]any{
"status": "ready",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": orderReq.Identifiers,
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
"finalize": s.server.URL + "/acme/order/" + orderID + "/finalize",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/order/"+orderID)
w.Header().Set("Replay-Nonce", "test-nonce-order")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(order)
}
func (s *TestACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authzID := strings.TrimPrefix(r.URL.Path, "/acme/authz/")
domain := s.authzDomains[authzID]
if domain == "" {
domain = "test.example.com"
}
authz := map[string]any{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifier": map[string]string{"type": "dns", "value": domain},
"challenges": []map[string]any{
{
"type": "dns-01",
"status": "valid",
"url": s.server.URL + "/acme/chall/test-chall-789",
"token": "test-token-abc123",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-authz")
json.NewEncoder(w).Encode(authz)
}
func (s *TestACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
challenge := map[string]any{
"type": "dns-01",
"status": "valid",
"url": r.URL.String(),
"token": "test-token-abc123",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-chall")
json.NewEncoder(w).Encode(challenge)
}
func (s *TestACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/finalize") {
s.handleFinalize(w, r)
return
}
orderID := strings.TrimPrefix(r.URL.Path, "/acme/order/")
domains := s.orderDomains[orderID]
if len(domains) == 0 {
domains = []string{"test.example.com"}
}
certURL := s.server.URL + "/acme/cert/" + orderID
order := map[string]any{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": func() []map[string]string {
out := make([]map[string]string, 0, len(domains))
for _, d := range domains {
out = append(out, map[string]string{"type": "dns", "value": d})
}
return out
}(),
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
json.NewEncoder(w).Encode(order)
}
func (s *TestACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request", http.StatusBadRequest)
return
}
csr, err := s.extractCSRFromJWS(body)
if err != nil {
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
return
}
orderID := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/acme/order/"), "/finalize")
s.mu.Lock()
s.clientCSRs[orderID] = csr
// Detect renewal: if we already have a certificate for these domains, it's a renewal
domains := csr.DNSNames
sort.Strings(domains)
domainKey := strings.Join(domains, ",")
if s.certRequestCount[domainKey] > 1 {
s.renewalRequestCount[domainKey]++
}
s.mu.Unlock()
certURL := s.server.URL + "/acme/cert/" + orderID
order := map[string]any{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": func() []map[string]string {
out := make([]map[string]string, 0, len(domains))
for _, d := range domains {
out = append(out, map[string]string{"type": "dns", "value": d})
}
return out
}(),
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
json.NewEncoder(w).Encode(order)
}
func (s *TestACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
var jws struct {
Payload string `json:"payload"`
}
if err := json.Unmarshal(jwsData, &jws); err != nil {
return nil, err
}
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, err
}
var finalizeReq struct {
CSR string `json:"csr"`
}
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
return nil, err
}
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
if err != nil {
return nil, err
}
return x509.ParseCertificateRequest(csrBytes)
}
func (s *TestACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
csr, exists := s.clientCSRs[orderID]
if !exists {
http.Error(w, "No CSR found for order", http.StatusBadRequest)
return
}
template := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Cert"},
Country: []string{"US"},
},
DNSNames: csr.DNSNames,
NotBefore: time.Now(),
NotAfter: time.Now().Add(90 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
w.Header().Set("Content-Type", "application/pem-certificate-chain")
w.Header().Set("Replay-Nonce", "test-nonce-cert")
w.Write(append(certPEM, caPEM...))
}
func TestMain(m *testing.M) {
dnsproviders.InitProviders()
m.Run()
@@ -41,7 +406,7 @@ func TestCustomProvider(t *testing.T) {
ACMEKeyPath: "certs/custom-acme.key",
}
err := cfg.Validate()
err := error(cfg.Validate())
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
@@ -62,7 +427,8 @@ func TestCustomProvider(t *testing.T) {
err := cfg.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
require.Contains(t, err.Error(), "missing field")
require.Contains(t, err.Error(), "ca_dir_url")
})
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
@@ -76,7 +442,7 @@ func TestCustomProvider(t *testing.T) {
ACMEKeyPath: "certs/internal-acme.key",
}
err := cfg.Validate()
err := error(cfg.Validate())
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
@@ -86,9 +452,10 @@ func TestCustomProvider(t *testing.T) {
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
require.Equal(t, "admin@internal.com", user.Email)
provider := autocert.NewProvider(cfg, user, legoCfg)
provider, err := autocert.NewProvider(cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
require.Equal(t, autocert.ProviderCustom, provider.GetName())
require.Equal(t, "main", provider.GetName())
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
})
@@ -119,7 +486,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
require.NotNil(t, user)
require.NotNil(t, legoCfg)
provider := autocert.NewProvider(cfg, user, legoCfg)
provider, err := autocert.NewProvider(cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
// Test obtaining certificate
@@ -161,7 +529,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
require.NotNil(t, user)
require.NotNil(t, legoCfg)
provider := autocert.NewProvider(cfg, user, legoCfg)
provider, err := autocert.NewProvider(cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
err = provider.ObtainCert()
@@ -178,330 +547,3 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
require.True(t, time.Now().After(x509Cert.NotBefore))
})
}
// testACMEServer implements a minimal ACME server for testing.
type testACMEServer struct {
server *httptest.Server
caCert *x509.Certificate
caKey *rsa.PrivateKey
clientCSRs map[string]*x509.CertificateRequest
orderID string
}
func newTestACMEServer(t *testing.T) *testACMEServer {
t.Helper()
// Generate CA certificate and key
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caCert, err := x509.ParseCertificate(caCertDER)
require.NoError(t, err)
acme := &testACMEServer{
caCert: caCert,
caKey: caKey,
clientCSRs: make(map[string]*x509.CertificateRequest),
orderID: "test-order-123",
}
mux := http.NewServeMux()
acme.setupRoutes(mux)
acme.server = httptest.NewUnstartedServer(mux)
acme.server.TLS = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{caCert.Raw},
PrivateKey: caKey,
},
},
MinVersion: tls.VersionTLS12,
}
acme.server.StartTLS()
return acme
}
func (s *testACMEServer) Close() {
s.server.Close()
}
func (s *testACMEServer) URL() string {
return s.server.URL
}
func (s *testACMEServer) httpClient() *http.Client {
certPool := x509.NewCertPool()
certPool.AddCert(s.caCert)
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
},
},
}
}
func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
// ACME directory endpoint
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
// ACME endpoints
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
mux.HandleFunc("/acme/chall/", s.handleChallenge)
mux.HandleFunc("/acme/order/", s.handleOrder)
mux.HandleFunc("/acme/cert/", s.handleCertificate)
}
func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
directory := map[string]interface{}{
"newNonce": s.server.URL + "/acme/new-nonce",
"newAccount": s.server.URL + "/acme/new-account",
"newOrder": s.server.URL + "/acme/new-order",
"keyChange": s.server.URL + "/acme/key-change",
"meta": map[string]interface{}{
"termsOfService": s.server.URL + "/terms",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(directory)
}
func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "test-nonce-12345")
w.WriteHeader(http.StatusOK)
}
func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
account := map[string]interface{}{
"status": "valid",
"contact": []string{"mailto:test@example.com"},
"orders": s.server.URL + "/acme/orders",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/account/1")
w.Header().Set("Replay-Nonce", "test-nonce-67890")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(account)
}
func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
authzID := "test-authz-456"
order := map[string]interface{}{
"status": "ready", // Skip pending state for simplicity
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
"finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
w.Header().Set("Replay-Nonce", "test-nonce-order")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authz := map[string]interface{}{
"status": "valid", // Skip challenge validation for simplicity
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifier": map[string]string{"type": "dns", "value": "test.example.com"},
"challenges": []map[string]interface{}{
{
"type": "dns-01",
"status": "valid",
"url": s.server.URL + "/acme/chall/test-chall-789",
"token": "test-token-abc123",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-authz")
json.NewEncoder(w).Encode(authz)
}
func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
challenge := map[string]interface{}{
"type": "dns-01",
"status": "valid",
"url": r.URL.String(),
"token": "test-token-abc123",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-chall")
json.NewEncoder(w).Encode(challenge)
}
func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/finalize") {
s.handleFinalize(w, r)
return
}
certURL := s.server.URL + "/acme/cert/" + s.orderID
order := map[string]interface{}{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
// Read the JWS payload
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request", http.StatusBadRequest)
return
}
// Extract CSR from JWS payload
csr, err := s.extractCSRFromJWS(body)
if err != nil {
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
return
}
// Store the CSR for certificate generation
s.clientCSRs[s.orderID] = csr
certURL := s.server.URL + "/acme/cert/" + s.orderID
order := map[string]interface{}{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
// Parse the JWS structure
var jws struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
if err := json.Unmarshal(jwsData, &jws); err != nil {
return nil, err
}
// Decode the payload
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, err
}
// Parse the finalize request
var finalizeReq struct {
CSR string `json:"csr"`
}
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
return nil, err
}
// Decode the CSR
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
if err != nil {
return nil, err
}
// Parse the CSR
csr, err := x509.ParseCertificateRequest(csrBytes)
if err != nil {
return nil, err
}
return csr, nil
}
func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
// Extract order ID from URL
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
// Get the CSR for this order
csr, exists := s.clientCSRs[orderID]
if !exists {
http.Error(w, "No CSR found for order", http.StatusBadRequest)
return
}
// Create certificate using the public key from the client's CSR
template := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Cert"},
Country: []string{"US"},
},
DNSNames: csr.DNSNames,
NotBefore: time.Now(),
NotAfter: time.Now().Add(90 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
// Use the public key from the CSR and sign with CA key
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Return certificate chain
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
w.Header().Set("Content-Type", "application/pem-certificate-chain")
w.Header().Set("Replay-Nonce", "test-nonce-cert")
w.Write(append(certPEM, caPEM...))
}

View File

@@ -1,32 +0,0 @@
package provider_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
)
func TestExtraCertKeyPathsUnique(t *testing.T) {
t.Run("duplicate cert_path rejected", func(t *testing.T) {
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
Extra: []autocert.Config{
{CertPath: "a.crt", KeyPath: "a.key"},
{CertPath: "a.crt", KeyPath: "b.key"},
},
}
require.Error(t, cfg.Validate())
})
t.Run("duplicate key_path rejected", func(t *testing.T) {
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
Extra: []autocert.Config{
{CertPath: "a.crt", KeyPath: "a.key"},
{CertPath: "b.crt", KeyPath: "a.key"},
},
}
require.Error(t, cfg.Validate())
})
}

View File

@@ -0,0 +1,90 @@
//nolint:errchkjson,errcheck
package provider_test
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/serialization"
"github.com/yusing/goutils/task"
)
func buildMultiCertYAML(serverURL string) []byte {
return fmt.Appendf(nil, `
email: main@example.com
domains: [main.example.com]
provider: custom
ca_dir_url: %s/acme/acme/directory
cert_path: certs/main.crt
key_path: certs/main.key
extra:
- email: extra1@example.com
domains: [extra1.example.com]
cert_path: certs/extra1.crt
key_path: certs/extra1.key
- email: extra2@example.com
domains: [extra2.example.com]
cert_path: certs/extra2.crt
key_path: certs/extra2.key
`, serverURL)
}
func TestMultipleCertificatesLifecycle(t *testing.T) {
acmeServer := newTestACMEServer(t)
defer acmeServer.Close()
yamlConfig := buildMultiCertYAML(acmeServer.URL())
var cfg autocert.Config
cfg.HTTPClient = acmeServer.httpClient()
/* unmarshal yaml config with multiple certs */
err := error(serialization.UnmarshalValidateYAML(yamlConfig, &cfg))
require.NoError(t, err)
require.Equal(t, []string{"main.example.com"}, cfg.Domains)
require.Len(t, cfg.Extra, 2)
require.Equal(t, []string{"extra1.example.com"}, cfg.Extra[0].Domains)
require.Equal(t, []string{"extra2.example.com"}, cfg.Extra[1].Domains)
var provider *autocert.Provider
/* initialize autocert with multi-cert config */
user, legoCfg, gerr := cfg.GetLegoConfig()
require.NoError(t, gerr)
provider, err = autocert.NewProvider(&cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
// Start renewal scheduler
root := task.RootTask("test", false)
defer root.Finish(nil)
provider.ScheduleRenewalAll(root)
require.Equal(t, "custom", cfg.Provider)
require.Equal(t, "custom", cfg.Extra[0].Provider)
require.Equal(t, "custom", cfg.Extra[1].Provider)
/* track cert requests for all configs */
os.MkdirAll("certs", 0755)
defer os.RemoveAll("certs")
err = provider.ObtainCertIfNotExistsAll()
require.NoError(t, err)
require.Equal(t, 1, acmeServer.certRequestCount["main.example.com"])
require.Equal(t, 1, acmeServer.certRequestCount["extra1.example.com"])
require.Equal(t, 1, acmeServer.certRequestCount["extra2.example.com"])
/* track renewal scheduling and requests */
// force renewal for all providers and wait for completion
ok := provider.ForceExpiryAll()
require.True(t, ok)
provider.WaitRenewalDone(t.Context())
require.Equal(t, 1, acmeServer.renewalRequestCount["main.example.com"])
require.Equal(t, 1, acmeServer.renewalRequestCount["extra1.example.com"])
require.Equal(t, 1, acmeServer.renewalRequestCount["extra2.example.com"])
}

View File

@@ -71,15 +71,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"})
require.NoError(t, err)
@@ -100,15 +103,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
require.NoError(t, err)
@@ -129,15 +135,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"})
require.NoError(t, err)
@@ -159,8 +168,11 @@ func TestGetCertBySNI(t *testing.T) {
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(nil)
require.NoError(t, err)
@@ -182,8 +194,11 @@ func TestGetCertBySNI(t *testing.T) {
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""})
require.NoError(t, err)
@@ -204,15 +219,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"})
require.NoError(t, err)
@@ -233,15 +251,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "})
require.NoError(t, err)
@@ -262,15 +283,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert1, KeyPath: extraKey1},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"})
require.NoError(t, err)
@@ -292,8 +316,11 @@ func TestGetCertBySNI(t *testing.T) {
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
require.NoError(t, err)
@@ -317,7 +344,7 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert1, KeyPath: extraKey1},
{CertPath: extraCert2, KeyPath: extraKey2},
},
@@ -325,8 +352,11 @@ func TestGetCertBySNI(t *testing.T) {
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"})
require.NoError(t, err)
@@ -352,15 +382,18 @@ func TestGetCertBySNI(t *testing.T) {
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.Config{
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, p.Setup())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
require.NoError(t, err)

View File

@@ -1,101 +1,30 @@
package autocert
import (
"errors"
"fmt"
"os"
"github.com/rs/zerolog/log"
gperr "github.com/yusing/goutils/errs"
strutils "github.com/yusing/goutils/strings"
)
func (p *Provider) Setup() (err error) {
if err = p.LoadCert(); err != nil {
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
log.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
}
if err = p.setupExtraProviders(); err != nil {
return err
}
for _, expiry := range p.GetExpiries() {
log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
break
}
return nil
}
func (p *Provider) setupExtraProviders() error {
p.extraProviders = nil
func (p *Provider) setupExtraProviders() gperr.Error {
p.sniMatcher = sniMatcher{}
if len(p.cfg.Extra) == 0 {
p.rebuildSNIMatcher()
return nil
}
for i := range p.cfg.Extra {
merged := mergeExtraConfig(p.cfg, &p.cfg.Extra[i])
user, legoCfg, err := merged.GetLegoConfig()
p.extraProviders = make([]*Provider, 0, len(p.cfg.Extra))
errs := gperr.NewBuilder("setup extra providers error")
for _, extra := range p.cfg.Extra {
user, legoCfg, err := extra.AsConfig().GetLegoConfig()
if err != nil {
return err.Subjectf("extra[%d]", i)
errs.Add(p.fmtError(err))
continue
}
ep := NewProvider(&merged, user, legoCfg)
if err := ep.Setup(); err != nil {
return gperr.PrependSubject(fmt.Sprintf("extra[%d]", i), err)
ep, err := NewProvider(extra.AsConfig(), user, legoCfg)
if err != nil {
errs.Add(p.fmtError(err))
continue
}
p.extraProviders = append(p.extraProviders, ep)
}
p.rebuildSNIMatcher()
return nil
}
func mergeExtraConfig(mainCfg *Config, extraCfg *Config) Config {
merged := *mainCfg
merged.Extra = nil
merged.CertPath = extraCfg.CertPath
merged.KeyPath = extraCfg.KeyPath
if merged.Email == "" {
merged.Email = mainCfg.Email
}
if len(extraCfg.Domains) > 0 {
merged.Domains = extraCfg.Domains
}
if extraCfg.ACMEKeyPath != "" {
merged.ACMEKeyPath = extraCfg.ACMEKeyPath
}
if extraCfg.Provider != "" {
merged.Provider = extraCfg.Provider
}
if len(extraCfg.Options) > 0 {
merged.Options = extraCfg.Options
}
if len(extraCfg.Resolvers) > 0 {
merged.Resolvers = extraCfg.Resolvers
}
if extraCfg.CADirURL != "" {
merged.CADirURL = extraCfg.CADirURL
}
if len(extraCfg.CACerts) > 0 {
merged.CACerts = extraCfg.CACerts
}
if extraCfg.EABKid != "" {
merged.EABKid = extraCfg.EABKid
}
if extraCfg.EABHmac != "" {
merged.EABHmac = extraCfg.EABHmac
}
if extraCfg.HTTPClient != nil {
merged.HTTPClient = extraCfg.HTTPClient
}
return merged
return errs.Error()
}

View File

@@ -0,0 +1,82 @@
package autocert_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/dnsproviders"
"github.com/yusing/godoxy/internal/serialization"
strutils "github.com/yusing/goutils/strings"
)
func TestSetupExtraProviders(t *testing.T) {
dnsproviders.InitProviders()
cfgYAML := `
email: test@example.com
domains: [example.com]
provider: custom
ca_dir_url: https://ca.example.com:9000/acme/acme/directory
cert_path: certs/test.crt
key_path: certs/test.key
options: {key: value}
resolvers: [8.8.8.8]
ca_certs: [ca.crt]
eab_kid: eabKid
eab_hmac: eabHmac
extra:
- cert_path: certs/extra.crt
key_path: certs/extra.key
- cert_path: certs/extra2.crt
key_path: certs/extra2.key
email: override@example.com
provider: pseudo
domains: [override.com]
ca_dir_url: https://ca2.example.com/directory
options: {opt2: val2}
resolvers: [1.1.1.1]
ca_certs: [ca2.crt]
eab_kid: eabKid2
eab_hmac: eabHmac2
`
var cfg autocert.Config
err := error(serialization.UnmarshalValidateYAML([]byte(cfgYAML), &cfg))
require.NoError(t, err)
// Test: extra[0] inherits all fields from main except CertPath and KeyPath.
merged0 := cfg.Extra[0]
require.Equal(t, "certs/extra.crt", merged0.CertPath)
require.Equal(t, "certs/extra.key", merged0.KeyPath)
// Inherited fields from main config:
require.Equal(t, "test@example.com", merged0.Email) // inherited
require.Equal(t, "custom", merged0.Provider) // inherited
require.Equal(t, []string{"example.com"}, merged0.Domains) // inherited
require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", merged0.CADirURL) // inherited
require.Equal(t, map[string]strutils.Redacted{"key": "value"}, merged0.Options) // inherited
require.Equal(t, []string{"8.8.8.8"}, merged0.Resolvers) // inherited
require.Equal(t, []string{"ca.crt"}, merged0.CACerts) // inherited
require.Equal(t, "eabKid", merged0.EABKid) // inherited
require.Equal(t, "eabHmac", merged0.EABHmac) // inherited
require.Equal(t, cfg.HTTPClient, merged0.HTTPClient) // inherited
require.Nil(t, merged0.Extra)
// Test: extra[1] overrides some fields, and inherits others.
merged1 := cfg.Extra[1]
require.Equal(t, "certs/extra2.crt", merged1.CertPath)
require.Equal(t, "certs/extra2.key", merged1.KeyPath)
// Overridden fields:
require.Equal(t, "override@example.com", merged1.Email) // overridden
require.Equal(t, "pseudo", merged1.Provider) // overridden
require.Equal(t, []string{"override.com"}, merged1.Domains) // overridden
require.Equal(t, "https://ca2.example.com/directory", merged1.CADirURL) // overridden
require.Equal(t, map[string]strutils.Redacted{"opt2": "val2"}, merged1.Options) // overridden
require.Equal(t, []string{"1.1.1.1"}, merged1.Resolvers) // overridden
require.Equal(t, []string{"ca2.crt"}, merged1.CACerts) // overridden
require.Equal(t, "eabKid2", merged1.EABKid) // overridden
require.Equal(t, "eabHmac2", merged1.EABHmac) // overridden
// Inherited field:
require.Equal(t, cfg.HTTPClient, merged1.HTTPClient) // inherited
require.Nil(t, merged1.Extra)
}

View File

@@ -9,6 +9,6 @@ import (
type Provider interface {
Setup() error
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
ScheduleRenewal(task.Parent)
ObtainCert() error
ScheduleRenewalAll(task.Parent)
ObtainCertAll() error
}

View File

@@ -272,6 +272,7 @@ func (state *state) initAutoCert() error {
autocertCfg := state.AutoCert
if autocertCfg == nil {
autocertCfg = new(autocert.Config)
_ = autocertCfg.Validate()
}
user, legoCfg, err := autocertCfg.GetLegoConfig()
@@ -279,12 +280,19 @@ func (state *state) initAutoCert() error {
return err
}
state.autocertProvider = autocert.NewProvider(autocertCfg, user, legoCfg)
if err := state.autocertProvider.Setup(); err != nil {
return fmt.Errorf("autocert error: %w", err)
} else {
state.autocertProvider.ScheduleRenewal(state.task)
p, err := autocert.NewProvider(autocertCfg, user, legoCfg)
if err != nil {
return err
}
if err := p.ObtainCertIfNotExistsAll(); err != nil {
return err
}
p.ScheduleRenewalAll(state.task)
p.PrintCertExpiriesAll()
state.autocertProvider = p
return nil
}

View File

@@ -1,7 +1,7 @@
package dnsproviders
type (
DummyConfig struct{}
DummyConfig map[string]any
DummyProvider struct{}
)