diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index df3e56a8..4bb70e0e 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -228,7 +228,7 @@ func (p *Provider) ObtainCertIfNotExistsAll() error { // obtainCertIfNotExists obtains a new certificate for this provider if it does not exist. func (p *Provider) obtainCertIfNotExists() error { - err := p.LoadCert() + err := p.loadCert() if err == nil { return nil } @@ -346,29 +346,32 @@ func (p *Provider) ObtainCert() error { return nil } -func (p *Provider) LoadCert() error { +func (p *Provider) LoadCertAll() error { var errs gperr.Builder + for _, provider := range p.allProviders() { + if err := provider.loadCert(); err != nil { + errs.Add(provider.fmtError(err)) + } + } + p.rebuildSNIMatcher() + return errs.Error() +} + +func (p *Provider) loadCert() error { cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath) if err != nil { - errs.Addf("load SSL certificate: %w", p.fmtError(err)) + return err } expiries, err := getCertExpiries(&cert) if err != nil { - errs.Addf("parse SSL certificate: %w", p.fmtError(err)) + return err } p.tlsCert = &cert p.certExpiries = expiries - for _, ep := range p.extraProviders { - if err := ep.LoadCert(); err != nil { - errs.Add(err) - } - } - - p.rebuildSNIMatcher() - return errs.Error() + return nil } // PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers. diff --git a/internal/autocert/provider_test/sni_test.go b/internal/autocert/provider_test/sni_test.go index 01a07b8f..e813144f 100644 --- a/internal/autocert/provider_test/sni_test.go +++ b/internal/autocert/provider_test/sni_test.go @@ -81,7 +81,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"}) @@ -113,7 +113,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) @@ -145,7 +145,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"}) @@ -171,7 +171,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(nil) @@ -197,7 +197,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""}) @@ -229,7 +229,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"}) @@ -261,7 +261,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "}) @@ -293,7 +293,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"}) @@ -319,7 +319,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"}) @@ -355,7 +355,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"}) @@ -392,7 +392,7 @@ func TestGetCertBySNI(t *testing.T) { p, err := autocert.NewProvider(cfg, nil, nil) require.NoError(t, err) - err = p.LoadCert() + err = p.LoadCertAll() require.NoError(t, err) cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})