diff --git a/Makefile b/Makefile index 785a5f75..2a5a805a 100755 --- a/Makefile +++ b/Makefile @@ -123,6 +123,15 @@ dev: dev-build: build docker compose -f dev.compose.yml up -t 0 -d app --force-recreate +benchmark: + @if [ -z "$(TARGET)" ]; then \ + docker compose -f dev.compose.yml up -d --force-recreate godoxy traefik caddy nginx; \ + else \ + docker compose -f dev.compose.yml up -d --force-recreate $(TARGET); \ + fi + sleep 1 + @./scripts/benchmark.sh + dev-run: build cd dev-data && ${BIN_PATH} @@ -142,7 +151,7 @@ ci-test: act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)" cloc: - scc -w -i go --not-match '_test.go$' + scc -w -i go --not-match '_test.go$$' push-github: git push origin $(shell git rev-parse --abbrev-ref HEAD) diff --git a/cmd/bench_server/Dockerfile b/cmd/bench_server/Dockerfile new file mode 100644 index 00000000..c9ea5183 --- /dev/null +++ b/cmd/bench_server/Dockerfile @@ -0,0 +1,18 @@ +FROM golang:1.25.5-alpine AS builder + +HEALTHCHECK NONE + +WORKDIR /src + +COPY go.mod go.sum ./ +COPY main.go ./ + +RUN go build -o bench_server main.go + +FROM scratch + +COPY --from=builder /src/bench_server /app/run + +USER 1001:1001 + +CMD ["/app/run"] \ No newline at end of file diff --git a/cmd/bench_server/go.mod b/cmd/bench_server/go.mod new file mode 100644 index 00000000..6a1facd0 --- /dev/null +++ b/cmd/bench_server/go.mod @@ -0,0 +1,3 @@ +module github.com/yusing/godoxy/cmd/bench_server + +go 1.25.5 diff --git a/cmd/bench_server/go.sum b/cmd/bench_server/go.sum new file mode 100644 index 00000000..e69de29b diff --git a/cmd/bench_server/main.go b/cmd/bench_server/main.go new file mode 100644 index 00000000..92a64e01 --- /dev/null +++ b/cmd/bench_server/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "log" + "net/http" + + "math/rand/v2" +) + +var printables = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +var random = make([]byte, 4096) + +func init() { + for i := range random { + random[i] = printables[rand.IntN(len(printables))] + } +} + +func main() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write(random) + }) + + server := &http.Server{ + Addr: ":80", + Handler: handler, + } + + log.Println("Bench server listening on :80") + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("ListenAndServe: %v", err) + } +} diff --git a/cmd/h2c_test_server/main.go b/cmd/h2c_test_server/main.go index b49011fe..83ed1bf5 100644 --- a/cmd/h2c_test_server/main.go +++ b/cmd/h2c_test_server/main.go @@ -19,7 +19,7 @@ func main() { Handler: h2c.NewHandler(handler, &http2.Server{}), } - log.Println("H2C server listening on :8080") + log.Println("H2C server listening on :80") if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("ListenAndServe: %v", err) } diff --git a/dev.compose.yml b/dev.compose.yml index 21704d63..f869dc61 100644 --- a/dev.compose.yml +++ b/dev.compose.yml @@ -1,3 +1,8 @@ +x-benchmark: &benchmark + restart: no + labels: + proxy.exclude: true + proxy.#1.healthcheck.disable: true services: app: image: godoxy-dev @@ -90,7 +95,154 @@ services: labels: proxy.#1.scheme: h2c proxy.#1.port: 80 + bench: # returns 4096 bytes of random data + <<: *benchmark + build: + context: cmd/bench_server + dockerfile: Dockerfile + container_name: bench + godoxy: + <<: *benchmark + build: . + container_name: godoxy-benchmark + ports: + - 8080:80 + configs: + - source: godoxy_config + target: /app/config/config.yml + - source: godoxy_provider + target: /app/config/providers.yml + traefik: + <<: *benchmark + image: traefik:latest + container_name: traefik + command: + - --api.insecure=true + - --entrypoints.web.address=:8081 + - --providers.file.directory=/etc/traefik/dynamic + - --providers.file.watch=true + - --log.level=ERROR + ports: + - 8081:8081 + configs: + - source: traefik_config + target: /etc/traefik/dynamic/routes.yml + caddy: + <<: *benchmark + image: caddy:latest + container_name: caddy + ports: + - 8082:80 + configs: + - source: caddy_config + target: /etc/caddy/Caddyfile + tmpfs: + - /data + - /config + nginx: + <<: *benchmark + image: nginx:latest + container_name: nginx + command: nginx -g 'daemon off;' -c /etc/nginx/nginx.conf + ports: + - 8083:80 + configs: + - source: nginx_config + target: /etc/nginx/nginx.conf + configs: + godoxy_config: + content: | + providers: + include: + - providers.yml + godoxy_provider: + content: | + bench.domain.com: + host: bench + traefik_config: + content: | + http: + routers: + bench: + rule: "Host(`bench.domain.com`)" + entryPoints: + - web + service: bench + services: + bench: + loadBalancer: + servers: + - url: "http://bench:80" + caddy_config: + content: | + { + admin off + auto_https off + default_bind 0.0.0.0 + + servers { + protocols h1 h2c + } + } + + http://bench.domain.com { + reverse_proxy bench:80 + } + nginx_config: + content: | + worker_processes auto; + worker_rlimit_nofile 65535; + error_log /dev/null; + pid /var/run/nginx.pid; + + events { + worker_connections 10240; + multi_accept on; + use epoll; + } + + http { + include /etc/nginx/mime.types; + default_type application/octet-stream; + + access_log off; + + sendfile on; + tcp_nopush on; + tcp_nodelay on; + keepalive_timeout 65; + keepalive_requests 10000; + + upstream backend { + server bench:80; + keepalive 128; + } + + server { + listen 80 default_server; + server_name _; + http2 on; + + return 404; + } + + server { + listen 80; + server_name bench.domain.com; + http2 on; + + location / { + proxy_pass http://backend; + proxy_http_version 1.1; + proxy_set_header Connection ""; + proxy_set_header Host $$host; + proxy_set_header X-Real-IP $$remote_addr; + proxy_set_header X-Forwarded-For $$proxy_add_x_forwarded_for; + proxy_buffering off; + } + } + } parca: content: | object_storage: diff --git a/goutils b/goutils index 19965cc6..785deb23 160000 --- a/goutils +++ b/goutils @@ -1 +1 @@ -Subproject commit 19965cc6afc016fa41581bedbfc04695e2c726b4 +Subproject commit 785deb23bd64fb9db28875ae39cf3ea6675fb146 diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 2d9f0a42..43e2aa70 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -24,6 +24,7 @@ type Config struct { Domains []string `json:"domains,omitempty"` CertPath string `json:"cert_path,omitempty"` KeyPath string `json:"key_path,omitempty"` + Extra []Config `json:"extra,omitempty"` ACMEKeyPath string `json:"acme_key_path,omitempty"` Provider string `json:"provider,omitempty"` Options map[string]strutils.Redacted `json:"options,omitempty"` @@ -48,6 +49,9 @@ var ( ErrMissingEmail = gperr.New("missing field 'email'") ErrMissingProvider = gperr.New("missing field 'provider'") ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'") + ErrMissingCertPath = gperr.New("missing field 'cert_path'") + ErrMissingKeyPath = gperr.New("missing field 'key_path'") + ErrDuplicatedPath = gperr.New("duplicated path") ErrInvalidDomain = gperr.New("invalid domain") ErrUnknownProvider = gperr.New("unknown provider") ) @@ -68,10 +72,36 @@ func (cfg *Config) Validate() gperr.Error { if cfg.Provider == "" { cfg.Provider = ProviderLocal - return nil } b := gperr.NewBuilder("autocert errors") + if len(cfg.Extra) > 0 { + seenCertPaths := make(map[string]int, len(cfg.Extra)) + seenKeyPaths := make(map[string]int, len(cfg.Extra)) + for i := range cfg.Extra { + if cfg.Extra[i].CertPath == "" { + b.Add(ErrMissingCertPath.Subjectf("extra[%d].cert_path", i)) + } + if cfg.Extra[i].KeyPath == "" { + b.Add(ErrMissingKeyPath.Subjectf("extra[%d].key_path", i)) + } + if cfg.Extra[i].CertPath != "" { + if first, ok := seenCertPaths[cfg.Extra[i].CertPath]; ok { + b.Add(ErrDuplicatedPath.Subjectf("extra[%d].cert_path", i).Withf("first: %d", first)) + } else { + seenCertPaths[cfg.Extra[i].CertPath] = i + } + } + if cfg.Extra[i].KeyPath != "" { + if first, ok := seenKeyPaths[cfg.Extra[i].KeyPath]; ok { + b.Add(ErrDuplicatedPath.Subjectf("extra[%d].key_path", i).Withf("first: %d", first)) + } else { + seenKeyPaths[cfg.Extra[i].KeyPath] = i + } + } + } + } + if cfg.Provider == ProviderCustom && cfg.CADirURL == "" { b.Add(ErrMissingCADirURL) } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index d7032f8e..163d6bbc 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -1,13 +1,14 @@ package autocert import ( + "crypto/sha256" "crypto/tls" "crypto/x509" "errors" "fmt" "maps" "os" - "path" + "path/filepath" "slices" "strings" "sync/atomic" @@ -33,9 +34,14 @@ type ( client *lego.Client lastFailure time.Time + lastFailureFile string + legoCert *certificate.Resource tlsCert *tls.Certificate certExpiries CertExpiries + + extraProviders []*Provider + sniMatcher sniMatcher } CertExpiries map[string]time.Time @@ -55,16 +61,23 @@ var ActiveProvider atomic.Pointer[Provider] func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider { return &Provider{ - cfg: cfg, - user: user, - legoCfg: legoCfg, + cfg: cfg, + user: user, + legoCfg: legoCfg, + lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath), } } -func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { if p.tlsCert == nil { return nil, ErrGetCertFailure } + if hello == nil || hello.ServerName == "" { + return p.tlsCert, nil + } + if prov := p.sniMatcher.match(hello.ServerName); prov != nil && prov.tlsCert != nil { + return prov.tlsCert, nil + } return p.tlsCert, nil } @@ -90,7 +103,7 @@ func (p *Provider) GetLastFailure() (time.Time, error) { } if p.lastFailure.IsZero() { - data, err := os.ReadFile(LastFailureFile) + data, err := os.ReadFile(p.lastFailureFile) if err != nil { if !os.IsNotExist(err) { return time.Time{}, err @@ -108,7 +121,7 @@ func (p *Provider) UpdateLastFailure() error { } t := time.Now() p.lastFailure = t - return os.WriteFile(LastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600) + return os.WriteFile(p.lastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600) } func (p *Provider) ClearLastFailure() error { @@ -116,10 +129,26 @@ func (p *Provider) ClearLastFailure() error { return nil } p.lastFailure = time.Time{} - return os.Remove(LastFailureFile) + return os.Remove(p.lastFailureFile) } func (p *Provider) ObtainCert() error { + if len(p.extraProviders) > 0 { + errs := gperr.NewGroup("autocert errors") + errs.Go(p.obtainCertSelf) + for _, ep := range p.extraProviders { + errs.Go(ep.obtainCertSelf) + } + if err := errs.Wait().Error(); err != nil { + return err + } + p.rebuildSNIMatcher() + return nil + } + return p.obtainCertSelf() +} + +func (p *Provider) obtainCertSelf() error { if p.cfg.Provider == ProviderLocal { return nil } @@ -239,7 +268,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { timer := time.NewTimer(time.Until(renewalTime)) defer timer.Stop() - task := parent.Subtask("cert-renew-scheduler", true) + task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true) defer task.Finish(nil) for { @@ -282,6 +311,9 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { } } }() + for _, ep := range p.extraProviders { + ep.ScheduleRenewal(parent) + } } func (p *Provider) initClient() error { @@ -334,10 +366,10 @@ func (p *Provider) saveCert(cert *certificate.Resource) error { } /* This should have been done in setup but double check is always a good choice.*/ - _, err := os.Stat(path.Dir(p.cfg.CertPath)) + _, err := os.Stat(filepath.Dir(p.cfg.CertPath)) if err != nil { if os.IsNotExist(err) { - if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil { + if err = os.MkdirAll(filepath.Dir(p.cfg.CertPath), 0o755); err != nil { return err } } else { @@ -391,7 +423,7 @@ func (p *Provider) renewIfNeeded() error { return nil } - return p.ObtainCert() + return p.obtainCertSelf() } func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { @@ -411,3 +443,20 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { } return r, nil } + +func lastFailureFileFor(certPath, keyPath string) string { + if certPath == "" && keyPath == "" { + return LastFailureFile + } + dir := filepath.Dir(certPath) + sum := sha256.Sum256([]byte(certPath + "|" + keyPath)) + return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6])) +} + +func (p *Provider) rebuildSNIMatcher() { + p.sniMatcher = sniMatcher{} + p.sniMatcher.addProvider(p) + for _, ep := range p.extraProviders { + p.sniMatcher.addProvider(ep) + } +} diff --git a/internal/autocert/provider_test/extra_validation_test.go b/internal/autocert/provider_test/extra_validation_test.go new file mode 100644 index 00000000..3fbb5174 --- /dev/null +++ b/internal/autocert/provider_test/extra_validation_test.go @@ -0,0 +1,32 @@ +package provider_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" +) + +func TestExtraCertKeyPathsUnique(t *testing.T) { + t.Run("duplicate cert_path rejected", func(t *testing.T) { + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + Extra: []autocert.Config{ + {CertPath: "a.crt", KeyPath: "a.key"}, + {CertPath: "a.crt", KeyPath: "b.key"}, + }, + } + require.Error(t, cfg.Validate()) + }) + + t.Run("duplicate key_path rejected", func(t *testing.T) { + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + Extra: []autocert.Config{ + {CertPath: "a.crt", KeyPath: "a.key"}, + {CertPath: "b.crt", KeyPath: "a.key"}, + }, + } + require.Error(t, cfg.Validate()) + }) +} diff --git a/internal/autocert/provider_test/sni_test.go b/internal/autocert/provider_test/sni_test.go new file mode 100644 index 00000000..766593cd --- /dev/null +++ b/internal/autocert/provider_test/sni_test.go @@ -0,0 +1,383 @@ +package provider_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/yusing/godoxy/internal/autocert" +) + +func writeSelfSignedCert(t *testing.T, dir string, dnsNames []string) (string, string) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + require.NoError(t, err) + + cn := "" + if len(dnsNames) > 0 { + cn = dnsNames[0] + } + + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: dnsNames, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + require.NoError(t, os.WriteFile(certPath, certPEM, 0o644)) + require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600)) + + return certPath, keyPath +} + +func TestGetCertBySNI(t *testing.T) { + t.Run("extra cert used when main does not match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"*.internal.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.internal.example.com") + }) + + t.Run("exact match wins over wildcard match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "foo.example.com") + }) + + t.Run("main cert fallback when no match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"*.test.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("nil ServerName returns main cert", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(nil) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("empty ServerName returns main cert", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("case insensitive matching", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"Foo.Example.COM"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "Foo.Example.COM") + }) + + t.Run("normalization with trailing dot and whitespace", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "foo.example.com") + }) + + t.Run("longest wildcard match wins", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir1 := t.TempDir() + extraCert1, extraKey1 := writeSelfSignedCert(t, extraDir1, []string{"*.a.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert1, KeyPath: extraKey1}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.a.example.com") + }) + + t.Run("main cert wildcard match", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"}) + require.NoError(t, err) + + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf.DNSNames, "*.example.com") + }) + + t.Run("multiple extra certs", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir1 := t.TempDir() + extraCert1, extraKey1 := writeSelfSignedCert(t, extraDir1, []string{"*.test.com"}) + + extraDir2 := t.TempDir() + extraCert2, extraKey2 := writeSelfSignedCert(t, extraDir2, []string{"*.dev.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert1, KeyPath: extraKey1}, + {CertPath: extraCert2, KeyPath: extraKey2}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"}) + require.NoError(t, err) + leaf1, err := x509.ParseCertificate(cert1.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf1.DNSNames, "*.test.com") + + cert2, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.dev.com"}) + require.NoError(t, err) + leaf2, err := x509.ParseCertificate(cert2.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf2.DNSNames, "*.dev.com") + }) + + t.Run("multiple DNSNames in cert", func(t *testing.T) { + mainDir := t.TempDir() + mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"}) + + extraDir := t.TempDir() + extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com", "bar.example.com", "*.test.com"}) + + cfg := &autocert.Config{ + Provider: autocert.ProviderLocal, + CertPath: mainCert, + KeyPath: mainKey, + Extra: []autocert.Config{ + {CertPath: extraCert, KeyPath: extraKey}, + }, + } + + require.NoError(t, cfg.Validate()) + + p := autocert.NewProvider(cfg, nil, nil) + require.NoError(t, p.Setup()) + + cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"}) + require.NoError(t, err) + leaf1, err := x509.ParseCertificate(cert1.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf1.DNSNames, "foo.example.com") + + cert2, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"}) + require.NoError(t, err) + leaf2, err := x509.ParseCertificate(cert2.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf2.DNSNames, "bar.example.com") + + cert3, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "baz.test.com"}) + require.NoError(t, err) + leaf3, err := x509.ParseCertificate(cert3.Certificate[0]) + require.NoError(t, err) + require.Contains(t, leaf3.DNSNames, "*.test.com") + }) +} diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 88bbcd53..e114be50 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -2,9 +2,11 @@ package autocert import ( "errors" + "fmt" "os" "github.com/rs/zerolog/log" + gperr "github.com/yusing/goutils/errs" strutils "github.com/yusing/goutils/strings" ) @@ -19,6 +21,10 @@ func (p *Provider) Setup() (err error) { } } + if err = p.setupExtraProviders(); err != nil { + return err + } + for _, expiry := range p.GetExpiries() { log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry)) break @@ -26,3 +32,70 @@ func (p *Provider) Setup() (err error) { return nil } + +func (p *Provider) setupExtraProviders() error { + p.extraProviders = nil + p.sniMatcher = sniMatcher{} + if len(p.cfg.Extra) == 0 { + p.rebuildSNIMatcher() + return nil + } + + for i := range p.cfg.Extra { + merged := mergeExtraConfig(p.cfg, &p.cfg.Extra[i]) + user, legoCfg, err := merged.GetLegoConfig() + if err != nil { + return err.Subjectf("extra[%d]", i) + } + ep := NewProvider(&merged, user, legoCfg) + if err := ep.Setup(); err != nil { + return gperr.PrependSubject(fmt.Sprintf("extra[%d]", i), err) + } + p.extraProviders = append(p.extraProviders, ep) + } + p.rebuildSNIMatcher() + return nil +} + +func mergeExtraConfig(mainCfg *Config, extraCfg *Config) Config { + merged := *mainCfg + merged.Extra = nil + merged.CertPath = extraCfg.CertPath + merged.KeyPath = extraCfg.KeyPath + + if merged.Email == "" { + merged.Email = mainCfg.Email + } + + if len(extraCfg.Domains) > 0 { + merged.Domains = extraCfg.Domains + } + if extraCfg.ACMEKeyPath != "" { + merged.ACMEKeyPath = extraCfg.ACMEKeyPath + } + if extraCfg.Provider != "" { + merged.Provider = extraCfg.Provider + } + if len(extraCfg.Options) > 0 { + merged.Options = extraCfg.Options + } + if len(extraCfg.Resolvers) > 0 { + merged.Resolvers = extraCfg.Resolvers + } + if extraCfg.CADirURL != "" { + merged.CADirURL = extraCfg.CADirURL + } + if len(extraCfg.CACerts) > 0 { + merged.CACerts = extraCfg.CACerts + } + if extraCfg.EABKid != "" { + merged.EABKid = extraCfg.EABKid + } + if extraCfg.EABHmac != "" { + merged.EABHmac = extraCfg.EABHmac + } + if extraCfg.HTTPClient != nil { + merged.HTTPClient = extraCfg.HTTPClient + } + return merged +} diff --git a/internal/autocert/sni_matcher.go b/internal/autocert/sni_matcher.go new file mode 100644 index 00000000..7859e2d5 --- /dev/null +++ b/internal/autocert/sni_matcher.go @@ -0,0 +1,129 @@ +package autocert + +import ( + "crypto/x509" + "strings" +) + +type sniMatcher struct { + exact map[string]*Provider + root sniTreeNode +} + +type sniTreeNode struct { + children map[string]*sniTreeNode + wildcard *Provider +} + +func (m *sniMatcher) match(serverName string) *Provider { + if m == nil { + return nil + } + serverName = normalizeServerName(serverName) + if serverName == "" { + return nil + } + if m.exact != nil { + if p, ok := m.exact[serverName]; ok { + return p + } + } + return m.matchSuffixTree(serverName) +} + +func (m *sniMatcher) matchSuffixTree(serverName string) *Provider { + n := &m.root + labels := strings.Split(serverName, ".") + + var best *Provider + for i := len(labels) - 1; i >= 0; i-- { + if n.children == nil { + break + } + next := n.children[labels[i]] + if next == nil { + break + } + n = next + + consumed := len(labels) - i + remaining := len(labels) - consumed + if remaining == 1 && n.wildcard != nil { + best = n.wildcard + } + } + return best +} + +func normalizeServerName(s string) string { + s = strings.TrimSpace(s) + s = strings.TrimSuffix(s, ".") + return strings.ToLower(s) +} + +func (m *sniMatcher) addProvider(p *Provider) { + if p == nil || p.tlsCert == nil || len(p.tlsCert.Certificate) == 0 { + return + } + leaf, err := x509.ParseCertificate(p.tlsCert.Certificate[0]) + if err != nil { + return + } + + addName := func(name string) { + name = normalizeServerName(name) + if name == "" { + return + } + if after, ok := strings.CutPrefix(name, "*."); ok { + suffix := after + if suffix == "" { + return + } + m.insertWildcardSuffix(suffix, p) + return + } + m.insertExact(name, p) + } + + if leaf.Subject.CommonName != "" { + addName(leaf.Subject.CommonName) + } + for _, n := range leaf.DNSNames { + addName(n) + } +} + +func (m *sniMatcher) insertExact(name string, p *Provider) { + if name == "" || p == nil { + return + } + if m.exact == nil { + m.exact = make(map[string]*Provider) + } + if _, exists := m.exact[name]; !exists { + m.exact[name] = p + } +} + +func (m *sniMatcher) insertWildcardSuffix(suffix string, p *Provider) { + if suffix == "" || p == nil { + return + } + n := &m.root + labels := strings.Split(suffix, ".") + for i := len(labels) - 1; i >= 0; i-- { + if n.children == nil { + n.children = make(map[string]*sniTreeNode) + } + next := n.children[labels[i]] + if next == nil { + next = &sniTreeNode{} + n.children[labels[i]] = next + } + n = next + } + if n.wildcard == nil { + n.wildcard = p + } +} diff --git a/internal/autocert/sni_matcher_bench_test.go b/internal/autocert/sni_matcher_bench_test.go new file mode 100644 index 00000000..e55ffb12 --- /dev/null +++ b/internal/autocert/sni_matcher_bench_test.go @@ -0,0 +1,104 @@ +package autocert + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "testing" + "time" +) + +func createTLSCert(dnsNames []string) (*tls.Certificate, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) + if err != nil { + return nil, err + } + + cn := "" + if len(dnsNames) > 0 { + cn = dnsNames[0] + } + + template := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: cn, + }, + NotBefore: time.Now().Add(-time.Minute), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: dnsNames, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + return nil, err + } + + return &tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: key, + }, nil +} + +func BenchmarkSNIMatcher(b *testing.B) { + matcher := sniMatcher{} + + wildcard1Cert, err := createTLSCert([]string{"*.example.com"}) + if err != nil { + b.Fatal(err) + } + wildcard1 := &Provider{tlsCert: wildcard1Cert} + + wildcard2Cert, err := createTLSCert([]string{"*.test.com"}) + if err != nil { + b.Fatal(err) + } + wildcard2 := &Provider{tlsCert: wildcard2Cert} + + wildcard3Cert, err := createTLSCert([]string{"*.foo.com"}) + if err != nil { + b.Fatal(err) + } + wildcard3 := &Provider{tlsCert: wildcard3Cert} + + exact1Cert, err := createTLSCert([]string{"bar.example.com"}) + if err != nil { + b.Fatal(err) + } + exact1 := &Provider{tlsCert: exact1Cert} + + exact2Cert, err := createTLSCert([]string{"baz.test.com"}) + if err != nil { + b.Fatal(err) + } + exact2 := &Provider{tlsCert: exact2Cert} + + matcher.addProvider(wildcard1) + matcher.addProvider(wildcard2) + matcher.addProvider(wildcard3) + matcher.addProvider(exact1) + matcher.addProvider(exact2) + + b.Run("MatchWildcard", func(b *testing.B) { + for b.Loop() { + _ = matcher.match("sub.example.com") + } + }) + + b.Run("MatchExact", func(b *testing.B) { + for b.Loop() { + _ = matcher.match("bar.example.com") + } + }) +} diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index 1931db49..14627789 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -97,9 +97,12 @@ func (ep *Entrypoint) FindRoute(s string) types.HTTPRoute { func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { if ep.accessLogger != nil { - rec := accesslog.NewResponseRecorder(w) + rec := accesslog.GetResponseRecorder(w) w = rec - defer ep.accessLogger.Log(r, rec.Response()) + defer func() { + ep.accessLogger.Log(r, rec.Response()) + accesslog.PutResponseRecorder(rec) + }() } route := ep.findRouteFunc(r.Host) diff --git a/internal/idlewatcher/watcher.go b/internal/idlewatcher/watcher.go index e9d7ed41..92587f76 100644 --- a/internal/idlewatcher/watcher.go +++ b/internal/idlewatcher/watcher.go @@ -259,7 +259,7 @@ func NewWatcher(parent task.Parent, r types.Route, cfg *types.IdlewatcherConfig) p, err = provider.NewDockerProvider(cfg.Docker.DockerCfg, cfg.Docker.ContainerID) kind = "docker" default: - p, err = provider.NewProxmoxProvider(cfg.Proxmox.Node, cfg.Proxmox.VMID) + p, err = provider.NewProxmoxProvider(parent.Context(), cfg.Proxmox.Node, cfg.Proxmox.VMID) kind = "proxmox" } targetURL := r.TargetURL() diff --git a/internal/logging/accesslog/response_recorder.go b/internal/logging/accesslog/response_recorder.go index 4a3b96ef..250346dd 100644 --- a/internal/logging/accesslog/response_recorder.go +++ b/internal/logging/accesslog/response_recorder.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "sync" ) type ResponseRecorder struct { @@ -13,14 +14,30 @@ type ResponseRecorder struct { resp http.Response } -func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { - return &ResponseRecorder{ - w: w, - resp: http.Response{ - StatusCode: http.StatusOK, - Header: w.Header(), - }, +var recorderPool = sync.Pool{ + New: func() any { + return &ResponseRecorder{} + }, +} + +func GetResponseRecorder(w http.ResponseWriter) *ResponseRecorder { + r := recorderPool.Get().(*ResponseRecorder) + r.w = w + r.resp = http.Response{ + StatusCode: http.StatusOK, + Header: w.Header(), } + return r +} + +func PutResponseRecorder(r *ResponseRecorder) { + r.w = nil + r.resp = http.Response{} + recorderPool.Put(r) +} + +func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder { + return GetResponseRecorder(w) } func (w *ResponseRecorder) Unwrap() http.ResponseWriter { diff --git a/internal/net/gphttp/transport.go b/internal/net/gphttp/transport.go index d633ee6b..c996942c 100644 --- a/internal/net/gphttp/transport.go +++ b/internal/net/gphttp/transport.go @@ -16,7 +16,7 @@ func NewTransport() *http.Transport { Proxy: http.ProxyFromEnvironment, DialContext: DefaultDialer.DialContext, ForceAttemptHTTP2: true, - MaxIdleConnsPerHost: 100, + MaxIdleConnsPerHost: 1000, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, diff --git a/internal/route/routes/context.go b/internal/route/routes/context.go index 3365417a..812cfa21 100644 --- a/internal/route/routes/context.go +++ b/internal/route/routes/context.go @@ -30,7 +30,7 @@ func (r *RouteContext) Value(key any) any { func WithRouteContext(r *http.Request, route types.HTTPRoute) *http.Request { // we don't want to copy the request object every fucking requests // return r.WithContext(context.WithValue(r.Context(), routeContextKey, route)) - ctxFieldPtr := (*context.Context)(unsafe.Pointer(uintptr(unsafe.Pointer(r)) + ctxFieldOffset)) + ctxFieldPtr := (*context.Context)(unsafe.Add(unsafe.Pointer(r), ctxFieldOffset)) *ctxFieldPtr = &RouteContext{ Context: r.Context(), Route: route, diff --git a/internal/types/docker_provider_config.go b/internal/types/docker_provider_config.go index 9118084f..e5117809 100644 --- a/internal/types/docker_provider_config.go +++ b/internal/types/docker_provider_config.go @@ -20,7 +20,7 @@ type DockerProviderConfig struct { } // @name DockerProviderConfig type DockerProviderConfigDetailed struct { - Scheme string `json:"scheme,omitempty" validate:"required,oneof=http https tcp tls"` + Scheme string `json:"scheme,omitempty" validate:"required,oneof=http https tcp tls unix ssh"` Host string `json:"host,omitempty" validate:"required,hostname|ip"` Port int `json:"port,omitempty" validate:"required,min=1,max=65535"` TLS *DockerTLSConfig `json:"tls" validate:"omitempty"` @@ -49,11 +49,13 @@ func (cfg *DockerProviderConfig) Parse(value string) error { switch u.Scheme { case "http", "https", "tcp", "tls": + cfg.URL = u.String() + case "unix", "ssh": + cfg.URL = value default: return fmt.Errorf("invalid scheme: %s", u.Scheme) } - cfg.URL = u.String() return nil } diff --git a/internal/types/docker_provider_config_test.go b/internal/types/docker_provider_config_test.go index c74f377c..3339556a 100644 --- a/internal/types/docker_provider_config_test.go +++ b/internal/types/docker_provider_config_test.go @@ -38,7 +38,12 @@ func TestDockerProviderConfigValidation(t *testing.T) { yamlStr string wantErr bool }{ - {name: "valid url", yamlStr: "test: http://localhost:2375", wantErr: false}, + {name: "valid url (http)", yamlStr: "test: http://localhost:2375", wantErr: false}, + {name: "valid url (https)", yamlStr: "test: https://localhost:2375", wantErr: false}, + {name: "valid url (tcp)", yamlStr: "test: tcp://localhost:2375", wantErr: false}, + {name: "valid url (tls)", yamlStr: "test: tls://localhost:2375", wantErr: false}, + {name: "valid url (unix)", yamlStr: "test: unix:///var/run/docker.sock", wantErr: false}, + {name: "valid url (ssh)", yamlStr: "test: ssh://localhost:2375", wantErr: false}, {name: "invalid url", yamlStr: "test: ftp://localhost/2375", wantErr: true}, {name: "valid scheme", yamlStr: ` test: diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh new file mode 100644 index 00000000..c7fc0e7f --- /dev/null +++ b/scripts/benchmark.sh @@ -0,0 +1,199 @@ +#!/bin/bash +# Benchmark script to compare GoDoxy, Traefik, Caddy, and Nginx +# Uses wrk for HTTP load testing + +set -e + +# Configuration +HOST="bench.domain.com" +DURATION="${DURATION:-10s}" +THREADS="${THREADS:-4}" +CONNECTIONS="${CONNECTIONS:-100}" +TARGET="${TARGET-}" + +# Color functions for output +red() { echo -e "\033[0;31m$*\033[0m"; } +green() { echo -e "\033[0;32m$*\033[0m"; } +yellow() { echo -e "\033[1;33m$*\033[0m"; } +blue() { echo -e "\033[0;34m$*\033[0m"; } + +# Check if wrk is installed +if ! command -v wrk &>/dev/null; then + red "Error: wrk is not installed" + echo "Please install wrk:" + echo " Ubuntu/Debian: sudo apt-get install wrk" + echo " macOS: brew install wrk" + echo " Or build from source: https://github.com/wg/wrk" + exit 1 +fi + +if ! command -v h2load &>/dev/null; then + red "Error: h2load is not installed" + echo "Please install h2load (nghttp2-client):" + echo " Ubuntu/Debian: sudo apt-get install nghttp2-client" + echo " macOS: brew install nghttp2" + exit 1 +fi + +OUTFILE="/tmp/reverse_proxy_benchmark_$(date +%Y%m%d_%H%M%S).log" +: >"$OUTFILE" +exec > >(tee -a "$OUTFILE") 2>&1 + +blue "========================================" +blue "Reverse Proxy Benchmark Comparison" +blue "========================================" +echo "" +echo "Target: $HOST" +echo "Duration: $DURATION" +echo "Threads: $THREADS" +echo "Connections: $CONNECTIONS" +if [ -n "$TARGET" ]; then + echo "Filter: $TARGET" +fi +echo "" + +# Define services to test +declare -A services=( + ["GoDoxy"]="http://127.0.0.1:8080" + ["Traefik"]="http://127.0.0.1:8081" + ["Caddy"]="http://127.0.0.1:8082" + ["Nginx"]="http://127.0.0.1:8083" +) + +# Array to store connection errors +declare -a connection_errors=() + +# Function to test connection before benchmarking +test_connection() { + local name=$1 + local url=$2 + + yellow "Testing connection to $name..." + + # Test HTTP/1.1 + local res1=$(curl -sS -w "\n%{http_code}" --http1.1 -H "Host: $HOST" --max-time 5 "$url") + local body1=$(echo "$res1" | head -n -1) + local status1=$(echo "$res1" | tail -n 1) + + # Test HTTP/2 + local res2=$(curl -sS -w "\n%{http_code}" --http2-prior-knowledge -H "Host: $HOST" --max-time 5 "$url") + local body2=$(echo "$res2" | head -n -1) + local status2=$(echo "$res2" | tail -n 1) + + local failed=false + if [ "$status1" != "200" ] || [ ${#body1} -ne 4096 ]; then + red "✗ $name failed HTTP/1.1 connection test (Status: $status1, Body length: ${#body1})" + failed=true + fi + + if [ "$status2" != "200" ] || [ ${#body2} -ne 4096 ]; then + red "✗ $name failed HTTP/2 connection test (Status: $status2, Body length: ${#body2})" + failed=true + fi + + if [ "$failed" = true ]; then + connection_errors+=("$name failed connection test (URL: $url)") + return 1 + else + green "✓ $name is reachable (HTTP/1.1 & HTTP/2)" + return 0 + fi +} + +blue "========================================" +blue "Connection Tests" +blue "========================================" +echo "" + +# Run connection tests for all services +for name in "${!services[@]}"; do + if [ -z "$TARGET" ] || [ "${name,,}" = "${TARGET,,}" ]; then + test_connection "$name" "${services[$name]}" + fi +done + +echo "" +blue "========================================" + +# Exit if any connection test failed +if [ ${#connection_errors[@]} -gt 0 ]; then + echo "" + red "Connection test failed for the following services:" + for error in "${connection_errors[@]}"; do + red " - $error" + done + echo "" + red "Please ensure all services are running before benchmarking" + exit 1 +fi + +echo "" +green "All services are reachable. Starting benchmarks..." +echo "" +blue "========================================" +echo "" + +restart_bench() { + local name=$1 + echo "" + yellow "Restarting bench service before benchmarking $name HTTP/1.1..." + docker compose -f dev.compose.yml up -d --force-recreate bench >/dev/null 2>&1 + sleep 1 +} + +# Function to run benchmark +run_benchmark() { + local name=$1 + local url=$2 + local h2_duration="${DURATION%s}" + + restart_bench "$name" + + yellow "Testing $name..." + + echo "========================================" + echo "$name" + echo "URL: $url" + echo "========================================" + echo "" + echo "[HTTP/1.1] wrk" + + wrk -t"$THREADS" -c"$CONNECTIONS" -d"$DURATION" \ + -H "Host: $HOST" \ + "$url" + + restart_bench "$name" + + echo "" + echo "[HTTP/2] h2load" + + h2load -t"$THREADS" -c"$CONNECTIONS" --duration="$h2_duration" \ + -H "Host: $HOST" \ + -H ":authority: $HOST" \ + "$url" | grep -vE "^(starting benchmark...|spawning thread|progress: |Warm-up |Main benchmark duration|Stopped all clients|Process Request Failure)" + + echo "" + green "✓ $name benchmark completed" + blue "----------------------------------------" + echo "" +} + +# Run benchmarks for each service +for name in "${!services[@]}"; do + if [ -z "$TARGET" ] || [ "${name,,}" = "${TARGET,,}" ]; then + run_benchmark "$name" "${services[$name]}" + fi +done + +blue "========================================" +blue "Benchmark Summary" +blue "========================================" +echo "" +echo "All benchmark output saved to: $OUTFILE" +echo "" +echo "Key metrics to compare:" +echo " - Requests/sec (throughput)" +echo " - Latency (mean, stdev)" +echo " - Transfer/sec" +echo "" +green "All benchmarks completed!"