v0.5.0-rc4: fixing autocert issue, cache ACME registration, added ls-config option

This commit is contained in:
yusing
2024-09-17 08:41:36 +08:00
parent 04fd6543fd
commit 82f06374f7
33 changed files with 301 additions and 221 deletions

View File

@@ -20,6 +20,9 @@ func NewConfig(cfg *M.AutoCertConfig) *Config {
if cfg.KeyPath == "" {
cfg.KeyPath = KeyFileDefault
}
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
}
return (*Config)(cfg)
}
@@ -36,43 +39,35 @@ func (cfg *Config) GetProvider() (*Provider, E.NestedError) {
if cfg.Email == "" {
errors.Addf("no email specified")
}
// check if provider is implemented
_, ok := providersGenMap[cfg.Provider]
if !ok {
errors.Addf("unknown provider: %q", cfg.Provider)
}
}
gen, ok := providersGenMap[cfg.Provider]
if !ok {
errors.Addf("unknown provider: %q", cfg.Provider)
}
if err := errors.Build(); err.IsNotNil() {
if err := errors.Build(); err.HasError() {
return nil, err
}
privKey, err := E.Check(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
if err.IsNotNil() {
if err.HasError() {
return nil, E.Failure("generate private key").With(err)
}
user := &User{
Email: cfg.Email,
key: privKey,
}
legoCfg := lego.NewConfig(user)
legoCfg.Certificate.KeyType = certcrypto.RSA2048
legoClient, err := E.Check(lego.NewClient(legoCfg))
if err.IsNotNil() {
return nil, E.Failure("create lego client").With(err)
}
base := &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
client: legoClient,
}
legoProvider, err := E.Check(gen(cfg.Options))
if err.IsNotNil() {
return nil, E.Failure("create lego provider").With(err)
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.IsNotNil() {
return nil, E.Failure("set challenge provider").With(err)
}
return base, E.Nil()
}

View File

@@ -8,9 +8,10 @@ import (
)
const (
certBasePath = "certs/"
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
certBasePath = "certs/"
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
RegistrationFile = certBasePath + "registration.json"
)
const (
@@ -21,11 +22,10 @@ const (
)
var providersGenMap = map[string]ProviderGenerator{
"": providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
ProviderLocal: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
ProviderCloudflare: providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig),
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
}
var Logger = logrus.WithField("module", "autocert")
var logger = logrus.WithField("module", "autocert")

View File

@@ -5,18 +5,17 @@ import (
"crypto/tls"
"crypto/x509"
"os"
"slices"
"sync"
"reflect"
"sort"
"time"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
"github.com/sirupsen/logrus"
E "github.com/yusing/go-proxy/error"
M "github.com/yusing/go-proxy/models"
"github.com/yusing/go-proxy/utils"
U "github.com/yusing/go-proxy/utils"
)
type Provider struct {
@@ -27,10 +26,9 @@ type Provider struct {
tlsCert *tls.Certificate
certExpiries CertExpiries
mutex sync.Mutex
}
type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, error)
type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, E.NestedError)
type CertExpiries map[string]time.Time
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -57,59 +55,72 @@ func (p *Provider) GetExpiries() CertExpiries {
}
func (p *Provider) ObtainCert() E.NestedError {
if p.cfg.Provider == ProviderLocal {
return E.FailureWhy("obtain cert", "provider is set to \"local\"")
}
if p.client == nil {
if err := p.initClient(); err.HasError() {
return E.Failure("obtain cert").With(err)
}
}
ne := E.Failure("obtain certificate")
client := p.client
if p.user.Registration == nil {
reg, err := E.Check(client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.IsNotNil() {
return ne.With(E.Failure("register account").With(err))
if err := p.loadRegistration(); err.HasError() {
ne = ne.With(err)
if err := p.registerACME(); err.HasError() {
return ne.With(err)
}
}
p.user.Registration = reg
}
req := certificate.ObtainRequest{
Domains: p.cfg.Domains,
Bundle: true,
}
cert, err := E.Check(client.Certificate.Obtain(req))
if err.IsNotNil() {
if err.HasError() {
return ne.With(err)
}
err = p.saveCert(cert)
if err.IsNotNil() {
if err.HasError() {
return ne.With(E.Failure("save certificate").With(err))
}
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
if err.IsNotNil() {
if err.HasError() {
return ne.With(E.Failure("parse obtained certificate").With(err))
}
expiries, err := getCertExpiries(&tlsCert)
if err.IsNotNil() {
if err.HasError() {
return ne.With(E.Failure("get certificate expiry").With(err))
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
return E.Nil()
}
func (p *Provider) LoadCert() E.NestedError {
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
if err.IsNotNil() {
if err.HasError() {
return err
}
expiries, err := getCertExpiries(&cert)
if err.IsNotNil() {
if err.HasError() {
return err
}
p.tlsCert = &cert
p.certExpiries = expiries
p.renewIfNeeded()
return E.Nil()
logger.Infof("next renewal in %v", time.Until(p.ShouldRenewOn()))
return p.renewIfNeeded()
}
func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0)
return expiry.AddDate(0, -1, 0) // 1 month before
}
// this line should never be reached
panic("no certificate available")
@@ -120,117 +131,151 @@ func (p *Provider) ScheduleRenewal(ctx context.Context) {
return
}
logger.Debug("starting renewal scheduler")
logger.Debug("started renewal scheduler")
defer logger.Debug("renewal scheduler stopped")
stop := make(chan struct{})
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
default:
t := time.Until(p.ShouldRenewOn())
Logger.Infof("next renewal in %v", t.Round(time.Second))
go func() {
<-time.After(t)
close(stop)
}()
select {
case <-ctx.Done():
return
case <-stop:
if err := p.renewIfNeeded(); err.IsNotNil() {
Logger.Fatal(err)
}
case <-ticker.C: // check every 5 seconds
if err := p.renewIfNeeded(); err.HasError() {
logger.Warn(err)
}
}
}
}
func (p *Provider) initClient() E.NestedError {
legoClient, err := E.Check(lego.NewClient(p.legoCfg))
if err.HasError() {
return E.Failure("create lego client").With(err)
}
legoProvider, err := providersGenMap[p.cfg.Provider](p.cfg.Options)
if err.HasError() {
return E.Failure("create lego provider").With(err)
}
err = E.From(legoClient.Challenge.SetDNS01Provider(legoProvider))
if err.HasError() {
return E.Failure("set challenge provider").With(err)
}
p.client = legoClient
return E.Nil()
}
func (p *Provider) registerACME() E.NestedError {
if p.user.Registration != nil {
return E.Nil()
}
reg, err := E.Check(p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
if err.HasError() {
return E.Failure("register ACME").With(err)
}
p.user.Registration = reg
if err := p.saveRegistration(); err.HasError() {
logger.Warn(err)
}
return E.Nil()
}
func (p *Provider) loadRegistration() E.NestedError {
if p.user.Registration != nil {
return E.Nil()
}
reg := &registration.Resource{}
err := U.LoadJson(RegistrationFile, reg)
if err.HasError() {
return E.Failure("parse registration file").With(err)
}
p.user.Registration = reg
return E.Nil()
}
func (p *Provider) saveRegistration() E.NestedError {
return U.SaveJson(RegistrationFile, p.user.Registration, 0o600)
}
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0600) // -rw-------
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
if err != nil {
return E.Failure("write key file").With(err)
}
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0644) // -rw-r--r--
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
if err != nil {
return E.Failure("write cert file").With(err)
}
return E.Nil()
}
func (p *Provider) needRenewal() bool {
expired := time.Now().After(p.ShouldRenewOn())
if expired {
return true
func (p *Provider) certState() CertState {
if time.Now().After(p.ShouldRenewOn()) {
return CertStateExpired
}
if len(p.cfg.Domains) != len(p.certExpiries) {
return true
}
wantedDomains := make([]string, len(p.cfg.Domains))
certDomains := make([]string, len(p.certExpiries))
copy(wantedDomains, p.cfg.Domains)
wantedDomains := make([]string, len(p.cfg.Domains))
i := 0
for domain := range p.certExpiries {
certDomains[i] = domain
i++
}
slices.Sort(wantedDomains)
slices.Sort(certDomains)
for i, domain := range certDomains {
if domain != wantedDomains[i] {
return true
}
copy(wantedDomains, p.cfg.Domains)
sort.Strings(wantedDomains)
sort.Strings(certDomains)
if !reflect.DeepEqual(certDomains, wantedDomains) {
logger.Debugf("cert domains mismatch: %v != %v", certDomains, p.cfg.Domains)
return CertStateMismatch
}
return false
return CertStateValid
}
func (p *Provider) renewIfNeeded() E.NestedError {
if !p.needRenewal() {
switch p.certState() {
case CertStateExpired:
logger.Info("certs expired, renewing")
case CertStateMismatch:
logger.Info("cert domains mismatch with config, renewing")
default:
return E.Nil()
}
p.mutex.Lock()
defer p.mutex.Unlock()
if !p.needRenewal() {
return E.Nil()
}
trials := 0
for {
err := p.ObtainCert()
if err.IsNotNil() {
return E.Nil()
}
trials++
if trials > 3 {
return E.Failure("renew certificate").With(err)
}
time.Sleep(5 * time.Second)
if err := p.ObtainCert(); err.HasError() {
return E.Failure("renew certificate").With(err)
}
return E.Nil()
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
r := make(CertExpiries, len(cert.Certificate))
for _, cert := range cert.Certificate {
x509Cert, err := E.Check(x509.ParseCertificate(cert))
if err.IsNotNil() {
if err.HasError() {
return nil, E.Failure("parse certificate").With(err)
}
if x509Cert.IsCA {
continue
}
r[x509Cert.Subject.CommonName] = x509Cert.NotAfter
for i := range x509Cert.DNSNames {
r[x509Cert.DNSNames[i]] = x509Cert.NotAfter
}
}
return r, E.Nil()
}
func setOptions[T interface{}](cfg *T, opt M.AutocertProviderOpt) E.NestedError {
for k, v := range opt {
err := utils.SetFieldFromSnake(cfg, k, v)
if err.IsNotNil() {
err := U.SetFieldFromSnake(cfg, k, v)
if err.HasError() {
return E.Failure("set autocert option").Subject(k).With(err)
}
}
@@ -241,18 +286,16 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT,
newProvider func(*CT) (PT, error),
) ProviderGenerator {
return func(opt M.AutocertProviderOpt) (challenge.Provider, error) {
return func(opt M.AutocertProviderOpt) (challenge.Provider, E.NestedError) {
cfg := defaultCfg()
err := setOptions(cfg, opt)
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
p, err := E.Check(newProvider(cfg))
if err.IsNotNil() {
if err.HasError() {
return nil, err
}
return p, nil
return p, E.Nil()
}
}
var logger = logrus.WithField("module", "autocert")

9
src/autocert/state.go Normal file
View File

@@ -0,0 +1,9 @@
package autocert
type CertState int
const (
CertStateValid CertState = 0
CertStateExpired CertState = iota
CertStateMismatch CertState = iota
)