mirror of
https://github.com/yusing/godoxy.git
synced 2026-01-11 22:30:47 +01:00
646 lines
16 KiB
Go
646 lines
16 KiB
Go
package autocert
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"maps"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/go-acme/lego/v4/certificate"
|
|
"github.com/go-acme/lego/v4/lego"
|
|
"github.com/go-acme/lego/v4/registration"
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/yusing/godoxy/internal/common"
|
|
"github.com/yusing/godoxy/internal/notif"
|
|
gperr "github.com/yusing/goutils/errs"
|
|
strutils "github.com/yusing/goutils/strings"
|
|
"github.com/yusing/goutils/task"
|
|
)
|
|
|
|
type (
|
|
Provider struct {
|
|
logger zerolog.Logger
|
|
|
|
cfg *Config
|
|
user *User
|
|
legoCfg *lego.Config
|
|
client *lego.Client
|
|
lastFailure time.Time
|
|
|
|
lastFailureFile string
|
|
|
|
legoCert *certificate.Resource
|
|
tlsCert *tls.Certificate
|
|
certExpiries CertExpiries
|
|
|
|
extraProviders []*Provider
|
|
sniMatcher sniMatcher
|
|
|
|
forceRenewalCh chan struct{}
|
|
forceRenewalDoneCh atomic.Value // chan struct{}
|
|
|
|
scheduleRenewalOnce sync.Once
|
|
}
|
|
|
|
CertExpiries map[string]time.Time
|
|
RenewMode uint8
|
|
)
|
|
|
|
var ErrNoCertificate = errors.New("no certificate found")
|
|
|
|
const (
|
|
// renew failed for whatever reason, 1 hour cooldown
|
|
renewalCooldownDuration = 1 * time.Hour
|
|
// prevents cert request docker compose across restarts with `restart: always` (non-zero exit code)
|
|
requestCooldownDuration = 15 * time.Second
|
|
)
|
|
|
|
const (
|
|
renewModeForce = iota
|
|
renewModeIfNeeded
|
|
)
|
|
|
|
// could be nil
|
|
var ActiveProvider atomic.Pointer[Provider]
|
|
|
|
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) (*Provider, error) {
|
|
p := &Provider{
|
|
cfg: cfg,
|
|
user: user,
|
|
legoCfg: legoCfg,
|
|
lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath),
|
|
forceRenewalCh: make(chan struct{}, 1),
|
|
}
|
|
p.forceRenewalDoneCh.Store(emptyForceRenewalDoneCh)
|
|
|
|
if cfg.idx == 0 {
|
|
p.logger = log.With().Str("provider", "main").Logger()
|
|
} else {
|
|
p.logger = log.With().Str("provider", fmt.Sprintf("extra[%d]", cfg.idx)).Logger()
|
|
}
|
|
if err := p.setupExtraProviders(); err != nil {
|
|
return nil, err
|
|
}
|
|
return p, nil
|
|
}
|
|
|
|
func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
if p.tlsCert == nil {
|
|
return nil, ErrNoCertificate
|
|
}
|
|
if hello == nil || hello.ServerName == "" {
|
|
return p.tlsCert, nil
|
|
}
|
|
if prov := p.sniMatcher.match(hello.ServerName); prov != nil && prov.tlsCert != nil {
|
|
return prov.tlsCert, nil
|
|
}
|
|
return p.tlsCert, nil
|
|
}
|
|
|
|
func (p *Provider) GetName() string {
|
|
if p.cfg.idx == 0 {
|
|
return "main"
|
|
}
|
|
return fmt.Sprintf("extra[%d]", p.cfg.idx)
|
|
}
|
|
|
|
func (p *Provider) fmtError(err error) error {
|
|
return gperr.PrependSubject(fmt.Sprintf("provider: %s", p.GetName()), err)
|
|
}
|
|
|
|
func (p *Provider) GetCertPath() string {
|
|
return p.cfg.CertPath
|
|
}
|
|
|
|
func (p *Provider) GetKeyPath() string {
|
|
return p.cfg.KeyPath
|
|
}
|
|
|
|
func (p *Provider) GetExpiries() CertExpiries {
|
|
return p.certExpiries
|
|
}
|
|
|
|
func (p *Provider) GetLastFailure() (time.Time, error) {
|
|
if common.IsTest {
|
|
return time.Time{}, nil
|
|
}
|
|
|
|
if p.lastFailure.IsZero() {
|
|
data, err := os.ReadFile(p.lastFailureFile)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
return time.Time{}, err
|
|
}
|
|
} else {
|
|
p.lastFailure, _ = time.Parse(time.RFC3339, string(data))
|
|
}
|
|
}
|
|
return p.lastFailure, nil
|
|
}
|
|
|
|
func (p *Provider) UpdateLastFailure() error {
|
|
if common.IsTest {
|
|
return nil
|
|
}
|
|
t := time.Now()
|
|
p.lastFailure = t
|
|
return os.WriteFile(p.lastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600)
|
|
}
|
|
|
|
func (p *Provider) ClearLastFailure() error {
|
|
if common.IsTest {
|
|
return nil
|
|
}
|
|
p.lastFailure = time.Time{}
|
|
err := os.Remove(p.lastFailureFile)
|
|
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// allProviders returns all providers including this provider and all extra providers.
|
|
func (p *Provider) allProviders() []*Provider {
|
|
return append([]*Provider{p}, p.extraProviders...)
|
|
}
|
|
|
|
// ObtainCertIfNotExistsAll obtains a new certificate for this provider and all extra providers if they do not exist.
|
|
func (p *Provider) ObtainCertIfNotExistsAll() error {
|
|
errs := gperr.NewGroup("obtain cert error")
|
|
|
|
for _, provider := range p.allProviders() {
|
|
errs.Go(func() error {
|
|
if err := provider.obtainCertIfNotExists(); err != nil {
|
|
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
p.rebuildSNIMatcher()
|
|
return errs.Wait().Error()
|
|
}
|
|
|
|
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
|
|
func (p *Provider) obtainCertIfNotExists() error {
|
|
err := p.LoadCert()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
return err
|
|
}
|
|
|
|
// check last failure
|
|
lastFailure, err := p.GetLastFailure()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get last failure: %w", err)
|
|
}
|
|
if !lastFailure.IsZero() && time.Since(lastFailure) < requestCooldownDuration {
|
|
return fmt.Errorf("still in cooldown until %s", strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
|
|
}
|
|
|
|
p.logger.Info().Msg("cert not found, obtaining new cert")
|
|
return p.ObtainCert()
|
|
}
|
|
|
|
// ObtainCertAll renews existing certificates or obtains new certificates for this provider and all extra providers.
|
|
func (p *Provider) ObtainCertAll() error {
|
|
errs := gperr.NewGroup("obtain cert error")
|
|
for _, provider := range p.allProviders() {
|
|
errs.Go(func() error {
|
|
if err := provider.obtainCertIfNotExists(); err != nil {
|
|
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
return errs.Wait().Error()
|
|
}
|
|
|
|
// ObtainCert renews existing certificate or obtains a new certificate for this provider.
|
|
func (p *Provider) ObtainCert() error {
|
|
if p.cfg.Provider == ProviderLocal {
|
|
return nil
|
|
}
|
|
|
|
if p.cfg.Provider == ProviderPseudo {
|
|
p.logger.Info().Msg("init client for pseudo provider")
|
|
<-time.After(time.Second)
|
|
p.logger.Info().Msg("registering acme for pseudo provider")
|
|
<-time.After(time.Second)
|
|
p.logger.Info().Msg("obtained cert for pseudo provider")
|
|
return nil
|
|
}
|
|
|
|
if p.client == nil {
|
|
if err := p.initClient(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// mark it as failed first, clear it later if successful
|
|
// in case the process crashed / failed to renew, we put it on a cooldown
|
|
// this prevents rate limiting by the ACME server
|
|
if err := p.UpdateLastFailure(); err != nil {
|
|
return fmt.Errorf("failed to update last failure: %w", err)
|
|
}
|
|
|
|
if p.user.Registration == nil {
|
|
if err := p.registerACME(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
var cert *certificate.Resource
|
|
var err error
|
|
|
|
if p.legoCert != nil {
|
|
cert, err = p.client.Certificate.RenewWithOptions(*p.legoCert, &certificate.RenewOptions{
|
|
Bundle: true,
|
|
})
|
|
if err != nil {
|
|
p.legoCert = nil
|
|
log.Err(err).Msg("cert renew failed, fallback to obtain")
|
|
} else {
|
|
p.legoCert = cert
|
|
}
|
|
}
|
|
|
|
if cert == nil {
|
|
cert, err = p.client.Certificate.Obtain(certificate.ObtainRequest{
|
|
Domains: p.cfg.Domains,
|
|
Bundle: true,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err = p.saveCert(cert); err != nil {
|
|
return err
|
|
}
|
|
|
|
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
expiries, err := getCertExpiries(&tlsCert)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.tlsCert = &tlsCert
|
|
p.certExpiries = expiries
|
|
p.rebuildSNIMatcher()
|
|
|
|
if err := p.ClearLastFailure(); err != nil {
|
|
return fmt.Errorf("failed to clear last failure: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Provider) LoadCert() error {
|
|
var errs gperr.Builder
|
|
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
|
if err != nil {
|
|
errs.Addf("load SSL certificate: %w", p.fmtError(err))
|
|
}
|
|
|
|
expiries, err := getCertExpiries(&cert)
|
|
if err != nil {
|
|
errs.Addf("parse SSL certificate: %w", p.fmtError(err))
|
|
}
|
|
|
|
p.tlsCert = &cert
|
|
p.certExpiries = expiries
|
|
|
|
for _, ep := range p.extraProviders {
|
|
if err := ep.LoadCert(); err != nil {
|
|
errs.Add(err)
|
|
}
|
|
}
|
|
|
|
p.rebuildSNIMatcher()
|
|
return errs.Error()
|
|
}
|
|
|
|
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
|
|
func (p *Provider) PrintCertExpiriesAll() {
|
|
for _, provider := range p.allProviders() {
|
|
for domain, expiry := range provider.certExpiries {
|
|
p.logger.Info().Str("domain", domain).Msgf("certificate expire on %s", strutils.FormatTime(expiry))
|
|
}
|
|
}
|
|
}
|
|
|
|
// ShouldRenewOn returns the time at which the certificate should be renewed.
|
|
func (p *Provider) ShouldRenewOn() time.Time {
|
|
for _, expiry := range p.certExpiries {
|
|
return expiry.AddDate(0, -1, 0) // 1 month before
|
|
}
|
|
// this line should never be reached in production, but will be useful for testing
|
|
return time.Now().AddDate(0, 1, 0) // 1 month after
|
|
}
|
|
|
|
// ForceExpiryAll triggers immediate certificate renewal for this provider and all extra providers.
|
|
// Returns true if the renewal was triggered, false if the renewal was dropped.
|
|
//
|
|
// If at least one renewal is triggered, returns true.
|
|
func (p *Provider) ForceExpiryAll() (ok bool) {
|
|
doneCh := make(chan struct{})
|
|
if swapped := p.forceRenewalDoneCh.CompareAndSwap(emptyForceRenewalDoneCh, doneCh); !swapped { // already in progress
|
|
close(doneCh)
|
|
return false
|
|
}
|
|
|
|
select {
|
|
case p.forceRenewalCh <- struct{}{}:
|
|
ok = true
|
|
default:
|
|
}
|
|
|
|
for _, ep := range p.extraProviders {
|
|
if ep.ForceExpiryAll() {
|
|
ok = true
|
|
}
|
|
}
|
|
|
|
return ok
|
|
}
|
|
|
|
// WaitRenewalDone waits for the renewal to complete.
|
|
// Returns false if the renewal was dropped.
|
|
func (p *Provider) WaitRenewalDone(ctx context.Context) bool {
|
|
done, ok := p.forceRenewalDoneCh.Load().(chan struct{})
|
|
if !ok || done == nil {
|
|
return false
|
|
}
|
|
select {
|
|
case <-done:
|
|
case <-ctx.Done():
|
|
return false
|
|
}
|
|
|
|
for _, ep := range p.extraProviders {
|
|
if !ep.WaitRenewalDone(ctx) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// ScheduleRenewalAll schedules the renewal of the certificate for this provider and all extra providers.
|
|
func (p *Provider) ScheduleRenewalAll(parent task.Parent) {
|
|
p.scheduleRenewalOnce.Do(func() {
|
|
p.scheduleRenewal(parent)
|
|
})
|
|
for _, ep := range p.extraProviders {
|
|
ep.scheduleRenewalOnce.Do(func() {
|
|
ep.scheduleRenewal(parent)
|
|
})
|
|
}
|
|
}
|
|
|
|
var emptyForceRenewalDoneCh any = chan struct{}(nil)
|
|
|
|
// scheduleRenewal schedules the renewal of the certificate for this provider.
|
|
func (p *Provider) scheduleRenewal(parent task.Parent) {
|
|
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
|
|
return
|
|
}
|
|
|
|
timer := time.NewTimer(time.Until(p.ShouldRenewOn()))
|
|
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
|
|
|
|
renew := func(renewMode RenewMode) {
|
|
defer func() {
|
|
if done, ok := p.forceRenewalDoneCh.Swap(emptyForceRenewalDoneCh).(chan struct{}); ok && done != nil {
|
|
close(done)
|
|
}
|
|
}()
|
|
|
|
renewed, err := p.renew(renewMode)
|
|
if err != nil {
|
|
gperr.LogWarn("autocert: cert renew failed", p.fmtError(err))
|
|
notif.Notify(¬if.LogMessage{
|
|
Level: zerolog.ErrorLevel,
|
|
Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()),
|
|
Body: notif.MessageBody(err.Error()),
|
|
})
|
|
return
|
|
}
|
|
if renewed {
|
|
p.rebuildSNIMatcher()
|
|
|
|
notif.Notify(¬if.LogMessage{
|
|
Level: zerolog.InfoLevel,
|
|
Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()),
|
|
Body: notif.ListBody(p.cfg.Domains),
|
|
})
|
|
|
|
// Reset on success
|
|
if err := p.ClearLastFailure(); err != nil {
|
|
gperr.LogWarn("autocert: failed to clear last failure", p.fmtError(err))
|
|
}
|
|
timer.Reset(time.Until(p.ShouldRenewOn()))
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
defer timer.Stop()
|
|
defer task.Finish(nil)
|
|
|
|
for {
|
|
select {
|
|
case <-task.Context().Done():
|
|
return
|
|
case <-p.forceRenewalCh:
|
|
renew(renewModeForce)
|
|
case <-timer.C:
|
|
renew(renewModeIfNeeded)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (p *Provider) initClient() error {
|
|
legoClient, err := lego.NewClient(p.legoCfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider, p.cfg.dns01Options()...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
p.client = legoClient
|
|
return nil
|
|
}
|
|
|
|
func (p *Provider) registerACME() error {
|
|
if p.user.Registration != nil {
|
|
return nil
|
|
}
|
|
|
|
reg, err := p.client.Registration.ResolveAccountByKey()
|
|
if err == nil {
|
|
p.user.Registration = reg
|
|
log.Info().Msg("reused acme registration from private key")
|
|
return nil
|
|
}
|
|
|
|
if p.cfg.EABKid != "" && p.cfg.EABHmac != "" {
|
|
reg, err = p.client.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
|
|
TermsOfServiceAgreed: true,
|
|
Kid: p.cfg.EABKid,
|
|
HmacEncoded: p.cfg.EABHmac,
|
|
})
|
|
} else {
|
|
reg, err = p.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.user.Registration = reg
|
|
log.Info().Interface("reg", reg).Msg("acme registered")
|
|
return nil
|
|
}
|
|
|
|
func (p *Provider) saveCert(cert *certificate.Resource) error {
|
|
if common.IsTest {
|
|
return nil
|
|
}
|
|
/* This should have been done in setup
|
|
but double check is always a good choice.*/
|
|
_, err := os.Stat(filepath.Dir(p.cfg.CertPath))
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
if err = os.MkdirAll(filepath.Dir(p.cfg.CertPath), 0o755); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
err = os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0o600) // -rw-------
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0o644) // -rw-r--r--
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Provider) certState() CertState {
|
|
if time.Now().After(p.ShouldRenewOn()) {
|
|
return CertStateExpired
|
|
}
|
|
|
|
if len(p.certExpiries) != len(p.cfg.Domains) {
|
|
return CertStateMismatch
|
|
}
|
|
|
|
for i := range len(p.cfg.Domains) {
|
|
if _, ok := p.certExpiries[p.cfg.Domains[i]]; !ok {
|
|
log.Info().Msgf("autocert domains mismatch: cert: %s, wanted: %s",
|
|
strings.Join(slices.Collect(maps.Keys(p.certExpiries)), ", "),
|
|
strings.Join(p.cfg.Domains, ", "))
|
|
return CertStateMismatch
|
|
}
|
|
}
|
|
|
|
return CertStateValid
|
|
}
|
|
|
|
func (p *Provider) renew(mode RenewMode) (renewed bool, err error) {
|
|
if p.cfg.Provider == ProviderLocal {
|
|
return false, nil
|
|
}
|
|
|
|
if mode != renewModeForce {
|
|
// Retry after 1 hour on failure
|
|
lastFailure, err := p.GetLastFailure()
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to get last failure: %w", err)
|
|
}
|
|
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
|
|
until := lastFailure.Add(renewalCooldownDuration).Local()
|
|
return false, fmt.Errorf("still in cooldown until %s", strutils.FormatTime(until))
|
|
}
|
|
}
|
|
|
|
if mode == renewModeIfNeeded {
|
|
switch p.certState() {
|
|
case CertStateExpired:
|
|
log.Info().Msg("certs expired, renewing")
|
|
case CertStateMismatch:
|
|
log.Info().Msg("cert domains mismatch with config, renewing")
|
|
default:
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
if mode == renewModeForce {
|
|
log.Info().Msg("force renewing cert by user request")
|
|
}
|
|
|
|
if err := p.ObtainCert(); err != nil {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
|
|
r := make(CertExpiries, len(cert.Certificate))
|
|
for _, cert := range cert.Certificate {
|
|
x509Cert, err := x509.ParseCertificate(cert)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if x509Cert.IsCA {
|
|
continue
|
|
}
|
|
r[x509Cert.Subject.CommonName] = x509Cert.NotAfter
|
|
for i := range x509Cert.DNSNames {
|
|
r[x509Cert.DNSNames[i]] = x509Cert.NotAfter
|
|
}
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
func lastFailureFileFor(certPath, keyPath string) string {
|
|
dir := filepath.Dir(certPath)
|
|
sum := sha256.Sum256([]byte(certPath + "|" + keyPath))
|
|
return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6]))
|
|
}
|
|
|
|
func (p *Provider) rebuildSNIMatcher() {
|
|
if p.cfg.idx != 0 { // only main provider has extra providers
|
|
return
|
|
}
|
|
|
|
p.sniMatcher = sniMatcher{}
|
|
p.sniMatcher.addProvider(p)
|
|
for _, ep := range p.extraProviders {
|
|
p.sniMatcher.addProvider(ep)
|
|
}
|
|
}
|