mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
fix(autocert): ensure extra certificate registration and renewal scheduling
Extra providers were not being properly initialized during NewProvider(), causing certificate registration and renewal scheduling to be skipped. - Add ConfigExtra type with idx field for provider indexing - Add MergeExtraConfig() for inheriting main provider settings - Add setupExtraProviders() for recursive extra provider initialization - Refactor NewProvider to return error and call setupExtraProviders() - Add provider-scoped logger with "main" or "extra[N]" name - Add batch operations: ObtainCertIfNotExistsAll(), ObtainCertAll() - Add ForceExpiryAll() with completion tracking via WaitRenewalDone() - Add RenewMode (force/ifNeeded) for controlling renewal behavior - Add PrintCertExpiriesAll() for logging all provider certificate expiries Summary of staged changes: - config.go: Added ConfigExtra type, MergeExtraConfig(), recursive validation with path uniqueness checking - provider.go: Added provider indexing, scoped logger, batch cert operations, force renewal with completion tracking, RenewMode control - setup.go: New file with setupExtraProviders() for proper extra provider initialization - setup_test.go: New tests for extra provider setup - multi_cert_test.go: New tests for multi-certificate functionality - renew.go: Updated to use new provider API with error handling - state.go: Updated to handle NewProvider error return
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/yusing/godoxy/internal/autocert"
|
"github.com/yusing/godoxy/internal/autocert"
|
||||||
"github.com/yusing/godoxy/internal/logging/memlogger"
|
"github.com/yusing/godoxy/internal/logging/memlogger"
|
||||||
apitypes "github.com/yusing/goutils/apitypes"
|
apitypes "github.com/yusing/goutils/apitypes"
|
||||||
gperr "github.com/yusing/goutils/errs"
|
|
||||||
"github.com/yusing/goutils/http/websocket"
|
"github.com/yusing/goutils/http/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,33 +39,33 @@ func Renew(c *gin.Context) {
|
|||||||
logs, cancel := memlogger.Events()
|
logs, cancel := memlogger.Events()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(done)
|
// Stream logs until WebSocket connection closes (renewal runs in background)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-manager.Context().Done():
|
||||||
|
return
|
||||||
|
case l := <-logs:
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err = autocert.ObtainCert()
|
err = manager.WriteData(websocket.TextMessage, l, 10*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
gperr.LogError("failed to obtain cert", err)
|
return
|
||||||
_ = manager.WriteData(websocket.TextMessage, []byte(err.Error()), 10*time.Second)
|
}
|
||||||
} else {
|
}
|
||||||
log.Info().Msg("cert obtained successfully")
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
// renewal happens in background
|
||||||
select {
|
ok := autocert.ForceExpiryAll()
|
||||||
case l := <-logs:
|
if !ok {
|
||||||
if err != nil {
|
log.Error().Msg("cert renewal already in progress")
|
||||||
return
|
time.Sleep(1 * time.Second) // wait for the log above to be sent
|
||||||
}
|
return
|
||||||
|
|
||||||
err = manager.WriteData(websocket.TextMessage, l, 10*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
log.Info().Msg("cert force renewal requested")
|
||||||
|
|
||||||
|
autocert.WaitRenewalDone(manager.Context())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -19,13 +20,14 @@ import (
|
|||||||
strutils "github.com/yusing/goutils/strings"
|
strutils "github.com/yusing/goutils/strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ConfigExtra Config
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
Domains []string `json:"domains,omitempty"`
|
Domains []string `json:"domains,omitempty"`
|
||||||
CertPath string `json:"cert_path,omitempty"`
|
CertPath string `json:"cert_path,omitempty"`
|
||||||
KeyPath string `json:"key_path,omitempty"`
|
KeyPath string `json:"key_path,omitempty"`
|
||||||
Extra []Config `json:"extra,omitempty"`
|
Extra []ConfigExtra `json:"extra,omitempty"`
|
||||||
ACMEKeyPath string `json:"acme_key_path,omitempty"`
|
ACMEKeyPath string `json:"acme_key_path,omitempty"` // shared by all extra providers
|
||||||
Provider string `json:"provider,omitempty"`
|
Provider string `json:"provider,omitempty"`
|
||||||
Options map[string]strutils.Redacted `json:"options,omitempty"`
|
Options map[string]strutils.Redacted `json:"options,omitempty"`
|
||||||
|
|
||||||
@@ -42,15 +44,12 @@ type Config struct {
|
|||||||
HTTPClient *http.Client `json:"-"` // for tests only
|
HTTPClient *http.Client `json:"-"` // for tests only
|
||||||
|
|
||||||
challengeProvider challenge.Provider
|
challengeProvider challenge.Provider
|
||||||
|
|
||||||
|
idx int // 0: main, 1+: extra[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingDomain = gperr.New("missing field 'domains'")
|
ErrMissingField = gperr.New("missing field")
|
||||||
ErrMissingEmail = gperr.New("missing field 'email'")
|
|
||||||
ErrMissingProvider = gperr.New("missing field 'provider'")
|
|
||||||
ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
|
|
||||||
ErrMissingCertPath = gperr.New("missing field 'cert_path'")
|
|
||||||
ErrMissingKeyPath = gperr.New("missing field 'key_path'")
|
|
||||||
ErrDuplicatedPath = gperr.New("duplicated path")
|
ErrDuplicatedPath = gperr.New("duplicated path")
|
||||||
ErrInvalidDomain = gperr.New("invalid domain")
|
ErrInvalidDomain = gperr.New("invalid domain")
|
||||||
ErrUnknownProvider = gperr.New("unknown provider")
|
ErrUnknownProvider = gperr.New("unknown provider")
|
||||||
@@ -66,95 +65,22 @@ var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
|||||||
|
|
||||||
// Validate implements the utils.CustomValidator interface.
|
// Validate implements the utils.CustomValidator interface.
|
||||||
func (cfg *Config) Validate() gperr.Error {
|
func (cfg *Config) Validate() gperr.Error {
|
||||||
if cfg == nil {
|
seenPaths := make(map[string]int) // path -> provider idx (0 for main, 1+ for extras)
|
||||||
return nil
|
return cfg.validate(seenPaths)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cfg *ConfigExtra) Validate() gperr.Error {
|
||||||
|
return nil // done by main config's validate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *ConfigExtra) AsConfig() *Config {
|
||||||
|
return (*Config)(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) validate(seenPaths map[string]int) gperr.Error {
|
||||||
if cfg.Provider == "" {
|
if cfg.Provider == "" {
|
||||||
cfg.Provider = ProviderLocal
|
cfg.Provider = ProviderLocal
|
||||||
}
|
}
|
||||||
|
|
||||||
b := gperr.NewBuilder("autocert errors")
|
|
||||||
if len(cfg.Extra) > 0 {
|
|
||||||
seenCertPaths := make(map[string]int, len(cfg.Extra))
|
|
||||||
seenKeyPaths := make(map[string]int, len(cfg.Extra))
|
|
||||||
for i := range cfg.Extra {
|
|
||||||
if cfg.Extra[i].CertPath == "" {
|
|
||||||
b.Add(ErrMissingCertPath.Subjectf("extra[%d].cert_path", i))
|
|
||||||
}
|
|
||||||
if cfg.Extra[i].KeyPath == "" {
|
|
||||||
b.Add(ErrMissingKeyPath.Subjectf("extra[%d].key_path", i))
|
|
||||||
}
|
|
||||||
if cfg.Extra[i].CertPath != "" {
|
|
||||||
if first, ok := seenCertPaths[cfg.Extra[i].CertPath]; ok {
|
|
||||||
b.Add(ErrDuplicatedPath.Subjectf("extra[%d].cert_path", i).Withf("first: %d", first))
|
|
||||||
} else {
|
|
||||||
seenCertPaths[cfg.Extra[i].CertPath] = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if cfg.Extra[i].KeyPath != "" {
|
|
||||||
if first, ok := seenKeyPaths[cfg.Extra[i].KeyPath]; ok {
|
|
||||||
b.Add(ErrDuplicatedPath.Subjectf("extra[%d].key_path", i).Withf("first: %d", first))
|
|
||||||
} else {
|
|
||||||
seenKeyPaths[cfg.Extra[i].KeyPath] = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
|
|
||||||
b.Add(ErrMissingCADirURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
|
||||||
if len(cfg.Domains) == 0 {
|
|
||||||
b.Add(ErrMissingDomain)
|
|
||||||
}
|
|
||||||
if cfg.Email == "" {
|
|
||||||
b.Add(ErrMissingEmail)
|
|
||||||
}
|
|
||||||
if cfg.Provider != ProviderCustom {
|
|
||||||
for i, d := range cfg.Domains {
|
|
||||||
if !domainOrWildcardRE.MatchString(d) {
|
|
||||||
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// check if provider is implemented
|
|
||||||
providerConstructor, ok := Providers[cfg.Provider]
|
|
||||||
if !ok {
|
|
||||||
if cfg.Provider != ProviderCustom {
|
|
||||||
b.Add(ErrUnknownProvider.
|
|
||||||
Subject(cfg.Provider).
|
|
||||||
With(gperr.DoYouMeanField(cfg.Provider, Providers)))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
provider, err := providerConstructor(cfg.Options)
|
|
||||||
if err != nil {
|
|
||||||
b.Add(err)
|
|
||||||
} else {
|
|
||||||
cfg.challengeProvider = provider
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.challengeProvider == nil {
|
|
||||||
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
|
|
||||||
}
|
|
||||||
return b.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *Config) dns01Options() []dns01.ChallengeOption {
|
|
||||||
return []dns01.ChallengeOption{
|
|
||||||
dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.CertPath == "" {
|
if cfg.CertPath == "" {
|
||||||
cfg.CertPath = CertFileDefault
|
cfg.CertPath = CertFileDefault
|
||||||
}
|
}
|
||||||
@@ -165,6 +91,83 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
|||||||
cfg.ACMEKeyPath = ACMEKeyFileDefault
|
cfg.ACMEKeyPath = ACMEKeyFileDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b := gperr.NewBuilder("certificate error")
|
||||||
|
|
||||||
|
// check if cert_path is unique
|
||||||
|
if first, ok := seenPaths[cfg.CertPath]; ok {
|
||||||
|
b.Add(ErrDuplicatedPath.Subjectf("cert_path %s", cfg.CertPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first)))
|
||||||
|
} else {
|
||||||
|
seenPaths[cfg.CertPath] = cfg.idx
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if key_path is unique
|
||||||
|
if first, ok := seenPaths[cfg.KeyPath]; ok {
|
||||||
|
b.Add(ErrDuplicatedPath.Subjectf("key_path %s", cfg.KeyPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first)))
|
||||||
|
} else {
|
||||||
|
seenPaths[cfg.KeyPath] = cfg.idx
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
|
||||||
|
b.Add(ErrMissingField.Subject("ca_dir_url"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||||
|
if len(cfg.Domains) == 0 {
|
||||||
|
b.Add(ErrMissingField.Subject("domains"))
|
||||||
|
}
|
||||||
|
if cfg.Email == "" {
|
||||||
|
b.Add(ErrMissingField.Subject("email"))
|
||||||
|
}
|
||||||
|
if cfg.Provider != ProviderCustom {
|
||||||
|
for i, d := range cfg.Domains {
|
||||||
|
if !domainOrWildcardRE.MatchString(d) {
|
||||||
|
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if provider is implemented
|
||||||
|
providerConstructor, ok := Providers[cfg.Provider]
|
||||||
|
if !ok {
|
||||||
|
if cfg.Provider != ProviderCustom {
|
||||||
|
b.Add(ErrUnknownProvider.
|
||||||
|
Subject(cfg.Provider).
|
||||||
|
With(gperr.DoYouMeanField(cfg.Provider, Providers)))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
provider, err := providerConstructor(cfg.Options)
|
||||||
|
if err != nil {
|
||||||
|
b.Add(err)
|
||||||
|
} else {
|
||||||
|
cfg.challengeProvider = provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.challengeProvider == nil {
|
||||||
|
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Extra) > 0 {
|
||||||
|
for i := range cfg.Extra {
|
||||||
|
cfg.Extra[i] = MergeExtraConfig(cfg, &cfg.Extra[i])
|
||||||
|
cfg.Extra[i].AsConfig().idx = i + 1
|
||||||
|
err := cfg.Extra[i].AsConfig().validate(seenPaths)
|
||||||
|
if err != nil {
|
||||||
|
b.Add(err.Subjectf("extra[%d]", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) dns01Options() []dns01.ChallengeOption {
|
||||||
|
return []dns01.ChallengeOption{
|
||||||
|
dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cfg *Config) GetLegoConfig() (*User, *lego.Config, error) {
|
||||||
var privKey *ecdsa.PrivateKey
|
var privKey *ecdsa.PrivateKey
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -208,6 +211,46 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
|
|||||||
return user, legoCfg, nil
|
return user, legoCfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MergeExtraConfig(mainCfg *Config, extraCfg *ConfigExtra) ConfigExtra {
|
||||||
|
merged := ConfigExtra(*mainCfg)
|
||||||
|
merged.Extra = nil
|
||||||
|
merged.CertPath = extraCfg.CertPath
|
||||||
|
merged.KeyPath = extraCfg.KeyPath
|
||||||
|
// NOTE: Using same ACME key as main provider
|
||||||
|
|
||||||
|
if extraCfg.Provider != "" {
|
||||||
|
merged.Provider = extraCfg.Provider
|
||||||
|
}
|
||||||
|
if extraCfg.Email != "" {
|
||||||
|
merged.Email = extraCfg.Email
|
||||||
|
}
|
||||||
|
if len(extraCfg.Domains) > 0 {
|
||||||
|
merged.Domains = extraCfg.Domains
|
||||||
|
}
|
||||||
|
if len(extraCfg.Options) > 0 {
|
||||||
|
merged.Options = extraCfg.Options
|
||||||
|
}
|
||||||
|
if len(extraCfg.Resolvers) > 0 {
|
||||||
|
merged.Resolvers = extraCfg.Resolvers
|
||||||
|
}
|
||||||
|
if extraCfg.CADirURL != "" {
|
||||||
|
merged.CADirURL = extraCfg.CADirURL
|
||||||
|
}
|
||||||
|
if len(extraCfg.CACerts) > 0 {
|
||||||
|
merged.CACerts = extraCfg.CACerts
|
||||||
|
}
|
||||||
|
if extraCfg.EABKid != "" {
|
||||||
|
merged.EABKid = extraCfg.EABKid
|
||||||
|
}
|
||||||
|
if extraCfg.EABHmac != "" {
|
||||||
|
merged.EABHmac = extraCfg.EABHmac
|
||||||
|
}
|
||||||
|
if extraCfg.HTTPClient != nil {
|
||||||
|
merged.HTTPClient = extraCfg.HTTPClient
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
|
||||||
if common.IsTest {
|
if common.IsTest {
|
||||||
return nil, os.ErrNotExist
|
return nil, os.ErrNotExist
|
||||||
|
|||||||
@@ -1,27 +1,32 @@
|
|||||||
package autocert
|
package autocert_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/godoxy/internal/autocert"
|
||||||
|
"github.com/yusing/godoxy/internal/dnsproviders"
|
||||||
"github.com/yusing/godoxy/internal/serialization"
|
"github.com/yusing/godoxy/internal/serialization"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEABConfigRequired(t *testing.T) {
|
func TestEABConfigRequired(t *testing.T) {
|
||||||
|
dnsproviders.InitProviders()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cfg *Config
|
cfg *autocert.Config
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{name: "Missing EABKid", cfg: &Config{EABHmac: "1234567890"}, wantErr: true},
|
{name: "Missing EABKid", cfg: &autocert.Config{EABHmac: "1234567890"}, wantErr: true},
|
||||||
{name: "Missing EABHmac", cfg: &Config{EABKid: "1234567890"}, wantErr: true},
|
{name: "Missing EABHmac", cfg: &autocert.Config{EABKid: "1234567890"}, wantErr: true},
|
||||||
{name: "Valid EAB", cfg: &Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false},
|
{name: "Valid EAB", cfg: &autocert.Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
yaml := fmt.Appendf(nil, "eab_kid: %s\neab_hmac: %s", test.cfg.EABKid, test.cfg.EABHmac)
|
yaml := fmt.Appendf(nil, "eab_kid: %s\neab_hmac: %s", test.cfg.EABKid, test.cfg.EABHmac)
|
||||||
cfg := Config{}
|
cfg := autocert.Config{}
|
||||||
err := serialization.UnmarshalValidateYAML(yaml, &cfg)
|
err := serialization.UnmarshalValidateYAML(yaml, &cfg)
|
||||||
if (err != nil) != test.wantErr {
|
if (err != nil) != test.wantErr {
|
||||||
t.Errorf("Validate() error = %v, wantErr %v", err, test.wantErr)
|
t.Errorf("Validate() error = %v, wantErr %v", err, test.wantErr)
|
||||||
@@ -29,3 +34,27 @@ func TestEABConfigRequired(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtraCertKeyPathsUnique(t *testing.T) {
|
||||||
|
t.Run("duplicate cert_path rejected", func(t *testing.T) {
|
||||||
|
cfg := &autocert.Config{
|
||||||
|
Provider: autocert.ProviderLocal,
|
||||||
|
Extra: []autocert.ConfigExtra{
|
||||||
|
{CertPath: "a.crt", KeyPath: "a.key"},
|
||||||
|
{CertPath: "a.crt", KeyPath: "b.key"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Error(t, cfg.Validate())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("duplicate key_path rejected", func(t *testing.T) {
|
||||||
|
cfg := &autocert.Config{
|
||||||
|
Provider: autocert.ProviderLocal,
|
||||||
|
Extra: []autocert.ConfigExtra{
|
||||||
|
{CertPath: "a.crt", KeyPath: "a.key"},
|
||||||
|
{CertPath: "b.crt", KeyPath: "a.key"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.Error(t, cfg.Validate())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,5 +5,4 @@ const (
|
|||||||
CertFileDefault = certBasePath + "cert.crt"
|
CertFileDefault = certBasePath + "cert.crt"
|
||||||
KeyFileDefault = certBasePath + "priv.key"
|
KeyFileDefault = certBasePath + "priv.key"
|
||||||
ACMEKeyFileDefault = certBasePath + "acme.key"
|
ACMEKeyFileDefault = certBasePath + "acme.key"
|
||||||
LastFailureFile = certBasePath + ".last_failure"
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
package autocert
|
package autocert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
"maps"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,6 +31,8 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
Provider struct {
|
Provider struct {
|
||||||
|
logger zerolog.Logger
|
||||||
|
|
||||||
cfg *Config
|
cfg *Config
|
||||||
user *User
|
user *User
|
||||||
legoCfg *lego.Config
|
legoCfg *lego.Config
|
||||||
@@ -42,12 +47,18 @@ type (
|
|||||||
|
|
||||||
extraProviders []*Provider
|
extraProviders []*Provider
|
||||||
sniMatcher sniMatcher
|
sniMatcher sniMatcher
|
||||||
|
|
||||||
|
forceRenewalCh chan struct{}
|
||||||
|
forceRenewalDoneCh atomic.Value // chan struct{}
|
||||||
|
|
||||||
|
scheduleRenewalOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
CertExpiries map[string]time.Time
|
CertExpiries map[string]time.Time
|
||||||
|
RenewMode uint8
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrGetCertFailure = errors.New("get certificate failed")
|
var ErrNoCertificate = errors.New("no certificate found")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// renew failed for whatever reason, 1 hour cooldown
|
// renew failed for whatever reason, 1 hour cooldown
|
||||||
@@ -56,21 +67,36 @@ const (
|
|||||||
requestCooldownDuration = 15 * time.Second
|
requestCooldownDuration = 15 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
renewModeForce = iota
|
||||||
|
renewModeIfNeeded
|
||||||
|
)
|
||||||
|
|
||||||
// could be nil
|
// could be nil
|
||||||
var ActiveProvider atomic.Pointer[Provider]
|
var ActiveProvider atomic.Pointer[Provider]
|
||||||
|
|
||||||
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider {
|
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) (*Provider, error) {
|
||||||
return &Provider{
|
p := &Provider{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
user: user,
|
user: user,
|
||||||
legoCfg: legoCfg,
|
legoCfg: legoCfg,
|
||||||
lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath),
|
lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath),
|
||||||
|
forceRenewalCh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
if cfg.idx == 0 {
|
||||||
|
p.logger = log.With().Str("provider", "main").Logger()
|
||||||
|
} else {
|
||||||
|
p.logger = log.With().Str("provider", fmt.Sprintf("extra[%d]", cfg.idx)).Logger()
|
||||||
|
}
|
||||||
|
if err := p.setupExtraProviders(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
if p.tlsCert == nil {
|
if p.tlsCert == nil {
|
||||||
return nil, ErrGetCertFailure
|
return nil, ErrNoCertificate
|
||||||
}
|
}
|
||||||
if hello == nil || hello.ServerName == "" {
|
if hello == nil || hello.ServerName == "" {
|
||||||
return p.tlsCert, nil
|
return p.tlsCert, nil
|
||||||
@@ -82,7 +108,14 @@ func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) GetName() string {
|
func (p *Provider) GetName() string {
|
||||||
return p.cfg.Provider
|
if p.cfg.idx == 0 {
|
||||||
|
return "main"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("extra[%d]", p.cfg.idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) fmtError(err error) error {
|
||||||
|
return gperr.PrependSubject(fmt.Sprintf("provider: %s", p.GetName()), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) GetCertPath() string {
|
func (p *Provider) GetCertPath() string {
|
||||||
@@ -129,45 +162,88 @@ func (p *Provider) ClearLastFailure() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p.lastFailure = time.Time{}
|
p.lastFailure = time.Time{}
|
||||||
return os.Remove(p.lastFailureFile)
|
err := os.Remove(p.lastFailureFile)
|
||||||
|
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) ObtainCert() error {
|
// allProviders returns all providers including this provider and all extra providers.
|
||||||
if len(p.extraProviders) > 0 {
|
func (p *Provider) allProviders() []*Provider {
|
||||||
errs := gperr.NewGroup("autocert errors")
|
return append([]*Provider{p}, p.extraProviders...)
|
||||||
errs.Go(p.obtainCertSelf)
|
}
|
||||||
for _, ep := range p.extraProviders {
|
|
||||||
errs.Go(ep.obtainCertSelf)
|
// ObtainCertIfNotExistsAll obtains a new certificate for this provider and all extra providers if they do not exist.
|
||||||
}
|
func (p *Provider) ObtainCertIfNotExistsAll() error {
|
||||||
if err := errs.Wait().Error(); err != nil {
|
errs := gperr.NewGroup("obtain cert error")
|
||||||
return err
|
|
||||||
}
|
for _, provider := range p.allProviders() {
|
||||||
p.rebuildSNIMatcher()
|
errs.Go(func() error {
|
||||||
|
if err := provider.obtainCertIfNotExists(); err != nil {
|
||||||
|
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
p.rebuildSNIMatcher()
|
||||||
|
return errs.Wait().Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
|
||||||
|
func (p *Provider) obtainCertIfNotExists() error {
|
||||||
|
err := p.LoadCert()
|
||||||
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return p.obtainCertSelf()
|
|
||||||
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// check last failure
|
||||||
|
lastFailure, err := p.GetLastFailure()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get last failure: %w", err)
|
||||||
|
}
|
||||||
|
if !lastFailure.IsZero() && time.Since(lastFailure) < requestCooldownDuration {
|
||||||
|
return fmt.Errorf("still in cooldown until %s", strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info().Msg("cert not found, obtaining new cert")
|
||||||
|
return p.ObtainCert()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) obtainCertSelf() error {
|
// ObtainCertAll renews existing certificates or obtains new certificates for this provider and all extra providers.
|
||||||
|
func (p *Provider) ObtainCertAll() error {
|
||||||
|
errs := gperr.NewGroup("obtain cert error")
|
||||||
|
for _, provider := range p.allProviders() {
|
||||||
|
errs.Go(func() error {
|
||||||
|
if err := provider.obtainCertIfNotExists(); err != nil {
|
||||||
|
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return errs.Wait().Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObtainCert renews existing certificate or obtains a new certificate for this provider.
|
||||||
|
func (p *Provider) ObtainCert() error {
|
||||||
if p.cfg.Provider == ProviderLocal {
|
if p.cfg.Provider == ProviderLocal {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.cfg.Provider == ProviderPseudo {
|
if p.cfg.Provider == ProviderPseudo {
|
||||||
log.Info().Msg("init client for pseudo provider")
|
p.logger.Info().Msg("init client for pseudo provider")
|
||||||
<-time.After(time.Second)
|
<-time.After(time.Second)
|
||||||
log.Info().Msg("registering acme for pseudo provider")
|
p.logger.Info().Msg("registering acme for pseudo provider")
|
||||||
<-time.After(time.Second)
|
<-time.After(time.Second)
|
||||||
log.Info().Msg("obtained cert for pseudo provider")
|
p.logger.Info().Msg("obtained cert for pseudo provider")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastFailure, err := p.GetLastFailure(); err != nil {
|
|
||||||
return err
|
|
||||||
} else if time.Since(lastFailure) < requestCooldownDuration {
|
|
||||||
return fmt.Errorf("%w: still in cooldown until %s", ErrGetCertFailure, strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.client == nil {
|
if p.client == nil {
|
||||||
if err := p.initClient(); err != nil {
|
if err := p.initClient(); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -227,6 +303,7 @@ func (p *Provider) obtainCertSelf() error {
|
|||||||
}
|
}
|
||||||
p.tlsCert = &tlsCert
|
p.tlsCert = &tlsCert
|
||||||
p.certExpiries = expiries
|
p.certExpiries = expiries
|
||||||
|
p.rebuildSNIMatcher()
|
||||||
|
|
||||||
if err := p.ClearLastFailure(); err != nil {
|
if err := p.ClearLastFailure(); err != nil {
|
||||||
return fmt.Errorf("failed to clear last failure: %w", err)
|
return fmt.Errorf("failed to clear last failure: %w", err)
|
||||||
@@ -235,19 +312,37 @@ func (p *Provider) obtainCertSelf() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) LoadCert() error {
|
func (p *Provider) LoadCert() error {
|
||||||
|
var errs gperr.Builder
|
||||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("load SSL certificate: %w", err)
|
errs.Addf("load SSL certificate: %w", p.fmtError(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
expiries, err := getCertExpiries(&cert)
|
expiries, err := getCertExpiries(&cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse SSL certificate: %w", err)
|
errs.Addf("parse SSL certificate: %w", p.fmtError(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
p.tlsCert = &cert
|
p.tlsCert = &cert
|
||||||
p.certExpiries = expiries
|
p.certExpiries = expiries
|
||||||
|
|
||||||
log.Info().Msgf("next cert renewal in %s", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
|
for _, ep := range p.extraProviders {
|
||||||
return p.renewIfNeeded()
|
if err := ep.LoadCert(); err != nil {
|
||||||
|
errs.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.rebuildSNIMatcher()
|
||||||
|
return errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
|
||||||
|
func (p *Provider) PrintCertExpiriesAll() {
|
||||||
|
for _, provider := range p.allProviders() {
|
||||||
|
for domain, expiry := range provider.certExpiries {
|
||||||
|
p.logger.Info().Str("domain", domain).Msgf("certificate expire on %s", strutils.FormatTime(expiry))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShouldRenewOn returns the time at which the certificate should be renewed.
|
// ShouldRenewOn returns the time at which the certificate should be renewed.
|
||||||
@@ -255,65 +350,129 @@ func (p *Provider) ShouldRenewOn() time.Time {
|
|||||||
for _, expiry := range p.certExpiries {
|
for _, expiry := range p.certExpiries {
|
||||||
return expiry.AddDate(0, -1, 0) // 1 month before
|
return expiry.AddDate(0, -1, 0) // 1 month before
|
||||||
}
|
}
|
||||||
// this line should never be reached
|
// this line should never be reached in production, but will be useful for testing
|
||||||
panic("no certificate available")
|
return time.Now().AddDate(0, 1, 0) // 1 month after
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
// ForceExpiryAll triggers immediate certificate renewal for this provider and all extra providers.
|
||||||
|
// Returns true if the renewal was triggered, false if the renewal was dropped.
|
||||||
|
//
|
||||||
|
// If at least one renewal is triggered, returns true.
|
||||||
|
func (p *Provider) ForceExpiryAll() (ok bool) {
|
||||||
|
doneCh := make(chan struct{})
|
||||||
|
if swapped := p.forceRenewalDoneCh.CompareAndSwap(nil, doneCh); !swapped { // already in progress
|
||||||
|
close(doneCh)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.forceRenewalCh <- struct{}{}:
|
||||||
|
ok = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ep := range p.extraProviders {
|
||||||
|
if ep.ForceExpiryAll() {
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitRenewalDone waits for the renewal to complete.
|
||||||
|
// Returns false if the renewal was dropped.
|
||||||
|
func (p *Provider) WaitRenewalDone(ctx context.Context) bool {
|
||||||
|
done, ok := p.forceRenewalDoneCh.Load().(chan struct{})
|
||||||
|
if !ok || done == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ep := range p.extraProviders {
|
||||||
|
if !ep.WaitRenewalDone(ctx) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduleRenewalAll schedules the renewal of the certificate for this provider and all extra providers.
|
||||||
|
func (p *Provider) ScheduleRenewalAll(parent task.Parent) {
|
||||||
|
p.scheduleRenewalOnce.Do(func() {
|
||||||
|
p.scheduleRenewal(parent)
|
||||||
|
})
|
||||||
|
for _, ep := range p.extraProviders {
|
||||||
|
ep.scheduleRenewalOnce.Do(func() {
|
||||||
|
ep.scheduleRenewal(parent)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var emptyForceRenewalDoneCh any = chan struct{}(nil)
|
||||||
|
|
||||||
|
// scheduleRenewal schedules the renewal of the certificate for this provider.
|
||||||
|
func (p *Provider) scheduleRenewal(parent task.Parent) {
|
||||||
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
|
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
renewalTime := p.ShouldRenewOn()
|
|
||||||
timer := time.NewTimer(time.Until(renewalTime))
|
|
||||||
defer timer.Stop()
|
|
||||||
|
|
||||||
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
|
timer := time.NewTimer(time.Until(p.ShouldRenewOn()))
|
||||||
|
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
|
||||||
|
|
||||||
|
renew := func(renewMode RenewMode) {
|
||||||
|
defer func() {
|
||||||
|
if done, ok := p.forceRenewalDoneCh.Swap(emptyForceRenewalDoneCh).(chan struct{}); ok && done != nil {
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
renewed, err := p.renew(renewMode)
|
||||||
|
if err != nil {
|
||||||
|
gperr.LogWarn("autocert: cert renew failed", p.fmtError(err))
|
||||||
|
notif.Notify(¬if.LogMessage{
|
||||||
|
Level: zerolog.ErrorLevel,
|
||||||
|
Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()),
|
||||||
|
Body: notif.MessageBody(err.Error()),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if renewed {
|
||||||
|
p.rebuildSNIMatcher()
|
||||||
|
|
||||||
|
notif.Notify(¬if.LogMessage{
|
||||||
|
Level: zerolog.InfoLevel,
|
||||||
|
Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()),
|
||||||
|
Body: notif.ListBody(p.cfg.Domains),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset on success
|
||||||
|
if err := p.ClearLastFailure(); err != nil {
|
||||||
|
gperr.LogWarn("autocert: failed to clear last failure", p.fmtError(err))
|
||||||
|
}
|
||||||
|
timer.Reset(time.Until(p.ShouldRenewOn()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer timer.Stop()
|
||||||
defer task.Finish(nil)
|
defer task.Finish(nil)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-task.Context().Done():
|
case <-task.Context().Done():
|
||||||
return
|
return
|
||||||
|
case <-p.forceRenewalCh:
|
||||||
|
renew(renewModeForce)
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
// Retry after 1 hour on failure
|
renew(renewModeIfNeeded)
|
||||||
lastFailure, err := p.GetLastFailure()
|
|
||||||
if err != nil {
|
|
||||||
gperr.LogWarn("autocert: failed to get last failure", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := p.renewIfNeeded(); err != nil {
|
|
||||||
gperr.LogWarn("autocert: cert renew failed", err)
|
|
||||||
if err := p.UpdateLastFailure(); err != nil {
|
|
||||||
gperr.LogWarn("autocert: failed to update last failure", err)
|
|
||||||
}
|
|
||||||
notif.Notify(¬if.LogMessage{
|
|
||||||
Level: zerolog.ErrorLevel,
|
|
||||||
Title: "SSL certificate renewal failed",
|
|
||||||
Body: notif.MessageBody(err.Error()),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
notif.Notify(¬if.LogMessage{
|
|
||||||
Level: zerolog.InfoLevel,
|
|
||||||
Title: "SSL certificate renewed",
|
|
||||||
Body: notif.ListBody(p.cfg.Domains),
|
|
||||||
})
|
|
||||||
// Reset on success
|
|
||||||
if err := p.ClearLastFailure(); err != nil {
|
|
||||||
gperr.LogWarn("autocert: failed to clear last failure", err)
|
|
||||||
}
|
|
||||||
renewalTime = p.ShouldRenewOn()
|
|
||||||
timer.Reset(time.Until(renewalTime))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for _, ep := range p.extraProviders {
|
|
||||||
ep.ScheduleRenewal(parent)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) initClient() error {
|
func (p *Provider) initClient() error {
|
||||||
@@ -409,21 +568,42 @@ func (p *Provider) certState() CertState {
|
|||||||
return CertStateValid
|
return CertStateValid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) renewIfNeeded() error {
|
func (p *Provider) renew(mode RenewMode) (renewed bool, err error) {
|
||||||
if p.cfg.Provider == ProviderLocal {
|
if p.cfg.Provider == ProviderLocal {
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch p.certState() {
|
if mode != renewModeForce {
|
||||||
case CertStateExpired:
|
// Retry after 1 hour on failure
|
||||||
log.Info().Msg("certs expired, renewing")
|
lastFailure, err := p.GetLastFailure()
|
||||||
case CertStateMismatch:
|
if err != nil {
|
||||||
log.Info().Msg("cert domains mismatch with config, renewing")
|
return false, fmt.Errorf("failed to get last failure: %w", err)
|
||||||
default:
|
}
|
||||||
return nil
|
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
|
||||||
|
until := lastFailure.Add(renewalCooldownDuration).Local()
|
||||||
|
return false, fmt.Errorf("still in cooldown until %s", strutils.FormatTime(until))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.obtainCertSelf()
|
if mode == renewModeIfNeeded {
|
||||||
|
switch p.certState() {
|
||||||
|
case CertStateExpired:
|
||||||
|
log.Info().Msg("certs expired, renewing")
|
||||||
|
case CertStateMismatch:
|
||||||
|
log.Info().Msg("cert domains mismatch with config, renewing")
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mode == renewModeForce {
|
||||||
|
log.Info().Msg("force renewing cert by user request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.ObtainCert(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
||||||
@@ -445,15 +625,16 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func lastFailureFileFor(certPath, keyPath string) string {
|
func lastFailureFileFor(certPath, keyPath string) string {
|
||||||
if certPath == "" && keyPath == "" {
|
|
||||||
return LastFailureFile
|
|
||||||
}
|
|
||||||
dir := filepath.Dir(certPath)
|
dir := filepath.Dir(certPath)
|
||||||
sum := sha256.Sum256([]byte(certPath + "|" + keyPath))
|
sum := sha256.Sum256([]byte(certPath + "|" + keyPath))
|
||||||
return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6]))
|
return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) rebuildSNIMatcher() {
|
func (p *Provider) rebuildSNIMatcher() {
|
||||||
|
if p.cfg.idx != 0 { // only main provider has extra providers
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.sniMatcher = sniMatcher{}
|
p.sniMatcher = sniMatcher{}
|
||||||
p.sniMatcher.addProvider(p)
|
p.sniMatcher.addProvider(p)
|
||||||
for _, ep := range p.extraProviders {
|
for _, ep := range p.extraProviders {
|
||||||
|
|||||||
@@ -10,12 +10,15 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -24,6 +27,368 @@ import (
|
|||||||
"github.com/yusing/godoxy/internal/dnsproviders"
|
"github.com/yusing/godoxy/internal/dnsproviders"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestACMEServer implements a minimal ACME server for testing with request tracking.
|
||||||
|
type TestACMEServer struct {
|
||||||
|
server *httptest.Server
|
||||||
|
caCert *x509.Certificate
|
||||||
|
caKey *rsa.PrivateKey
|
||||||
|
clientCSRs map[string]*x509.CertificateRequest
|
||||||
|
orderDomains map[string][]string
|
||||||
|
authzDomains map[string]string
|
||||||
|
orderSeq int
|
||||||
|
certRequestCount map[string]int
|
||||||
|
renewalRequestCount map[string]int
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestACMEServer(t *testing.T) *TestACMEServer {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate CA certificate and key
|
||||||
|
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
caTemplate := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Test CA"},
|
||||||
|
Country: []string{"US"},
|
||||||
|
Province: []string{""},
|
||||||
|
Locality: []string{"Test"},
|
||||||
|
StreetAddress: []string{""},
|
||||||
|
PostalCode: []string{""},
|
||||||
|
},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
caCert, err := x509.ParseCertificate(caCertDER)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
acme := &TestACMEServer{
|
||||||
|
caCert: caCert,
|
||||||
|
caKey: caKey,
|
||||||
|
clientCSRs: make(map[string]*x509.CertificateRequest),
|
||||||
|
orderDomains: make(map[string][]string),
|
||||||
|
authzDomains: make(map[string]string),
|
||||||
|
orderSeq: 0,
|
||||||
|
certRequestCount: make(map[string]int),
|
||||||
|
renewalRequestCount: make(map[string]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
acme.setupRoutes(mux)
|
||||||
|
|
||||||
|
acme.server = httptest.NewUnstartedServer(mux)
|
||||||
|
acme.server.TLS = &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{
|
||||||
|
{
|
||||||
|
Certificate: [][]byte{caCert.Raw},
|
||||||
|
PrivateKey: caKey,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
acme.server.StartTLS()
|
||||||
|
return acme
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) Close() {
|
||||||
|
s.server.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) URL() string {
|
||||||
|
return s.server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) httpClient() *http.Client {
|
||||||
|
certPool := x509.NewCertPool()
|
||||||
|
certPool.AddCert(s.caCert)
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: 30 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 30 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: certPool,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) setupRoutes(mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
|
||||||
|
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
|
||||||
|
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
|
||||||
|
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
|
||||||
|
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
|
||||||
|
mux.HandleFunc("/acme/chall/", s.handleChallenge)
|
||||||
|
mux.HandleFunc("/acme/order/", s.handleOrder)
|
||||||
|
mux.HandleFunc("/acme/cert/", s.handleCertificate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
|
||||||
|
directory := map[string]any{
|
||||||
|
"newNonce": s.server.URL + "/acme/new-nonce",
|
||||||
|
"newAccount": s.server.URL + "/acme/new-account",
|
||||||
|
"newOrder": s.server.URL + "/acme/new-order",
|
||||||
|
"keyChange": s.server.URL + "/acme/key-change",
|
||||||
|
"meta": map[string]any{
|
||||||
|
"termsOfService": s.server.URL + "/terms",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(directory)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-12345")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account := map[string]any{
|
||||||
|
"status": "valid",
|
||||||
|
"contact": []string{"mailto:test@example.com"},
|
||||||
|
"orders": s.server.URL + "/acme/orders",
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Location", s.server.URL+"/acme/account/1")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-67890")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
var jws struct {
|
||||||
|
Payload string `json:"payload"`
|
||||||
|
}
|
||||||
|
json.Unmarshal(body, &jws)
|
||||||
|
payloadBytes, _ := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
|
var orderReq struct {
|
||||||
|
Identifiers []map[string]string `json:"identifiers"`
|
||||||
|
}
|
||||||
|
json.Unmarshal(payloadBytes, &orderReq)
|
||||||
|
|
||||||
|
domains := []string{}
|
||||||
|
for _, id := range orderReq.Identifiers {
|
||||||
|
domains = append(domains, id["value"])
|
||||||
|
}
|
||||||
|
sort.Strings(domains)
|
||||||
|
domainKey := strings.Join(domains, ",")
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.orderSeq++
|
||||||
|
orderID := fmt.Sprintf("test-order-%d", s.orderSeq)
|
||||||
|
authzID := fmt.Sprintf("test-authz-%d", s.orderSeq)
|
||||||
|
s.orderDomains[orderID] = domains
|
||||||
|
if len(domains) > 0 {
|
||||||
|
s.authzDomains[authzID] = domains[0]
|
||||||
|
}
|
||||||
|
s.certRequestCount[domainKey]++
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
order := map[string]any{
|
||||||
|
"status": "ready",
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifiers": orderReq.Identifiers,
|
||||||
|
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
|
||||||
|
"finalize": s.server.URL + "/acme/order/" + orderID + "/finalize",
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Location", s.server.URL+"/acme/order/"+orderID)
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-order")
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
json.NewEncoder(w).Encode(order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authzID := strings.TrimPrefix(r.URL.Path, "/acme/authz/")
|
||||||
|
domain := s.authzDomains[authzID]
|
||||||
|
if domain == "" {
|
||||||
|
domain = "test.example.com"
|
||||||
|
}
|
||||||
|
authz := map[string]any{
|
||||||
|
"status": "valid",
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifier": map[string]string{"type": "dns", "value": domain},
|
||||||
|
"challenges": []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "dns-01",
|
||||||
|
"status": "valid",
|
||||||
|
"url": s.server.URL + "/acme/chall/test-chall-789",
|
||||||
|
"token": "test-token-abc123",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-authz")
|
||||||
|
json.NewEncoder(w).Encode(authz)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
|
||||||
|
challenge := map[string]any{
|
||||||
|
"type": "dns-01",
|
||||||
|
"status": "valid",
|
||||||
|
"url": r.URL.String(),
|
||||||
|
"token": "test-token-abc123",
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-chall")
|
||||||
|
json.NewEncoder(w).Encode(challenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.HasSuffix(r.URL.Path, "/finalize") {
|
||||||
|
s.handleFinalize(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
orderID := strings.TrimPrefix(r.URL.Path, "/acme/order/")
|
||||||
|
domains := s.orderDomains[orderID]
|
||||||
|
if len(domains) == 0 {
|
||||||
|
domains = []string{"test.example.com"}
|
||||||
|
}
|
||||||
|
certURL := s.server.URL + "/acme/cert/" + orderID
|
||||||
|
order := map[string]any{
|
||||||
|
"status": "valid",
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifiers": func() []map[string]string {
|
||||||
|
out := make([]map[string]string, 0, len(domains))
|
||||||
|
for _, d := range domains {
|
||||||
|
out = append(out, map[string]string{"type": "dns", "value": d})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}(),
|
||||||
|
"certificate": certURL,
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
|
||||||
|
json.NewEncoder(w).Encode(order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Failed to read request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
csr, err := s.extractCSRFromJWS(body)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
orderID := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/acme/order/"), "/finalize")
|
||||||
|
s.mu.Lock()
|
||||||
|
s.clientCSRs[orderID] = csr
|
||||||
|
|
||||||
|
// Detect renewal: if we already have a certificate for these domains, it's a renewal
|
||||||
|
domains := csr.DNSNames
|
||||||
|
sort.Strings(domains)
|
||||||
|
domainKey := strings.Join(domains, ",")
|
||||||
|
|
||||||
|
if s.certRequestCount[domainKey] > 1 {
|
||||||
|
s.renewalRequestCount[domainKey]++
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
certURL := s.server.URL + "/acme/cert/" + orderID
|
||||||
|
order := map[string]any{
|
||||||
|
"status": "valid",
|
||||||
|
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
||||||
|
"identifiers": func() []map[string]string {
|
||||||
|
out := make([]map[string]string, 0, len(domains))
|
||||||
|
for _, d := range domains {
|
||||||
|
out = append(out, map[string]string{"type": "dns", "value": d})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}(),
|
||||||
|
"certificate": certURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
|
||||||
|
json.NewEncoder(w).Encode(order)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
|
||||||
|
var jws struct {
|
||||||
|
Payload string `json:"payload"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jwsData, &jws); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var finalizeReq struct {
|
||||||
|
CSR string `json:"csr"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return x509.ParseCertificateRequest(csrBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TestACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
|
||||||
|
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
|
||||||
|
csr, exists := s.clientCSRs[orderID]
|
||||||
|
if !exists {
|
||||||
|
http.Error(w, "No CSR found for order", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Test Cert"},
|
||||||
|
Country: []string{"US"},
|
||||||
|
},
|
||||||
|
DNSNames: csr.DNSNames,
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(90 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||||
|
w.Header().Set("Replay-Nonce", "test-nonce-cert")
|
||||||
|
w.Write(append(certPEM, caPEM...))
|
||||||
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
dnsproviders.InitProviders()
|
dnsproviders.InitProviders()
|
||||||
m.Run()
|
m.Run()
|
||||||
@@ -41,7 +406,7 @@ func TestCustomProvider(t *testing.T) {
|
|||||||
ACMEKeyPath: "certs/custom-acme.key",
|
ACMEKeyPath: "certs/custom-acme.key",
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cfg.Validate()
|
err := error(cfg.Validate())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, legoCfg, err := cfg.GetLegoConfig()
|
user, legoCfg, err := cfg.GetLegoConfig()
|
||||||
@@ -62,7 +427,8 @@ func TestCustomProvider(t *testing.T) {
|
|||||||
|
|
||||||
err := cfg.Validate()
|
err := cfg.Validate()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
|
require.Contains(t, err.Error(), "missing field")
|
||||||
|
require.Contains(t, err.Error(), "ca_dir_url")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
|
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
|
||||||
@@ -76,7 +442,7 @@ func TestCustomProvider(t *testing.T) {
|
|||||||
ACMEKeyPath: "certs/internal-acme.key",
|
ACMEKeyPath: "certs/internal-acme.key",
|
||||||
}
|
}
|
||||||
|
|
||||||
err := cfg.Validate()
|
err := error(cfg.Validate())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user, legoCfg, err := cfg.GetLegoConfig()
|
user, legoCfg, err := cfg.GetLegoConfig()
|
||||||
@@ -86,9 +452,10 @@ func TestCustomProvider(t *testing.T) {
|
|||||||
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
|
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
|
||||||
require.Equal(t, "admin@internal.com", user.Email)
|
require.Equal(t, "admin@internal.com", user.Email)
|
||||||
|
|
||||||
provider := autocert.NewProvider(cfg, user, legoCfg)
|
provider, err := autocert.NewProvider(cfg, user, legoCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
require.NotNil(t, provider)
|
require.NotNil(t, provider)
|
||||||
require.Equal(t, autocert.ProviderCustom, provider.GetName())
|
require.Equal(t, "main", provider.GetName())
|
||||||
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
|
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
|
||||||
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
|
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
|
||||||
})
|
})
|
||||||
@@ -119,7 +486,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
|
|||||||
require.NotNil(t, user)
|
require.NotNil(t, user)
|
||||||
require.NotNil(t, legoCfg)
|
require.NotNil(t, legoCfg)
|
||||||
|
|
||||||
provider := autocert.NewProvider(cfg, user, legoCfg)
|
provider, err := autocert.NewProvider(cfg, user, legoCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
require.NotNil(t, provider)
|
require.NotNil(t, provider)
|
||||||
|
|
||||||
// Test obtaining certificate
|
// Test obtaining certificate
|
||||||
@@ -161,7 +529,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
|
|||||||
require.NotNil(t, user)
|
require.NotNil(t, user)
|
||||||
require.NotNil(t, legoCfg)
|
require.NotNil(t, legoCfg)
|
||||||
|
|
||||||
provider := autocert.NewProvider(cfg, user, legoCfg)
|
provider, err := autocert.NewProvider(cfg, user, legoCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
require.NotNil(t, provider)
|
require.NotNil(t, provider)
|
||||||
|
|
||||||
err = provider.ObtainCert()
|
err = provider.ObtainCert()
|
||||||
@@ -178,330 +547,3 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
|
|||||||
require.True(t, time.Now().After(x509Cert.NotBefore))
|
require.True(t, time.Now().After(x509Cert.NotBefore))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// testACMEServer implements a minimal ACME server for testing.
|
|
||||||
type testACMEServer struct {
|
|
||||||
server *httptest.Server
|
|
||||||
caCert *x509.Certificate
|
|
||||||
caKey *rsa.PrivateKey
|
|
||||||
clientCSRs map[string]*x509.CertificateRequest
|
|
||||||
orderID string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestACMEServer(t *testing.T) *testACMEServer {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
// Generate CA certificate and key
|
|
||||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
caTemplate := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Test CA"},
|
|
||||||
Country: []string{"US"},
|
|
||||||
Province: []string{""},
|
|
||||||
Locality: []string{"Test"},
|
|
||||||
StreetAddress: []string{""},
|
|
||||||
PostalCode: []string{""},
|
|
||||||
},
|
|
||||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
|
||||||
IsCA: true,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
caCert, err := x509.ParseCertificate(caCertDER)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
acme := &testACMEServer{
|
|
||||||
caCert: caCert,
|
|
||||||
caKey: caKey,
|
|
||||||
clientCSRs: make(map[string]*x509.CertificateRequest),
|
|
||||||
orderID: "test-order-123",
|
|
||||||
}
|
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
acme.setupRoutes(mux)
|
|
||||||
|
|
||||||
acme.server = httptest.NewUnstartedServer(mux)
|
|
||||||
acme.server.TLS = &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{
|
|
||||||
{
|
|
||||||
Certificate: [][]byte{caCert.Raw},
|
|
||||||
PrivateKey: caKey,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
}
|
|
||||||
acme.server.StartTLS()
|
|
||||||
return acme
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) Close() {
|
|
||||||
s.server.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) URL() string {
|
|
||||||
return s.server.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) httpClient() *http.Client {
|
|
||||||
certPool := x509.NewCertPool()
|
|
||||||
certPool.AddCert(s.caCert)
|
|
||||||
|
|
||||||
return &http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
DialContext: (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).DialContext,
|
|
||||||
TLSHandshakeTimeout: 30 * time.Second,
|
|
||||||
ResponseHeaderTimeout: 30 * time.Second,
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
RootCAs: certPool,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
|
|
||||||
// ACME directory endpoint
|
|
||||||
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
|
|
||||||
|
|
||||||
// ACME endpoints
|
|
||||||
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
|
|
||||||
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
|
|
||||||
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
|
|
||||||
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
|
|
||||||
mux.HandleFunc("/acme/chall/", s.handleChallenge)
|
|
||||||
mux.HandleFunc("/acme/order/", s.handleOrder)
|
|
||||||
mux.HandleFunc("/acme/cert/", s.handleCertificate)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
|
|
||||||
directory := map[string]interface{}{
|
|
||||||
"newNonce": s.server.URL + "/acme/new-nonce",
|
|
||||||
"newAccount": s.server.URL + "/acme/new-account",
|
|
||||||
"newOrder": s.server.URL + "/acme/new-order",
|
|
||||||
"keyChange": s.server.URL + "/acme/key-change",
|
|
||||||
"meta": map[string]interface{}{
|
|
||||||
"termsOfService": s.server.URL + "/terms",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(directory)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-12345")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
|
|
||||||
account := map[string]interface{}{
|
|
||||||
"status": "valid",
|
|
||||||
"contact": []string{"mailto:test@example.com"},
|
|
||||||
"orders": s.server.URL + "/acme/orders",
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Location", s.server.URL+"/acme/account/1")
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-67890")
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
|
||||||
json.NewEncoder(w).Encode(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
|
|
||||||
authzID := "test-authz-456"
|
|
||||||
|
|
||||||
order := map[string]interface{}{
|
|
||||||
"status": "ready", // Skip pending state for simplicity
|
|
||||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
|
||||||
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
|
||||||
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
|
|
||||||
"finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-order")
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
|
||||||
json.NewEncoder(w).Encode(order)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
|
||||||
authz := map[string]interface{}{
|
|
||||||
"status": "valid", // Skip challenge validation for simplicity
|
|
||||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
|
||||||
"identifier": map[string]string{"type": "dns", "value": "test.example.com"},
|
|
||||||
"challenges": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"type": "dns-01",
|
|
||||||
"status": "valid",
|
|
||||||
"url": s.server.URL + "/acme/chall/test-chall-789",
|
|
||||||
"token": "test-token-abc123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-authz")
|
|
||||||
json.NewEncoder(w).Encode(authz)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
|
|
||||||
challenge := map[string]interface{}{
|
|
||||||
"type": "dns-01",
|
|
||||||
"status": "valid",
|
|
||||||
"url": r.URL.String(),
|
|
||||||
"token": "test-token-abc123",
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-chall")
|
|
||||||
json.NewEncoder(w).Encode(challenge)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if strings.HasSuffix(r.URL.Path, "/finalize") {
|
|
||||||
s.handleFinalize(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
certURL := s.server.URL + "/acme/cert/" + s.orderID
|
|
||||||
order := map[string]interface{}{
|
|
||||||
"status": "valid",
|
|
||||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
|
||||||
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
|
||||||
"certificate": certURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
|
|
||||||
json.NewEncoder(w).Encode(order)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Read the JWS payload
|
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to read request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract CSR from JWS payload
|
|
||||||
csr, err := s.extractCSRFromJWS(body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store the CSR for certificate generation
|
|
||||||
s.clientCSRs[s.orderID] = csr
|
|
||||||
|
|
||||||
certURL := s.server.URL + "/acme/cert/" + s.orderID
|
|
||||||
order := map[string]interface{}{
|
|
||||||
"status": "valid",
|
|
||||||
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
|
|
||||||
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
|
|
||||||
"certificate": certURL,
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
|
|
||||||
json.NewEncoder(w).Encode(order)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
|
|
||||||
// Parse the JWS structure
|
|
||||||
var jws struct {
|
|
||||||
Protected string `json:"protected"`
|
|
||||||
Payload string `json:"payload"`
|
|
||||||
Signature string `json:"signature"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jwsData, &jws); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode the payload
|
|
||||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the finalize request
|
|
||||||
var finalizeReq struct {
|
|
||||||
CSR string `json:"csr"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode the CSR
|
|
||||||
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the CSR
|
|
||||||
csr, err := x509.ParseCertificateRequest(csrBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return csr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Extract order ID from URL
|
|
||||||
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
|
|
||||||
|
|
||||||
// Get the CSR for this order
|
|
||||||
csr, exists := s.clientCSRs[orderID]
|
|
||||||
if !exists {
|
|
||||||
http.Error(w, "No CSR found for order", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create certificate using the public key from the client's CSR
|
|
||||||
template := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(2),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Test Cert"},
|
|
||||||
Country: []string{"US"},
|
|
||||||
},
|
|
||||||
DNSNames: csr.DNSNames,
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(90 * 24 * time.Hour),
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the public key from the CSR and sign with CA key
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return certificate chain
|
|
||||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
||||||
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
|
||||||
w.Header().Set("Replay-Nonce", "test-nonce-cert")
|
|
||||||
w.Write(append(certPEM, caPEM...))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
package provider_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/yusing/godoxy/internal/autocert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestExtraCertKeyPathsUnique(t *testing.T) {
|
|
||||||
t.Run("duplicate cert_path rejected", func(t *testing.T) {
|
|
||||||
cfg := &autocert.Config{
|
|
||||||
Provider: autocert.ProviderLocal,
|
|
||||||
Extra: []autocert.Config{
|
|
||||||
{CertPath: "a.crt", KeyPath: "a.key"},
|
|
||||||
{CertPath: "a.crt", KeyPath: "b.key"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.Error(t, cfg.Validate())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("duplicate key_path rejected", func(t *testing.T) {
|
|
||||||
cfg := &autocert.Config{
|
|
||||||
Provider: autocert.ProviderLocal,
|
|
||||||
Extra: []autocert.Config{
|
|
||||||
{CertPath: "a.crt", KeyPath: "a.key"},
|
|
||||||
{CertPath: "b.crt", KeyPath: "a.key"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
require.Error(t, cfg.Validate())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
90
internal/autocert/provider_test/multi_cert_test.go
Normal file
90
internal/autocert/provider_test/multi_cert_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
//nolint:errchkjson,errcheck
|
||||||
|
package provider_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/godoxy/internal/autocert"
|
||||||
|
"github.com/yusing/godoxy/internal/serialization"
|
||||||
|
"github.com/yusing/goutils/task"
|
||||||
|
)
|
||||||
|
|
||||||
|
func buildMultiCertYAML(serverURL string) []byte {
|
||||||
|
return fmt.Appendf(nil, `
|
||||||
|
email: main@example.com
|
||||||
|
domains: [main.example.com]
|
||||||
|
provider: custom
|
||||||
|
ca_dir_url: %s/acme/acme/directory
|
||||||
|
cert_path: certs/main.crt
|
||||||
|
key_path: certs/main.key
|
||||||
|
extra:
|
||||||
|
- email: extra1@example.com
|
||||||
|
domains: [extra1.example.com]
|
||||||
|
cert_path: certs/extra1.crt
|
||||||
|
key_path: certs/extra1.key
|
||||||
|
- email: extra2@example.com
|
||||||
|
domains: [extra2.example.com]
|
||||||
|
cert_path: certs/extra2.crt
|
||||||
|
key_path: certs/extra2.key
|
||||||
|
`, serverURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleCertificatesLifecycle(t *testing.T) {
|
||||||
|
acmeServer := newTestACMEServer(t)
|
||||||
|
defer acmeServer.Close()
|
||||||
|
|
||||||
|
yamlConfig := buildMultiCertYAML(acmeServer.URL())
|
||||||
|
var cfg autocert.Config
|
||||||
|
cfg.HTTPClient = acmeServer.httpClient()
|
||||||
|
|
||||||
|
/* unmarshal yaml config with multiple certs */
|
||||||
|
err := error(serialization.UnmarshalValidateYAML(yamlConfig, &cfg))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"main.example.com"}, cfg.Domains)
|
||||||
|
require.Len(t, cfg.Extra, 2)
|
||||||
|
require.Equal(t, []string{"extra1.example.com"}, cfg.Extra[0].Domains)
|
||||||
|
require.Equal(t, []string{"extra2.example.com"}, cfg.Extra[1].Domains)
|
||||||
|
|
||||||
|
var provider *autocert.Provider
|
||||||
|
|
||||||
|
/* initialize autocert with multi-cert config */
|
||||||
|
user, legoCfg, gerr := cfg.GetLegoConfig()
|
||||||
|
require.NoError(t, gerr)
|
||||||
|
provider, err = autocert.NewProvider(&cfg, user, legoCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, provider)
|
||||||
|
|
||||||
|
// Start renewal scheduler
|
||||||
|
root := task.RootTask("test", false)
|
||||||
|
defer root.Finish(nil)
|
||||||
|
provider.ScheduleRenewalAll(root)
|
||||||
|
|
||||||
|
require.Equal(t, "custom", cfg.Provider)
|
||||||
|
require.Equal(t, "custom", cfg.Extra[0].Provider)
|
||||||
|
require.Equal(t, "custom", cfg.Extra[1].Provider)
|
||||||
|
|
||||||
|
/* track cert requests for all configs */
|
||||||
|
os.MkdirAll("certs", 0755)
|
||||||
|
defer os.RemoveAll("certs")
|
||||||
|
|
||||||
|
err = provider.ObtainCertIfNotExistsAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, 1, acmeServer.certRequestCount["main.example.com"])
|
||||||
|
require.Equal(t, 1, acmeServer.certRequestCount["extra1.example.com"])
|
||||||
|
require.Equal(t, 1, acmeServer.certRequestCount["extra2.example.com"])
|
||||||
|
|
||||||
|
/* track renewal scheduling and requests */
|
||||||
|
|
||||||
|
// force renewal for all providers and wait for completion
|
||||||
|
ok := provider.ForceExpiryAll()
|
||||||
|
require.True(t, ok)
|
||||||
|
provider.WaitRenewalDone(t.Context())
|
||||||
|
|
||||||
|
require.Equal(t, 1, acmeServer.renewalRequestCount["main.example.com"])
|
||||||
|
require.Equal(t, 1, acmeServer.renewalRequestCount["extra1.example.com"])
|
||||||
|
require.Equal(t, 1, acmeServer.renewalRequestCount["extra2.example.com"])
|
||||||
|
}
|
||||||
@@ -71,15 +71,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert, KeyPath: extraKey},
|
{CertPath: extraCert, KeyPath: extraKey},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -100,15 +103,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert, KeyPath: extraKey},
|
{CertPath: extraCert, KeyPath: extraKey},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -129,15 +135,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert, KeyPath: extraKey},
|
{CertPath: extraCert, KeyPath: extraKey},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -159,8 +168,11 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(nil)
|
cert, err := p.GetCert(nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -182,8 +194,11 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -204,15 +219,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert, KeyPath: extraKey},
|
{CertPath: extraCert, KeyPath: extraKey},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -233,15 +251,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert, KeyPath: extraKey},
|
{CertPath: extraCert, KeyPath: extraKey},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -262,15 +283,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert1, KeyPath: extraKey1},
|
{CertPath: extraCert1, KeyPath: extraKey1},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -292,8 +316,11 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -317,7 +344,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert1, KeyPath: extraKey1},
|
{CertPath: extraCert1, KeyPath: extraKey1},
|
||||||
{CertPath: extraCert2, KeyPath: extraKey2},
|
{CertPath: extraCert2, KeyPath: extraKey2},
|
||||||
},
|
},
|
||||||
@@ -325,8 +352,11 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"})
|
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -352,15 +382,18 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
Provider: autocert.ProviderLocal,
|
Provider: autocert.ProviderLocal,
|
||||||
CertPath: mainCert,
|
CertPath: mainCert,
|
||||||
KeyPath: mainKey,
|
KeyPath: mainKey,
|
||||||
Extra: []autocert.Config{
|
Extra: []autocert.ConfigExtra{
|
||||||
{CertPath: extraCert, KeyPath: extraKey},
|
{CertPath: extraCert, KeyPath: extraKey},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, cfg.Validate())
|
require.NoError(t, cfg.Validate())
|
||||||
|
|
||||||
p := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, p.Setup())
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = p.LoadCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -1,101 +1,30 @@
|
|||||||
package autocert
|
package autocert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
gperr "github.com/yusing/goutils/errs"
|
gperr "github.com/yusing/goutils/errs"
|
||||||
strutils "github.com/yusing/goutils/strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Provider) Setup() (err error) {
|
func (p *Provider) setupExtraProviders() gperr.Error {
|
||||||
if err = p.LoadCert(); err != nil {
|
|
||||||
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Debug().Msg("obtaining cert due to error loading cert")
|
|
||||||
if err = p.ObtainCert(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = p.setupExtraProviders(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, expiry := range p.GetExpiries() {
|
|
||||||
log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) setupExtraProviders() error {
|
|
||||||
p.extraProviders = nil
|
|
||||||
p.sniMatcher = sniMatcher{}
|
p.sniMatcher = sniMatcher{}
|
||||||
if len(p.cfg.Extra) == 0 {
|
if len(p.cfg.Extra) == 0 {
|
||||||
p.rebuildSNIMatcher()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range p.cfg.Extra {
|
p.extraProviders = make([]*Provider, 0, len(p.cfg.Extra))
|
||||||
merged := mergeExtraConfig(p.cfg, &p.cfg.Extra[i])
|
|
||||||
user, legoCfg, err := merged.GetLegoConfig()
|
errs := gperr.NewBuilder("setup extra providers error")
|
||||||
|
for _, extra := range p.cfg.Extra {
|
||||||
|
user, legoCfg, err := extra.AsConfig().GetLegoConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Subjectf("extra[%d]", i)
|
errs.Add(p.fmtError(err))
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
ep := NewProvider(&merged, user, legoCfg)
|
ep, err := NewProvider(extra.AsConfig(), user, legoCfg)
|
||||||
if err := ep.Setup(); err != nil {
|
if err != nil {
|
||||||
return gperr.PrependSubject(fmt.Sprintf("extra[%d]", i), err)
|
errs.Add(p.fmtError(err))
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
p.extraProviders = append(p.extraProviders, ep)
|
p.extraProviders = append(p.extraProviders, ep)
|
||||||
}
|
}
|
||||||
p.rebuildSNIMatcher()
|
return errs.Error()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func mergeExtraConfig(mainCfg *Config, extraCfg *Config) Config {
|
|
||||||
merged := *mainCfg
|
|
||||||
merged.Extra = nil
|
|
||||||
merged.CertPath = extraCfg.CertPath
|
|
||||||
merged.KeyPath = extraCfg.KeyPath
|
|
||||||
|
|
||||||
if merged.Email == "" {
|
|
||||||
merged.Email = mainCfg.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(extraCfg.Domains) > 0 {
|
|
||||||
merged.Domains = extraCfg.Domains
|
|
||||||
}
|
|
||||||
if extraCfg.ACMEKeyPath != "" {
|
|
||||||
merged.ACMEKeyPath = extraCfg.ACMEKeyPath
|
|
||||||
}
|
|
||||||
if extraCfg.Provider != "" {
|
|
||||||
merged.Provider = extraCfg.Provider
|
|
||||||
}
|
|
||||||
if len(extraCfg.Options) > 0 {
|
|
||||||
merged.Options = extraCfg.Options
|
|
||||||
}
|
|
||||||
if len(extraCfg.Resolvers) > 0 {
|
|
||||||
merged.Resolvers = extraCfg.Resolvers
|
|
||||||
}
|
|
||||||
if extraCfg.CADirURL != "" {
|
|
||||||
merged.CADirURL = extraCfg.CADirURL
|
|
||||||
}
|
|
||||||
if len(extraCfg.CACerts) > 0 {
|
|
||||||
merged.CACerts = extraCfg.CACerts
|
|
||||||
}
|
|
||||||
if extraCfg.EABKid != "" {
|
|
||||||
merged.EABKid = extraCfg.EABKid
|
|
||||||
}
|
|
||||||
if extraCfg.EABHmac != "" {
|
|
||||||
merged.EABHmac = extraCfg.EABHmac
|
|
||||||
}
|
|
||||||
if extraCfg.HTTPClient != nil {
|
|
||||||
merged.HTTPClient = extraCfg.HTTPClient
|
|
||||||
}
|
|
||||||
return merged
|
|
||||||
}
|
}
|
||||||
|
|||||||
82
internal/autocert/setup_test.go
Normal file
82
internal/autocert/setup_test.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package autocert_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/yusing/godoxy/internal/autocert"
|
||||||
|
"github.com/yusing/godoxy/internal/dnsproviders"
|
||||||
|
"github.com/yusing/godoxy/internal/serialization"
|
||||||
|
strutils "github.com/yusing/goutils/strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetupExtraProviders(t *testing.T) {
|
||||||
|
dnsproviders.InitProviders()
|
||||||
|
|
||||||
|
cfgYAML := `
|
||||||
|
email: test@example.com
|
||||||
|
domains: [example.com]
|
||||||
|
provider: custom
|
||||||
|
ca_dir_url: https://ca.example.com:9000/acme/acme/directory
|
||||||
|
cert_path: certs/test.crt
|
||||||
|
key_path: certs/test.key
|
||||||
|
options: {key: value}
|
||||||
|
resolvers: [8.8.8.8]
|
||||||
|
ca_certs: [ca.crt]
|
||||||
|
eab_kid: eabKid
|
||||||
|
eab_hmac: eabHmac
|
||||||
|
extra:
|
||||||
|
- cert_path: certs/extra.crt
|
||||||
|
key_path: certs/extra.key
|
||||||
|
- cert_path: certs/extra2.crt
|
||||||
|
key_path: certs/extra2.key
|
||||||
|
email: override@example.com
|
||||||
|
provider: pseudo
|
||||||
|
domains: [override.com]
|
||||||
|
ca_dir_url: https://ca2.example.com/directory
|
||||||
|
options: {opt2: val2}
|
||||||
|
resolvers: [1.1.1.1]
|
||||||
|
ca_certs: [ca2.crt]
|
||||||
|
eab_kid: eabKid2
|
||||||
|
eab_hmac: eabHmac2
|
||||||
|
`
|
||||||
|
|
||||||
|
var cfg autocert.Config
|
||||||
|
err := error(serialization.UnmarshalValidateYAML([]byte(cfgYAML), &cfg))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test: extra[0] inherits all fields from main except CertPath and KeyPath.
|
||||||
|
merged0 := cfg.Extra[0]
|
||||||
|
require.Equal(t, "certs/extra.crt", merged0.CertPath)
|
||||||
|
require.Equal(t, "certs/extra.key", merged0.KeyPath)
|
||||||
|
// Inherited fields from main config:
|
||||||
|
require.Equal(t, "test@example.com", merged0.Email) // inherited
|
||||||
|
require.Equal(t, "custom", merged0.Provider) // inherited
|
||||||
|
require.Equal(t, []string{"example.com"}, merged0.Domains) // inherited
|
||||||
|
require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", merged0.CADirURL) // inherited
|
||||||
|
require.Equal(t, map[string]strutils.Redacted{"key": "value"}, merged0.Options) // inherited
|
||||||
|
require.Equal(t, []string{"8.8.8.8"}, merged0.Resolvers) // inherited
|
||||||
|
require.Equal(t, []string{"ca.crt"}, merged0.CACerts) // inherited
|
||||||
|
require.Equal(t, "eabKid", merged0.EABKid) // inherited
|
||||||
|
require.Equal(t, "eabHmac", merged0.EABHmac) // inherited
|
||||||
|
require.Equal(t, cfg.HTTPClient, merged0.HTTPClient) // inherited
|
||||||
|
require.Nil(t, merged0.Extra)
|
||||||
|
|
||||||
|
// Test: extra[1] overrides some fields, and inherits others.
|
||||||
|
merged1 := cfg.Extra[1]
|
||||||
|
require.Equal(t, "certs/extra2.crt", merged1.CertPath)
|
||||||
|
require.Equal(t, "certs/extra2.key", merged1.KeyPath)
|
||||||
|
// Overridden fields:
|
||||||
|
require.Equal(t, "override@example.com", merged1.Email) // overridden
|
||||||
|
require.Equal(t, "pseudo", merged1.Provider) // overridden
|
||||||
|
require.Equal(t, []string{"override.com"}, merged1.Domains) // overridden
|
||||||
|
require.Equal(t, "https://ca2.example.com/directory", merged1.CADirURL) // overridden
|
||||||
|
require.Equal(t, map[string]strutils.Redacted{"opt2": "val2"}, merged1.Options) // overridden
|
||||||
|
require.Equal(t, []string{"1.1.1.1"}, merged1.Resolvers) // overridden
|
||||||
|
require.Equal(t, []string{"ca2.crt"}, merged1.CACerts) // overridden
|
||||||
|
require.Equal(t, "eabKid2", merged1.EABKid) // overridden
|
||||||
|
require.Equal(t, "eabHmac2", merged1.EABHmac) // overridden
|
||||||
|
// Inherited field:
|
||||||
|
require.Equal(t, cfg.HTTPClient, merged1.HTTPClient) // inherited
|
||||||
|
require.Nil(t, merged1.Extra)
|
||||||
|
}
|
||||||
@@ -9,6 +9,6 @@ import (
|
|||||||
type Provider interface {
|
type Provider interface {
|
||||||
Setup() error
|
Setup() error
|
||||||
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
ScheduleRenewal(task.Parent)
|
ScheduleRenewalAll(task.Parent)
|
||||||
ObtainCert() error
|
ObtainCertAll() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -272,6 +272,7 @@ func (state *state) initAutoCert() error {
|
|||||||
autocertCfg := state.AutoCert
|
autocertCfg := state.AutoCert
|
||||||
if autocertCfg == nil {
|
if autocertCfg == nil {
|
||||||
autocertCfg = new(autocert.Config)
|
autocertCfg = new(autocert.Config)
|
||||||
|
_ = autocertCfg.Validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
user, legoCfg, err := autocertCfg.GetLegoConfig()
|
user, legoCfg, err := autocertCfg.GetLegoConfig()
|
||||||
@@ -279,12 +280,19 @@ func (state *state) initAutoCert() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
state.autocertProvider = autocert.NewProvider(autocertCfg, user, legoCfg)
|
p, err := autocert.NewProvider(autocertCfg, user, legoCfg)
|
||||||
if err := state.autocertProvider.Setup(); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("autocert error: %w", err)
|
return err
|
||||||
} else {
|
|
||||||
state.autocertProvider.ScheduleRenewal(state.task)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := p.ObtainCertIfNotExistsAll(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.ScheduleRenewalAll(state.task)
|
||||||
|
p.PrintCertExpiriesAll()
|
||||||
|
|
||||||
|
state.autocertProvider = p
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package dnsproviders
|
package dnsproviders
|
||||||
|
|
||||||
type (
|
type (
|
||||||
DummyConfig struct{}
|
DummyConfig map[string]any
|
||||||
DummyProvider struct{}
|
DummyProvider struct{}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user