mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-19 15:01:22 +02:00
fix(autocert): correct ObtainCert error handling
- ObtainCertIfNotExistsAll longer fail on fs.ErrNotExists - Separate public LoadCertAll (loads all providers) from private loadCert - LoadCertAll now uses allProviders() for iteration - Updated tests to use LoadCertAll
This commit is contained in:
@@ -228,7 +228,7 @@ func (p *Provider) ObtainCertIfNotExistsAll() error {
|
|||||||
|
|
||||||
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
|
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
|
||||||
func (p *Provider) obtainCertIfNotExists() error {
|
func (p *Provider) obtainCertIfNotExists() error {
|
||||||
err := p.LoadCert()
|
err := p.loadCert()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -346,29 +346,32 @@ func (p *Provider) ObtainCert() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) LoadCert() error {
|
func (p *Provider) LoadCertAll() error {
|
||||||
var errs gperr.Builder
|
var errs gperr.Builder
|
||||||
|
for _, provider := range p.allProviders() {
|
||||||
|
if err := provider.loadCert(); err != nil {
|
||||||
|
errs.Add(provider.fmtError(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.rebuildSNIMatcher()
|
||||||
|
return errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) loadCert() error {
|
||||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.Addf("load SSL certificate: %w", p.fmtError(err))
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
expiries, err := getCertExpiries(&cert)
|
expiries, err := getCertExpiries(&cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs.Addf("parse SSL certificate: %w", p.fmtError(err))
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.tlsCert = &cert
|
p.tlsCert = &cert
|
||||||
p.certExpiries = expiries
|
p.certExpiries = expiries
|
||||||
|
|
||||||
for _, ep := range p.extraProviders {
|
return nil
|
||||||
if err := ep.LoadCert(); err != nil {
|
|
||||||
errs.Add(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.rebuildSNIMatcher()
|
|
||||||
return errs.Error()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
|
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"})
|
||||||
@@ -113,7 +113,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
||||||
@@ -145,7 +145,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"})
|
||||||
@@ -171,7 +171,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(nil)
|
cert, err := p.GetCert(nil)
|
||||||
@@ -197,7 +197,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""})
|
||||||
@@ -229,7 +229,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"})
|
||||||
@@ -261,7 +261,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "})
|
||||||
@@ -293,7 +293,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"})
|
||||||
@@ -319,7 +319,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
|
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
|
||||||
@@ -355,7 +355,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"})
|
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"})
|
||||||
@@ -392,7 +392,7 @@ func TestGetCertBySNI(t *testing.T) {
|
|||||||
p, err := autocert.NewProvider(cfg, nil, nil)
|
p, err := autocert.NewProvider(cfg, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = p.LoadCert()
|
err = p.LoadCertAll()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
|
||||||
|
|||||||
Reference in New Issue
Block a user