mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 16:58:31 +02:00
autocert: refactor and add pseudo provider for testing
This commit is contained in:
@@ -4,17 +4,19 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/certificate"
|
||||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
@@ -31,8 +33,10 @@ type (
|
||||
legoCert *certificate.Resource
|
||||
tlsCert *tls.Certificate
|
||||
certExpiries CertExpiries
|
||||
|
||||
obtainMu sync.Mutex
|
||||
}
|
||||
ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error)
|
||||
ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error)
|
||||
|
||||
CertExpiries map[string]time.Time
|
||||
)
|
||||
@@ -62,11 +66,22 @@ func (p *Provider) GetExpiries() CertExpiries {
|
||||
return p.certExpiries
|
||||
}
|
||||
|
||||
func (p *Provider) ObtainCert() E.Error {
|
||||
func (p *Provider) ObtainCert() error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.cfg.Provider == ProviderPseudo {
|
||||
t := time.NewTicker(1000 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
logging.Info().Msg("init client for pseudo provider")
|
||||
<-t.C
|
||||
logging.Info().Msg("registering acme for pseudo provider")
|
||||
<-t.C
|
||||
logging.Info().Msg("obtained cert for pseudo provider")
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.client == nil {
|
||||
if err := p.initClient(); err != nil {
|
||||
return err
|
||||
@@ -75,7 +90,7 @@ func (p *Provider) ObtainCert() E.Error {
|
||||
|
||||
if p.user.Registration == nil {
|
||||
if err := p.registerACME(); err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,22 +115,22 @@ func (p *Provider) ObtainCert() E.Error {
|
||||
Bundle: true,
|
||||
})
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err = p.saveCert(cert); err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
|
||||
expiries, err := getCertExpiries(&tlsCert)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
p.tlsCert = &tlsCert
|
||||
p.certExpiries = expiries
|
||||
@@ -123,14 +138,14 @@ func (p *Provider) ObtainCert() E.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) LoadCert() E.Error {
|
||||
func (p *Provider) LoadCert() error {
|
||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||
if err != nil {
|
||||
return E.Errorf("load SSL certificate: %w", err)
|
||||
return fmt.Errorf("load SSL certificate: %w", err)
|
||||
}
|
||||
expiries, err := getCertExpiries(&cert)
|
||||
if err != nil {
|
||||
return E.Errorf("parse SSL certificate: %w", err)
|
||||
return fmt.Errorf("parse SSL certificate: %w", err)
|
||||
}
|
||||
p.tlsCert = &cert
|
||||
p.certExpiries = expiries
|
||||
@@ -149,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time {
|
||||
}
|
||||
|
||||
func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||
if p.GetName() == ProviderLocal {
|
||||
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
@@ -171,7 +186,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||
continue
|
||||
}
|
||||
if err := p.renewIfNeeded(); err != nil {
|
||||
E.LogWarn("cert renew failed", err)
|
||||
gperr.LogWarn("cert renew failed", err)
|
||||
lastErrOn = time.Now()
|
||||
continue
|
||||
}
|
||||
@@ -184,10 +199,10 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Provider) initClient() E.Error {
|
||||
func (p *Provider) initClient() error {
|
||||
legoClient, err := lego.NewClient(p.legoCfg)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
|
||||
generator := providersGenMap[p.cfg.Provider]
|
||||
@@ -198,7 +213,7 @@ func (p *Provider) initClient() E.Error {
|
||||
|
||||
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
||||
if err != nil {
|
||||
return E.From(err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.client = legoClient
|
||||
@@ -273,7 +288,7 @@ func (p *Provider) certState() CertState {
|
||||
return CertStateValid
|
||||
}
|
||||
|
||||
func (p *Provider) renewIfNeeded() E.Error {
|
||||
func (p *Provider) renewIfNeeded() error {
|
||||
if p.cfg.Provider == ProviderLocal {
|
||||
return nil
|
||||
}
|
||||
@@ -312,13 +327,13 @@ func providerGenerator[CT any, PT challenge.Provider](
|
||||
defaultCfg func() *CT,
|
||||
newProvider func(*CT) (PT, error),
|
||||
) ProviderGenerator {
|
||||
return func(opt ProviderOpt) (challenge.Provider, E.Error) {
|
||||
return func(opt ProviderOpt) (challenge.Provider, gperr.Error) {
|
||||
cfg := defaultCfg()
|
||||
err := U.Deserialize(opt, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p, pErr := newProvider(cfg)
|
||||
return p, E.From(pErr)
|
||||
return p, gperr.Wrap(pErr)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user