diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 2d9f0a42..43e2aa70 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -24,6 +24,7 @@ type Config struct { Domains []string `json:"domains,omitempty"` CertPath string `json:"cert_path,omitempty"` KeyPath string `json:"key_path,omitempty"` + Extra []Config `json:"extra,omitempty"` ACMEKeyPath string `json:"acme_key_path,omitempty"` Provider string `json:"provider,omitempty"` Options map[string]strutils.Redacted `json:"options,omitempty"` @@ -48,6 +49,9 @@ var ( ErrMissingEmail = gperr.New("missing field 'email'") ErrMissingProvider = gperr.New("missing field 'provider'") ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'") + ErrMissingCertPath = gperr.New("missing field 'cert_path'") + ErrMissingKeyPath = gperr.New("missing field 'key_path'") + ErrDuplicatedPath = gperr.New("duplicated path") ErrInvalidDomain = gperr.New("invalid domain") ErrUnknownProvider = gperr.New("unknown provider") ) @@ -68,10 +72,36 @@ func (cfg *Config) Validate() gperr.Error { if cfg.Provider == "" { cfg.Provider = ProviderLocal - return nil } b := gperr.NewBuilder("autocert errors") + if len(cfg.Extra) > 0 { + seenCertPaths := make(map[string]int, len(cfg.Extra)) + seenKeyPaths := make(map[string]int, len(cfg.Extra)) + for i := range cfg.Extra { + if cfg.Extra[i].CertPath == "" { + b.Add(ErrMissingCertPath.Subjectf("extra[%d].cert_path", i)) + } + if cfg.Extra[i].KeyPath == "" { + b.Add(ErrMissingKeyPath.Subjectf("extra[%d].key_path", i)) + } + if cfg.Extra[i].CertPath != "" { + if first, ok := seenCertPaths[cfg.Extra[i].CertPath]; ok { + b.Add(ErrDuplicatedPath.Subjectf("extra[%d].cert_path", i).Withf("first: %d", first)) + } else { + seenCertPaths[cfg.Extra[i].CertPath] = i + } + } + if cfg.Extra[i].KeyPath != "" { + if first, ok := seenKeyPaths[cfg.Extra[i].KeyPath]; ok { + b.Add(ErrDuplicatedPath.Subjectf("extra[%d].key_path", i).Withf("first: %d", first)) + } else { + seenKeyPaths[cfg.Extra[i].KeyPath] = i + } + } + } + } + if cfg.Provider == ProviderCustom && cfg.CADirURL == "" { b.Add(ErrMissingCADirURL) } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index d7032f8e..163d6bbc 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -1,13 +1,14 @@ package autocert import ( + "crypto/sha256" "crypto/tls" "crypto/x509" "errors" "fmt" "maps" "os" - "path" + "path/filepath" "slices" "strings" "sync/atomic" @@ -33,9 +34,14 @@ type ( client *lego.Client lastFailure time.Time + lastFailureFile string + legoCert *certificate.Resource tlsCert *tls.Certificate certExpiries CertExpiries + + extraProviders []*Provider + sniMatcher sniMatcher } CertExpiries map[string]time.Time @@ -55,16 +61,23 @@ var ActiveProvider atomic.Pointer[Provider] func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider { return &Provider{ - cfg: cfg, - user: user, - legoCfg: legoCfg, + cfg: cfg, + user: user, + legoCfg: legoCfg, + lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath), } } -func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { if p.tlsCert == nil { return nil, ErrGetCertFailure } + if hello == nil || hello.ServerName == "" { + return p.tlsCert, nil + } + if prov := p.sniMatcher.match(hello.ServerName); prov != nil && prov.tlsCert != nil { + return prov.tlsCert, nil + } return p.tlsCert, nil } @@ -90,7 +103,7 @@ func (p *Provider) GetLastFailure() (time.Time, error) { } if p.lastFailure.IsZero() { - data, err := os.ReadFile(LastFailureFile) + data, err := os.ReadFile(p.lastFailureFile) if err != nil { if !os.IsNotExist(err) { return time.Time{}, err @@ -108,7 +121,7 @@ func (p *Provider) UpdateLastFailure() error { } t := time.Now() p.lastFailure = t - return os.WriteFile(LastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600) + return os.WriteFile(p.lastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600) } func (p *Provider) ClearLastFailure() error { @@ -116,10 +129,26 @@ func (p *Provider) ClearLastFailure() error { return nil } p.lastFailure = time.Time{} - return os.Remove(LastFailureFile) + return os.Remove(p.lastFailureFile) } func (p *Provider) ObtainCert() error { + if len(p.extraProviders) > 0 { + errs := gperr.NewGroup("autocert errors") + errs.Go(p.obtainCertSelf) + for _, ep := range p.extraProviders { + errs.Go(ep.obtainCertSelf) + } + if err := errs.Wait().Error(); err != nil { + return err + } + p.rebuildSNIMatcher() + return nil + } + return p.obtainCertSelf() +} + +func (p *Provider) obtainCertSelf() error { if p.cfg.Provider == ProviderLocal { return nil } @@ -239,7 +268,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { timer := time.NewTimer(time.Until(renewalTime)) defer timer.Stop() - task := parent.Subtask("cert-renew-scheduler", true) + task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true) defer task.Finish(nil) for { @@ -282,6 +311,9 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { } } }() + for _, ep := range p.extraProviders { + ep.ScheduleRenewal(parent) + } } func (p *Provider) initClient() error { @@ -334,10 +366,10 @@ func (p *Provider) saveCert(cert *certificate.Resource) error { } /* This should have been done in setup but double check is always a good choice.*/ - _, err := os.Stat(path.Dir(p.cfg.CertPath)) + _, err := os.Stat(filepath.Dir(p.cfg.CertPath)) if err != nil { if os.IsNotExist(err) { - if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil { + if err = os.MkdirAll(filepath.Dir(p.cfg.CertPath), 0o755); err != nil { return err } } else { @@ -391,7 +423,7 @@ func (p *Provider) renewIfNeeded() error { return nil } - return p.ObtainCert() + return p.obtainCertSelf() } func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { @@ -411,3 +443,20 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { } return r, nil } + +func lastFailureFileFor(certPath, keyPath string) string { + if certPath == "" && keyPath == "" { + return LastFailureFile + } + dir := filepath.Dir(certPath) + sum := sha256.Sum256([]byte(certPath + "|" + keyPath)) + return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6])) +} + +func (p *Provider) rebuildSNIMatcher() { + p.sniMatcher = sniMatcher{} + p.sniMatcher.addProvider(p) + for _, ep := range p.extraProviders { + p.sniMatcher.addProvider(ep) + } +} diff --git a/internal/autocert/provider_test/extra_validation_test.go b/internal/autocert/provider_test/extra_validation_test.go new file mode 100644 index 00000000..3fbb5174 --- /dev/null +++ b/internal/autocert/provider_test/extra_validation_test.go @@ -0,0 +1,32 @@ +package provider_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" +) + +func TestExtraCertKeyPathsUnique(t *testing.T) { + t.Run("duplicate cert_path rejected", func(t *testing.T) { + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + Extra: []autocert.Config{ + {CertPath: "a.crt", KeyPath: "a.key"}, + {CertPath: "a.crt", KeyPath: "b.key"}, + }, + } + require.Error(t, cfg.Validate()) + }) + + t.Run("duplicate key_path rejected", func(t *testing.T) { + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + Extra: []autocert.Config{ + {CertPath: "a.crt", KeyPath: "a.key"}, + {CertPath: "b.crt", KeyPath: "a.key"}, + }, + } + require.Error(t, cfg.Validate()) + }) +} diff --git a/internal/autocert/provider_test/sni_test.go b/internal/autocert/provider_test/sni_test.go new file mode 100644 index 00000000..766593cd --- /dev/null +++ b/internal/autocert/provider_test/sni_test.go @@ -0,0 +1,383 @@ +package provider_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" +) + +func writeSelfSignedCert(t *testing.T, dir string, dnsNames []string) (string, string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + require.NoError(t, err) + + cn := "" + if len(dnsNames) > 0 { + cn = dnsNames[0] + } + + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: dnsNames, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + require.NoError(t, os.WriteFile(certPath, certPEM, 0o644)) + require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600)) + + return certPath, keyPath +} + +func TestGetCertBySNI(t *testing.T) { + t.Run("extra cert used when main does not match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"*.internal.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.internal.example.com") + }) + + t.Run("exact match wins over wildcard match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "foo.example.com") + }) + + t.Run("main cert fallback when no match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"*.test.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("nil ServerName returns main cert", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(nil) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("empty ServerName returns main cert", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("case insensitive matching", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"Foo.Example.COM"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "Foo.Example.COM") + }) + + t.Run("normalization with trailing dot and whitespace", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "foo.example.com") + }) + + t.Run("longest wildcard match wins", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir1 := t.TempDir() + extraCert1, extraKey1 := writeSelfSignedCert(t, extraDir1, []string{"*.a.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert1, KeyPath: extraKey1}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.a.example.com") + }) + + t.Run("main cert wildcard match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("multiple extra certs", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir1 := t.TempDir() + extraCert1, extraKey1 := writeSelfSignedCert(t, extraDir1, []string{"*.test.com"}) + + extraDir2 := t.TempDir() + extraCert2, extraKey2 := writeSelfSignedCert(t, extraDir2, []string{"*.dev.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert1, KeyPath: extraKey1}, + {CertPath: extraCert2, KeyPath: extraKey2}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"}) + require.NoError(t, err) + leaf1, err := x509.ParseCertificate(cert1.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf1.DNSNames, "*.test.com") + + cert2, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.dev.com"}) + require.NoError(t, err) + leaf2, err := x509.ParseCertificate(cert2.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf2.DNSNames, "*.dev.com") + }) + + t.Run("multiple DNSNames in cert", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com", "bar.example.com", "*.test.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + leaf1, err := x509.ParseCertificate(cert1.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf1.DNSNames, "foo.example.com") + + cert2, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"}) + require.NoError(t, err) + leaf2, err := x509.ParseCertificate(cert2.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf2.DNSNames, "bar.example.com") + + cert3, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "baz.test.com"}) + require.NoError(t, err) + leaf3, err := x509.ParseCertificate(cert3.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf3.DNSNames, "*.test.com") + }) +} diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 88bbcd53..e114be50 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -2,9 +2,11 @@ package autocert import ( "errors" + "fmt" "os" "github.com/rs/zerolog/log" + gperr "github.com/yusing/goutils/errs" strutils "github.com/yusing/goutils/strings" ) @@ -19,6 +21,10 @@ func (p *Provider) Setup() (err error) { } } + if err = p.setupExtraProviders(); err != nil { + return err + } + for _, expiry := range p.GetExpiries() { log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) break @@ -26,3 +32,70 @@ func (p *Provider) Setup() (err error) { return nil } + +func (p *Provider) setupExtraProviders() error { + p.extraProviders = nil + p.sniMatcher = sniMatcher{} + if len(p.cfg.Extra) == 0 { + p.rebuildSNIMatcher() + return nil + } + + for i := range p.cfg.Extra { + merged := mergeExtraConfig(p.cfg, &p.cfg.Extra[i]) + user, legoCfg, err := merged.GetLegoConfig() + if err != nil { + return err.Subjectf("extra[%d]", i) + } + ep := NewProvider(&merged, user, legoCfg) + if err := ep.Setup(); err != nil { + return gperr.PrependSubject(fmt.Sprintf("extra[%d]", i), err) + } + p.extraProviders = append(p.extraProviders, ep) + } + p.rebuildSNIMatcher() + return nil +} + +func mergeExtraConfig(mainCfg *Config, extraCfg *Config) Config { + merged := *mainCfg + merged.Extra = nil + merged.CertPath = extraCfg.CertPath + merged.KeyPath = extraCfg.KeyPath + + if merged.Email == "" { + merged.Email = mainCfg.Email + } + + if len(extraCfg.Domains) > 0 { + merged.Domains = extraCfg.Domains + } + if extraCfg.ACMEKeyPath != "" { + merged.ACMEKeyPath = extraCfg.ACMEKeyPath + } + if extraCfg.Provider != "" { + merged.Provider = extraCfg.Provider + } + if len(extraCfg.Options) > 0 { + merged.Options = extraCfg.Options + } + if len(extraCfg.Resolvers) > 0 { + merged.Resolvers = extraCfg.Resolvers + } + if extraCfg.CADirURL != "" { + merged.CADirURL = extraCfg.CADirURL + } + if len(extraCfg.CACerts) > 0 { + merged.CACerts = extraCfg.CACerts + } + if extraCfg.EABKid != "" { + merged.EABKid = extraCfg.EABKid + } + if extraCfg.EABHmac != "" { + merged.EABHmac = extraCfg.EABHmac + } + if extraCfg.HTTPClient != nil { + merged.HTTPClient = extraCfg.HTTPClient + } + return merged +} diff --git a/internal/autocert/sni_matcher.go b/internal/autocert/sni_matcher.go new file mode 100644 index 00000000..7859e2d5 --- /dev/null +++ b/internal/autocert/sni_matcher.go @@ -0,0 +1,129 @@ +package autocert + +import ( + "crypto/x509" + "strings" +) + +type sniMatcher struct { + exact map[string]*Provider + root sniTreeNode +} + +type sniTreeNode struct { + children map[string]*sniTreeNode + wildcard *Provider +} + +func (m *sniMatcher) match(serverName string) *Provider { + if m == nil { + return nil + } + serverName = normalizeServerName(serverName) + if serverName == "" { + return nil + } + if m.exact != nil { + if p, ok := m.exact[serverName]; ok { + return p + } + } + return m.matchSuffixTree(serverName) +} + +func (m *sniMatcher) matchSuffixTree(serverName string) *Provider { + n := &m.root + labels := strings.Split(serverName, ".") + + var best *Provider + for i := len(labels) - 1; i >= 0; i-- { + if n.children == nil { + break + } + next := n.children[labels[i]] + if next == nil { + break + } + n = next + + consumed := len(labels) - i + remaining := len(labels) - consumed + if remaining == 1 && n.wildcard != nil { + best = n.wildcard + } + } + return best +} + +func normalizeServerName(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimSuffix(s, ".") + return strings.ToLower(s) +} + +func (m *sniMatcher) addProvider(p *Provider) { + if p == nil || p.tlsCert == nil || len(p.tlsCert.Certificate) == 0 { + return + } + leaf, err := x509.ParseCertificate(p.tlsCert.Certificate[0]) + if err != nil { + return + } + + addName := func(name string) { + name = normalizeServerName(name) + if name == "" { + return + } + if after, ok := strings.CutPrefix(name, "*."); ok { + suffix := after + if suffix == "" { + return + } + m.insertWildcardSuffix(suffix, p) + return + } + m.insertExact(name, p) + } + + if leaf.Subject.CommonName != "" { + addName(leaf.Subject.CommonName) + } + for _, n := range leaf.DNSNames { + addName(n) + } +} + +func (m *sniMatcher) insertExact(name string, p *Provider) { + if name == "" || p == nil { + return + } + if m.exact == nil { + m.exact = make(map[string]*Provider) + } + if _, exists := m.exact[name]; !exists { + m.exact[name] = p + } +} + +func (m *sniMatcher) insertWildcardSuffix(suffix string, p *Provider) { + if suffix == "" || p == nil { + return + } + n := &m.root + labels := strings.Split(suffix, ".") + for i := len(labels) - 1; i >= 0; i-- { + if n.children == nil { + n.children = make(map[string]*sniTreeNode) + } + next := n.children[labels[i]] + if next == nil { + next = &sniTreeNode{} + n.children[labels[i]] = next + } + n = next + } + if n.wildcard == nil { + n.wildcard = p + } +} diff --git a/internal/autocert/sni_matcher_bench_test.go b/internal/autocert/sni_matcher_bench_test.go new file mode 100644 index 00000000..e55ffb12 --- /dev/null +++ b/internal/autocert/sni_matcher_bench_test.go @@ -0,0 +1,104 @@ +package autocert + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "testing" + "time" +) + +func createTLSCert(dnsNames []string) (*tls.Certificate, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + if err != nil { + return nil, err + } + + cn := "" + if len(dnsNames) > 0 { + cn = dnsNames[0] + } + + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: dnsNames, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + return nil, err + } + + return &tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: key, + }, nil +} + +func BenchmarkSNIMatcher(b *testing.B) { + matcher := sniMatcher{} + + wildcard1Cert, err := createTLSCert([]string{"*.example.com"}) + if err != nil { + b.Fatal(err) + } + wildcard1 := &Provider{tlsCert: wildcard1Cert} + + wildcard2Cert, err := createTLSCert([]string{"*.test.com"}) + if err != nil { + b.Fatal(err) + } + wildcard2 := &Provider{tlsCert: wildcard2Cert} + + wildcard3Cert, err := createTLSCert([]string{"*.foo.com"}) + if err != nil { + b.Fatal(err) + } + wildcard3 := &Provider{tlsCert: wildcard3Cert} + + exact1Cert, err := createTLSCert([]string{"bar.example.com"}) + if err != nil { + b.Fatal(err) + } + exact1 := &Provider{tlsCert: exact1Cert} + + exact2Cert, err := createTLSCert([]string{"baz.test.com"}) + if err != nil { + b.Fatal(err) + } + exact2 := &Provider{tlsCert: exact2Cert} + + matcher.addProvider(wildcard1) + matcher.addProvider(wildcard2) + matcher.addProvider(wildcard3) + matcher.addProvider(exact1) + matcher.addProvider(exact2) + + b.Run("MatchWildcard", func(b *testing.B) { + for b.Loop() { + _ = matcher.match("sub.example.com") + } + }) + + b.Run("MatchExact", func(b *testing.B) { + for b.Loop() { + _ = matcher.match("bar.example.com") + } + }) +}