mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-10 02:43:37 +02:00
259 lines
5.6 KiB
Go
259 lines
5.6 KiB
Go
package autocert
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"os"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-acme/lego/v4/certificate"
|
|
"github.com/go-acme/lego/v4/challenge"
|
|
"github.com/go-acme/lego/v4/lego"
|
|
"github.com/go-acme/lego/v4/registration"
|
|
"github.com/sirupsen/logrus"
|
|
E "github.com/yusing/go-proxy/error"
|
|
M "github.com/yusing/go-proxy/models"
|
|
"github.com/yusing/go-proxy/utils"
|
|
)
|
|
|
|
type Provider struct {
|
|
cfg *Config
|
|
user *User
|
|
legoCfg *lego.Config
|
|
client *lego.Client
|
|
|
|
tlsCert *tls.Certificate
|
|
certExpiries CertExpiries
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
type ProviderGenerator func(M.AutocertProviderOpt) (challenge.Provider, error)
|
|
type CertExpiries map[string]time.Time
|
|
|
|
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
if p.tlsCert == nil {
|
|
return nil, E.Failure("get certificate")
|
|
}
|
|
return p.tlsCert, nil
|
|
}
|
|
|
|
func (p *Provider) GetName() string {
|
|
return p.cfg.Provider
|
|
}
|
|
|
|
func (p *Provider) GetCertPath() string {
|
|
return p.cfg.CertPath
|
|
}
|
|
|
|
func (p *Provider) GetKeyPath() string {
|
|
return p.cfg.KeyPath
|
|
}
|
|
|
|
func (p *Provider) GetExpiries() CertExpiries {
|
|
return p.certExpiries
|
|
}
|
|
|
|
func (p *Provider) ObtainCert() E.NestedError {
|
|
ne := E.Failure("obtain certificate")
|
|
|
|
client := p.client
|
|
if p.user.Registration == nil {
|
|
reg, err := E.Check(client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}))
|
|
if err.IsNotNil() {
|
|
return ne.With(E.Failure("register account").With(err))
|
|
}
|
|
p.user.Registration = reg
|
|
}
|
|
req := certificate.ObtainRequest{
|
|
Domains: p.cfg.Domains,
|
|
Bundle: true,
|
|
}
|
|
cert, err := E.Check(client.Certificate.Obtain(req))
|
|
if err.IsNotNil() {
|
|
return ne.With(err)
|
|
}
|
|
err = p.saveCert(cert)
|
|
if err.IsNotNil() {
|
|
return ne.With(E.Failure("save certificate").With(err))
|
|
}
|
|
tlsCert, err := E.Check(tls.X509KeyPair(cert.Certificate, cert.PrivateKey))
|
|
if err.IsNotNil() {
|
|
return ne.With(E.Failure("parse obtained certificate").With(err))
|
|
}
|
|
expiries, err := getCertExpiries(&tlsCert)
|
|
if err.IsNotNil() {
|
|
return ne.With(E.Failure("get certificate expiry").With(err))
|
|
}
|
|
p.tlsCert = &tlsCert
|
|
p.certExpiries = expiries
|
|
return E.Nil()
|
|
}
|
|
|
|
func (p *Provider) LoadCert() E.NestedError {
|
|
cert, err := E.Check(tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath))
|
|
if err.IsNotNil() {
|
|
return err
|
|
}
|
|
expiries, err := getCertExpiries(&cert)
|
|
if err.IsNotNil() {
|
|
return err
|
|
}
|
|
p.tlsCert = &cert
|
|
p.certExpiries = expiries
|
|
p.renewIfNeeded()
|
|
return E.Nil()
|
|
}
|
|
|
|
func (p *Provider) ShouldRenewOn() time.Time {
|
|
for _, expiry := range p.certExpiries {
|
|
return expiry.AddDate(0, -1, 0)
|
|
}
|
|
// this line should never be reached
|
|
panic("no certificate available")
|
|
}
|
|
|
|
func (p *Provider) ScheduleRenewal(ctx context.Context) {
|
|
if p.GetName() == ProviderLocal {
|
|
return
|
|
}
|
|
|
|
logger.Debug("starting renewal scheduler")
|
|
defer logger.Debug("renewal scheduler stopped")
|
|
|
|
stop := make(chan struct{})
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
t := time.Until(p.ShouldRenewOn())
|
|
Logger.Infof("next renewal in %v", t.Round(time.Second))
|
|
go func() {
|
|
<-time.After(t)
|
|
close(stop)
|
|
}()
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-stop:
|
|
if err := p.renewIfNeeded(); err.IsNotNil() {
|
|
Logger.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Provider) saveCert(cert *certificate.Resource) E.NestedError {
|
|
err := os.WriteFile(p.cfg.KeyPath, cert.PrivateKey, 0600) // -rw-------
|
|
if err != nil {
|
|
return E.Failure("write key file").With(err)
|
|
}
|
|
err = os.WriteFile(p.cfg.CertPath, cert.Certificate, 0644) // -rw-r--r--
|
|
if err != nil {
|
|
return E.Failure("write cert file").With(err)
|
|
}
|
|
return E.Nil()
|
|
}
|
|
|
|
func (p *Provider) needRenewal() bool {
|
|
expired := time.Now().After(p.ShouldRenewOn())
|
|
if expired {
|
|
return true
|
|
}
|
|
if len(p.cfg.Domains) != len(p.certExpiries) {
|
|
return true
|
|
}
|
|
wantedDomains := make([]string, len(p.cfg.Domains))
|
|
certDomains := make([]string, len(p.certExpiries))
|
|
copy(wantedDomains, p.cfg.Domains)
|
|
i := 0
|
|
for domain := range p.certExpiries {
|
|
certDomains[i] = domain
|
|
i++
|
|
}
|
|
slices.Sort(wantedDomains)
|
|
slices.Sort(certDomains)
|
|
for i, domain := range certDomains {
|
|
if domain != wantedDomains[i] {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (p *Provider) renewIfNeeded() E.NestedError {
|
|
if !p.needRenewal() {
|
|
return E.Nil()
|
|
}
|
|
|
|
p.mutex.Lock()
|
|
defer p.mutex.Unlock()
|
|
|
|
if !p.needRenewal() {
|
|
return E.Nil()
|
|
}
|
|
|
|
trials := 0
|
|
for {
|
|
err := p.ObtainCert()
|
|
if err.IsNotNil() {
|
|
return E.Nil()
|
|
}
|
|
trials++
|
|
if trials > 3 {
|
|
return E.Failure("renew certificate").With(err)
|
|
}
|
|
time.Sleep(5 * time.Second)
|
|
}
|
|
}
|
|
|
|
func getCertExpiries(cert *tls.Certificate) (CertExpiries, E.NestedError) {
|
|
r := make(CertExpiries, len(cert.Certificate))
|
|
for _, cert := range cert.Certificate {
|
|
x509Cert, err := E.Check(x509.ParseCertificate(cert))
|
|
if err.IsNotNil() {
|
|
return nil, E.Failure("parse certificate").With(err)
|
|
}
|
|
if x509Cert.IsCA {
|
|
continue
|
|
}
|
|
r[x509Cert.Subject.CommonName] = x509Cert.NotAfter
|
|
}
|
|
return r, E.Nil()
|
|
}
|
|
|
|
func setOptions[T interface{}](cfg *T, opt M.AutocertProviderOpt) E.NestedError {
|
|
for k, v := range opt {
|
|
err := utils.SetFieldFromSnake(cfg, k, v)
|
|
if err.IsNotNil() {
|
|
return E.Failure("set autocert option").Subject(k).With(err)
|
|
}
|
|
}
|
|
return E.Nil()
|
|
}
|
|
|
|
func providerGenerator[CT any, PT challenge.Provider](
|
|
defaultCfg func() *CT,
|
|
newProvider func(*CT) (PT, error),
|
|
) ProviderGenerator {
|
|
return func(opt M.AutocertProviderOpt) (challenge.Provider, error) {
|
|
cfg := defaultCfg()
|
|
err := setOptions(cfg, opt)
|
|
if err.IsNotNil() {
|
|
return nil, err
|
|
}
|
|
p, err := E.Check(newProvider(cfg))
|
|
if err.IsNotNil() {
|
|
return nil, err
|
|
}
|
|
return p, nil
|
|
}
|
|
}
|
|
|
|
var logger = logrus.WithField("module", "autocert")
|