Compare commits

...

35 Commits

Author SHA1 Message Date
yusing
1b9cfa6540 fix(autocert): forceRenewalDoneCh was never closed 2026-01-04 20:40:38 +08:00
yusing
f1d906ac11 fix(test): update test expectations 2026-01-04 20:31:11 +08:00
yusing
2835fd5fb0 fix(autocert): ensure extra certificate registration and renewal scheduling
Extra providers were not being properly initialized during NewProvider(),
causing certificate registration and renewal scheduling to be skipped.

- Add ConfigExtra type with idx field for provider indexing
- Add MergeExtraConfig() for inheriting main provider settings
- Add setupExtraProviders() for recursive extra provider initialization
- Refactor NewProvider to return error and call setupExtraProviders()
- Add provider-scoped logger with "main" or "extra[N]" name
- Add batch operations: ObtainCertIfNotExistsAll(), ObtainCertAll()
- Add ForceExpiryAll() with completion tracking via WaitRenewalDone()
- Add RenewMode (force/ifNeeded) for controlling renewal behavior
- Add PrintCertExpiriesAll() for logging all provider certificate expiries

Summary of staged changes:
- config.go: Added ConfigExtra type, MergeExtraConfig(), recursive validation with path uniqueness checking
- provider.go: Added provider indexing, scoped logger, batch cert operations, force renewal with completion tracking, RenewMode control
- setup.go: New file with setupExtraProviders() for proper extra provider initialization
- setup_test.go: New tests for extra provider setup
- multi_cert_test.go: New tests for multi-certificate functionality
- renew.go: Updated to use new provider API with error handling
- state.go: Updated to handle NewProvider error return
2026-01-04 20:30:58 +08:00
yusing
11d0c61b9c refactor(state): replace Entrypoint method with ShortLinkMatcher interface
- Cleaned up agent go.mod by removing unused indirect dependencies.
2026-01-04 12:43:05 +08:00
Yuzerion
c00854a124 feat(autocert): add multi-certificate support (#185)
Multi-certificate, SNI matching with exact map and suffix tree

Add support for multiple TLS certificates with SNI-based selection. The
root provider maintains a single centralized SNI matcher that uses an
exact match map for O(1) lookups, falling back to a suffix tree for
wildcard matching.

Key features:
- Add `Extra []Config` field to autocert.Config for additional certificates
- Each extra entry must specify unique `cert_path` and `key_path`
- Extra certs inherit main config (except `email` and `extra` fields)
- Extra certs participate in ACME obtain/renew cycles independently
- SNI selection precedence: exact match > wildcard match, main > extra
- Single centralized SNI matcher on root provider rebuilt after cert changes

The SNI matcher structure:
- Exact match map: O(1) lookup for exact domain matches
- Suffix tree: Efficient wildcard matching (e.g., *.example.com)

Implementation details:
- Provider.GetCert() now uses SNI from ClientHelloInfo for selection
- Main cert is returned as fallback when no SNI match is found
- Extra providers are created as child providers with merged configs
- SNI matcher is rebuilt after Setup() and after ObtainCert() completes
2026-01-04 00:37:26 +08:00
yusing
117dbb62f4 refactor(docker): accept unix and ssh scheme for providers 2026-01-03 20:06:31 +08:00
yusing
2c28bc116c fix(h2c_test_server): correct listening on message 2026-01-03 12:58:14 +08:00
yusing
1d90bec9ed refactor(benchmark): restart bench server after each run 2026-01-03 12:54:18 +08:00
yusing
b2df749cd1 refactor(io,reverseproxy): suppress "client disconnected" error; optimize CopyClose method 2026-01-03 12:41:11 +08:00
yusing
1916f73e78 refactor(route): modernize code with unsafe.Add 2026-01-03 12:40:55 +08:00
yusing
99ab9beb4a refactor(http/transport): increase MaxIdleConnsPerHost to 1000 2026-01-03 12:40:28 +08:00
yusing
5de064aa47 refactor(benchmark): replace whoami service with bench server
- Updated dev.compose.yml to define a new bench service that serves 4096 bytes of random data.
- Modified configurations for Traefik, Caddy, and Nginx to route traffic to the new bench service.
- Added Dockerfile and Go application for the bench server, including necessary Go modules.
- Updated benchmark script to target the new bench service endpoint.
2026-01-03 12:40:10 +08:00
yusing
880e11c414 refactor(http/reverseproxy): performance improvement
- Replaced req.Clone with req.WithContext and url/header/trailer cloning.
- Added conditional handling for "Expect" headers to manage 1xx responses with appropriate tracing.
2026-01-03 02:30:15 +08:00
yusing
0dfce823bf refactor(http): performance improvement
- Introduced a sync.Pool for ResponseRecorder to optimize memory usage.
- Updated ServeHTTP method to utilize the new GetResponseRecorder and PutResponseRecorder functions.
- Adjusted NewResponseRecorder to leverage the pooling mechanism.
2026-01-03 02:20:01 +08:00
yusing
c2583fc756 refactor(benchmark): update whoami service configuration to use FQDN alias 2026-01-03 02:10:00 +08:00
yusing
cf6246d58a refactor(benchmark): remove unused Docker socket configuration from benchmark service 2026-01-03 02:04:49 +08:00
yusing
fb040afe90 refactor(benchmark): benchmark script functionality and fairness 2026-01-03 00:57:50 +08:00
yusing
dc8abe943d feat(benchmark): enhance dev.compose.yml with benchmark services and scripts
- Added benchmark services (whoami, godoxy, traefik, caddy, nginx) to dev.compose.yml.
- Introduced a new benchmark.sh script for load testing using wrk and h2load.
- Updated Makefile to include a benchmark target for easy execution of the new script.
2026-01-03 00:28:59 +08:00
yusing
587b83cf14 fix(idlewatcher): pass context to ProxmoxProvider 2026-01-02 22:17:40 +08:00
yusing
a4658caf02 refactor(config): correct logic in InitFromFile 2026-01-02 21:56:34 +08:00
yusing
ef9ee0e169 feat(websocket): update goutils - deduplicate data to avoid unnecessary traffic 2026-01-02 18:04:08 +08:00
yusing
7eadec9752 chore: remove unused utils/deep_equal.go 2026-01-02 18:03:13 +08:00
yusing
dd35a4159f refactor(api/health): simplify health info type
- Updated health-related functions to return simplified health information.
- Introduced HealthStatusString type for correct swagger and schema generation.
- Refactored HealthJSON structure to utilize the new HealthStatusString type.
2026-01-02 18:02:49 +08:00
yusing
f28667e23e refactor: add context handling in various functions
- Modified functions to accept context.Context as a parameter for better context management.
- Updated Init methods in Proxmox and Config to use the provided context.
- Adjusted UpdatePorts and NewProxmoxProvider to utilize the context for operations.
2026-01-02 17:41:36 +08:00
yusing
8009da9e4d chore: go mod tidy 2026-01-02 15:49:03 +08:00
yusing
590743f1ef feat(entrypoint): implement short link #177
- Added ShortLinkMatcher to handle short link routing.
- Integrated short link handling in Entrypoint.
- Introduced tests for short link matching and dispatching.
- Configured default domain suffix for subdomain aliases.
2026-01-02 15:42:15 +08:00
yusing
1f4c30a48e fix(docker): update scheme validation to include 'tcp' in DockerProviderConfigDetailed 2026-01-02 10:55:42 +08:00
yusing
bae7387a5d feat(dev): add jotty and postgres-test services to dev.compose.yml 2026-01-02 01:20:05 +08:00
yusing
67fc48383d refactor(monitor): include detail in service down notification log 2026-01-02 01:17:47 +08:00
yusing
1406881071 feat(http/h2c): h2c test server with a Dockerfile
- Implemented a basic HTTP/2 server that responds with "ok" to requests.
- Updated dev.compose.yml to include a service for it
2026-01-02 01:17:28 +08:00
yusing
7976befda4 feat(http): enable HTTP/2 support in server configuration
- Added NextProtos to TLSConfig to prefer HTTP/2 and fallback to HTTP/1.1.
- Configured the server to handle HTTP/2 connections, with error logging for configuration failures.
2026-01-02 01:11:07 +08:00
yusing
8139311074 feat(healthcheck/http): implement h2c health check support and refactor request handling
- Added support for health checks using the h2c scheme.
- Refactored common header setting into a dedicated function.
- Updated CheckHealth method to differentiate between HTTP and h2c checks.
2026-01-02 00:46:48 +08:00
yusing
2690bf548d chore: update swagger add h2c scheme type 2026-01-01 18:56:11 +08:00
yusing
d3358ebd89 feat(http/reverseproxy): h2c support with scheme: h2c 2026-01-01 18:54:49 +08:00
yusing
fd74bfedf0 fix(agent): improve url handling to not break urls with encoded characters 2026-01-01 18:25:27 +08:00
61 changed files with 2974 additions and 1012 deletions

View File

@@ -123,6 +123,15 @@ dev:
dev-build: build
docker compose -f dev.compose.yml up -t 0 -d app --force-recreate
benchmark:
@if [ -z "$(TARGET)" ]; then \
docker compose -f dev.compose.yml up -d --force-recreate godoxy traefik caddy nginx; \
else \
docker compose -f dev.compose.yml up -d --force-recreate $(TARGET); \
fi
sleep 1
@./scripts/benchmark.sh
dev-run: build
cd dev-data && ${BIN_PATH}
@@ -142,7 +151,7 @@ ci-test:
act -n --artifact-server-path /tmp/artifacts -s GITHUB_TOKEN="$$(gh auth token)"
cloc:
scc -w -i go --not-match '_test.go$'
scc -w -i go --not-match '_test.go$$'
push-github:
git push origin $(shell git rev-parse --abbrev-ref HEAD)

View File

@@ -22,7 +22,7 @@ require (
github.com/rs/zerolog v1.34.0
github.com/stretchr/testify v1.11.1
github.com/valyala/fasthttp v1.68.0
github.com/yusing/godoxy v0.20.10
github.com/yusing/godoxy v0.0.0-00010101000000-000000000000
github.com/yusing/godoxy/socketproxy v0.0.0-00010101000000-000000000000
github.com/yusing/goutils v0.7.0
github.com/yusing/goutils/http/reverseproxy v0.0.0-20251217162119-cb0f79b51ce2

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/yusing/godoxy/agent/pkg/agent"
@@ -43,10 +44,22 @@ func ProxyHTTP(w http.ResponseWriter, r *http.Request) {
return
}
r.URL.Scheme = ""
r.URL.Host = ""
r.URL.Path = r.URL.Path[agent.HTTPProxyURLPrefixLen:] // strip the {API_BASE}/proxy/http prefix
r.RequestURI = r.URL.String()
// Strip the {API_BASE}/proxy/http prefix while preserving URL escaping.
//
// NOTE: `r.URL.Path` is decoded. If we rewrite it without keeping `RawPath`
// in sync, Go may re-escape the path (e.g. turning "%5B" into "%255B"),
// which breaks urls with percent-encoded characters, like Next.js static chunk URLs.
prefix := agent.APIEndpointBase + agent.EndpointProxyHTTP
r.URL.Path = strings.TrimPrefix(r.URL.Path, prefix)
if r.URL.RawPath != "" {
if after, ok := strings.CutPrefix(r.URL.RawPath, prefix); ok {
r.URL.RawPath = after
} else {
// RawPath is no longer a valid encoding for Path; force Go to re-derive it.
r.URL.RawPath = ""
}
}
r.RequestURI = ""
rp := &httputil.ReverseProxy{
Director: func(r *http.Request) {

View File

@@ -0,0 +1,18 @@
FROM golang:1.25.5-alpine AS builder
HEALTHCHECK NONE
WORKDIR /src
COPY go.mod go.sum ./
COPY main.go ./
RUN go build -o bench_server main.go
FROM scratch
COPY --from=builder /src/bench_server /app/run
USER 1001:1001
CMD ["/app/run"]

3
cmd/bench_server/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module github.com/yusing/godoxy/cmd/bench_server
go 1.25.5

0
cmd/bench_server/go.sum Normal file
View File

34
cmd/bench_server/main.go Normal file
View File

@@ -0,0 +1,34 @@
package main
import (
"log"
"net/http"
"math/rand/v2"
)
var printables = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var random = make([]byte, 4096)
func init() {
for i := range random {
random[i] = printables[rand.IntN(len(printables))]
}
}
func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write(random)
})
server := &http.Server{
Addr: ":80",
Handler: handler,
}
log.Println("Bench server listening on :80")
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("ListenAndServe: %v", err)
}
}

View File

@@ -0,0 +1,18 @@
FROM golang:1.25.5-alpine AS builder
HEALTHCHECK NONE
WORKDIR /src
COPY go.mod go.sum ./
COPY main.go ./
RUN go build -o h2c_test_server main.go
FROM scratch
COPY --from=builder /src/h2c_test_server /app/run
USER 1001:1001
CMD ["/app/run"]

View File

@@ -0,0 +1,7 @@
module github.com/yusing/godoxy/cmd/h2c_test_server
go 1.25.5
require golang.org/x/net v0.48.0
require golang.org/x/text v0.32.0 // indirect

View File

@@ -0,0 +1,4 @@
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=

View File

@@ -0,0 +1,26 @@
package main
import (
"log"
"net/http"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
func main() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
})
server := &http.Server{
Addr: ":80",
Handler: h2c.NewHandler(handler, &http2.Server{}),
}
log.Println("H2C server listening on :80")
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("ListenAndServe: %v", err)
}
}

View File

@@ -1,3 +1,8 @@
x-benchmark: &benchmark
restart: no
labels:
proxy.exclude: true
proxy.#1.healthcheck.disable: true
services:
app:
image: godoxy-dev
@@ -54,7 +59,190 @@ services:
- USERS=user:$$2a$$10$$UdLYoJ5lgPsC0RKqYH/jMua7zIn0g9kPqWmhYayJYLaZQ/FTmH2/u # user:password
labels:
proxy.tinyauth.port: "3000"
jotty: # issue #182
image: ghcr.io/fccview/jotty:latest
container_name: jotty
user: "1000:1000"
tmpfs:
- /app/data:rw,uid=1000,gid=1000
- /app/config:rw,uid=1000,gid=1000
- /app/.next/cache:rw,uid=1000,gid=1000
restart: unless-stopped
environment:
- NODE_ENV=production
labels:
proxy.aliases: "jotty.my.app"
postgres-test:
image: postgres:18-alpine
container_name: postgres-test
restart: unless-stopped
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
- POSTGRES_DB=postgres
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 10s
timeout: 5s
retries: 5
start_period: 30s
h2c_test_server:
build:
context: cmd/h2c_test_server
dockerfile: Dockerfile
container_name: h2c_test
restart: unless-stopped
labels:
proxy.#1.scheme: h2c
proxy.#1.port: 80
bench: # returns 4096 bytes of random data
<<: *benchmark
build:
context: cmd/bench_server
dockerfile: Dockerfile
container_name: bench
godoxy:
<<: *benchmark
build: .
container_name: godoxy-benchmark
ports:
- 8080:80
configs:
- source: godoxy_config
target: /app/config/config.yml
- source: godoxy_provider
target: /app/config/providers.yml
traefik:
<<: *benchmark
image: traefik:latest
container_name: traefik
command:
- --api.insecure=true
- --entrypoints.web.address=:8081
- --providers.file.directory=/etc/traefik/dynamic
- --providers.file.watch=true
- --log.level=ERROR
ports:
- 8081:8081
configs:
- source: traefik_config
target: /etc/traefik/dynamic/routes.yml
caddy:
<<: *benchmark
image: caddy:latest
container_name: caddy
ports:
- 8082:80
configs:
- source: caddy_config
target: /etc/caddy/Caddyfile
tmpfs:
- /data
- /config
nginx:
<<: *benchmark
image: nginx:latest
container_name: nginx
command: nginx -g 'daemon off;' -c /etc/nginx/nginx.conf
ports:
- 8083:80
configs:
- source: nginx_config
target: /etc/nginx/nginx.conf
configs:
godoxy_config:
content: |
providers:
include:
- providers.yml
godoxy_provider:
content: |
bench.domain.com:
host: bench
traefik_config:
content: |
http:
routers:
bench:
rule: "Host(`bench.domain.com`)"
entryPoints:
- web
service: bench
services:
bench:
loadBalancer:
servers:
- url: "http://bench:80"
caddy_config:
content: |
{
admin off
auto_https off
default_bind 0.0.0.0
servers {
protocols h1 h2c
}
}
http://bench.domain.com {
reverse_proxy bench:80
}
nginx_config:
content: |
worker_processes auto;
worker_rlimit_nofile 65535;
error_log /dev/null;
pid /var/run/nginx.pid;
events {
worker_connections 10240;
multi_accept on;
use epoll;
}
http {
include /etc/nginx/mime.types;
default_type application/octet-stream;
access_log off;
sendfile on;
tcp_nopush on;
tcp_nodelay on;
keepalive_timeout 65;
keepalive_requests 10000;
upstream backend {
server bench:80;
keepalive 128;
}
server {
listen 80 default_server;
server_name _;
http2 on;
return 404;
}
server {
listen 80;
server_name bench.domain.com;
http2 on;
location / {
proxy_pass http://backend;
proxy_http_version 1.1;
proxy_set_header Connection "";
proxy_set_header Host $$host;
proxy_set_header X-Real-IP $$remote_addr;
proxy_set_header X-Forwarded-For $$proxy_add_x_forwarded_for;
proxy_buffering off;
}
}
}
parca:
content: |
object_storage:

Submodule goutils updated: 51a75d684b...785deb23bd

View File

@@ -9,7 +9,6 @@ import (
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/logging/memlogger"
apitypes "github.com/yusing/goutils/apitypes"
gperr "github.com/yusing/goutils/errs"
"github.com/yusing/goutils/http/websocket"
)
@@ -40,33 +39,33 @@ func Renew(c *gin.Context) {
logs, cancel := memlogger.Events()
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
// Stream logs until WebSocket connection closes (renewal runs in background)
for {
select {
case <-manager.Context().Done():
return
case l := <-logs:
if err != nil {
return
}
err = autocert.ObtainCert()
if err != nil {
gperr.LogError("failed to obtain cert", err)
_ = manager.WriteData(websocket.TextMessage, []byte(err.Error()), 10*time.Second)
} else {
log.Info().Msg("cert obtained successfully")
err = manager.WriteData(websocket.TextMessage, l, 10*time.Second)
if err != nil {
return
}
}
}
}()
for {
select {
case l := <-logs:
if err != nil {
return
}
err = manager.WriteData(websocket.TextMessage, l, 10*time.Second)
if err != nil {
return
}
case <-done:
return
}
// renewal happens in background
ok := autocert.ForceExpiryAll()
if !ok {
log.Error().Msg("cert renewal already in progress")
time.Sleep(1 * time.Second) // wait for the log above to be sent
return
}
log.Info().Msg("cert force renewal requested")
autocert.WaitRenewalDone(manager.Context())
}

View File

@@ -2956,43 +2956,6 @@
"x-nullable": false,
"x-omitempty": false
},
"HealthInfo": {
"type": "object",
"properties": {
"detail": {
"type": "string",
"x-nullable": false,
"x-omitempty": false
},
"latency": {
"description": "latency in microseconds",
"type": "number",
"x-nullable": false,
"x-omitempty": false
},
"status": {
"type": "string",
"enum": [
"healthy",
"unhealthy",
"napping",
"starting",
"error",
"unknown"
],
"x-nullable": false,
"x-omitempty": false
},
"uptime": {
"description": "uptime in milliseconds",
"type": "number",
"x-nullable": false,
"x-omitempty": false
}
},
"x-nullable": false,
"x-omitempty": false
},
"HealthInfoWithoutDetail": {
"type": "object",
"properties": {
@@ -3047,22 +3010,14 @@
"x-nullable": true
},
"lastSeen": {
"description": "unix timestamp in seconds",
"type": "integer",
"x-nullable": false,
"x-omitempty": false
},
"lastSeenStr": {
"type": "string",
"x-nullable": false,
"x-omitempty": false
},
"latency": {
"type": "number",
"x-nullable": false,
"x-omitempty": false
},
"latencyStr": {
"type": "string",
"description": "latency in milliseconds",
"type": "integer",
"x-nullable": false,
"x-omitempty": false
},
@@ -3072,30 +3027,22 @@
"x-omitempty": false
},
"started": {
"description": "unix timestamp in seconds",
"type": "integer",
"x-nullable": false,
"x-omitempty": false
},
"startedStr": {
"type": "string",
"x-nullable": false,
"x-omitempty": false
},
"status": {
"type": "string",
"$ref": "#/definitions/HealthStatusString",
"x-nullable": false,
"x-omitempty": false
},
"uptime": {
"description": "uptime in seconds",
"type": "number",
"x-nullable": false,
"x-omitempty": false
},
"uptimeStr": {
"type": "string",
"x-nullable": false,
"x-omitempty": false
},
"url": {
"type": "string",
"x-nullable": false,
@@ -3108,11 +3055,32 @@
"HealthMap": {
"type": "object",
"additionalProperties": {
"$ref": "#/definitions/HealthInfo"
"$ref": "#/definitions/HealthStatusString"
},
"x-nullable": false,
"x-omitempty": false
},
"HealthStatusString": {
"type": "string",
"enum": [
"unknown",
"healthy",
"napping",
"starting",
"unhealthy",
"error"
],
"x-enum-varnames": [
"StatusUnknownStr",
"StatusHealthyStr",
"StatusNappingStr",
"StatusStartingStr",
"StatusUnhealthyStr",
"StatusErrorStr"
],
"x-nullable": false,
"x-omitempty": false
},
"HomepageCategory": {
"type": "object",
"properties": {
@@ -4357,6 +4325,7 @@
"enum": [
"http",
"https",
"h2c",
"tcp",
"udp",
"fileserver"
@@ -5494,6 +5463,7 @@
"enum": [
"http",
"https",
"h2c",
"tcp",
"udp",
"fileserver"

View File

@@ -302,26 +302,6 @@ definitions:
additionalProperties: {}
type: object
type: object
HealthInfo:
properties:
detail:
type: string
latency:
description: latency in microseconds
type: number
status:
enum:
- healthy
- unhealthy
- napping
- starting
- error
- unknown
type: string
uptime:
description: uptime in milliseconds
type: number
type: object
HealthInfoWithoutDetail:
properties:
latency:
@@ -351,32 +331,44 @@ definitions:
- $ref: '#/definitions/HealthExtra'
x-nullable: true
lastSeen:
description: unix timestamp in seconds
type: integer
lastSeenStr:
type: string
latency:
type: number
latencyStr:
type: string
description: latency in milliseconds
type: integer
name:
type: string
started:
description: unix timestamp in seconds
type: integer
startedStr:
type: string
status:
type: string
$ref: '#/definitions/HealthStatusString'
uptime:
description: uptime in seconds
type: number
uptimeStr:
type: string
url:
type: string
type: object
HealthMap:
additionalProperties:
$ref: '#/definitions/HealthInfo'
$ref: '#/definitions/HealthStatusString'
type: object
HealthStatusString:
enum:
- unknown
- healthy
- napping
- starting
- unhealthy
- error
type: string
x-enum-varnames:
- StatusUnknownStr
- StatusHealthyStr
- StatusNappingStr
- StatusStartingStr
- StatusUnhealthyStr
- StatusErrorStr
HomepageCategory:
properties:
items:
@@ -963,6 +955,7 @@ definitions:
enum:
- http
- https
- h2c
- tcp
- udp
- fileserver
@@ -1578,6 +1571,7 @@ definitions:
enum:
- http
- https
- h2c
- tcp
- udp
- fileserver

View File

@@ -12,8 +12,6 @@ import (
_ "github.com/yusing/goutils/apitypes"
)
type HealthMap = map[string]routes.HealthInfo // @name HealthMap
// @x-id "health"
// @BasePath /api/v1
// @Summary Get routes health info
@@ -21,16 +19,16 @@ type HealthMap = map[string]routes.HealthInfo // @name HealthMap
// @Tags v1,websocket
// @Accept json
// @Produce json
// @Success 200 {object} HealthMap "Health info by route name"
// @Success 200 {object} routes.HealthMap "Health info by route name"
// @Failure 403 {object} apitypes.ErrorResponse
// @Failure 500 {object} apitypes.ErrorResponse
// @Router /health [get]
func Health(c *gin.Context) {
if httpheaders.IsWebsocket(c.Request.Header) {
websocket.PeriodicWrite(c, 1*time.Second, func() (any, error) {
return routes.GetHealthInfo(), nil
return routes.GetHealthInfoSimple(), nil
})
} else {
c.JSON(http.StatusOK, routes.GetHealthInfo())
c.JSON(http.StatusOK, routes.GetHealthInfoSimple())
}
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/yusing/godoxy/internal/common"
"golang.org/x/oauth2"
"golang.org/x/time/rate"
expect "github.com/yusing/goutils/testing"
)
@@ -42,6 +43,7 @@ func setupMockOIDC(t *testing.T) {
}),
allowedUsers: []string{"test-user"},
allowedGroups: []string{"test-group1", "test-group2"},
rateLimit: rate.NewLimiter(rate.Every(common.OIDCRateLimitPeriod), common.OIDCRateLimit),
}
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"fmt"
"net/http"
"os"
"regexp"
@@ -19,12 +20,14 @@ import (
strutils "github.com/yusing/goutils/strings"
)
type ConfigExtra Config
type Config struct {
Email string `json:"email,omitempty"`
Domains []string `json:"domains,omitempty"`
CertPath string `json:"cert_path,omitempty"`
KeyPath string `json:"key_path,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"`
Extra []ConfigExtra `json:"extra,omitempty"`
ACMEKeyPath string `json:"acme_key_path,omitempty"` // shared by all extra providers
Provider string `json:"provider,omitempty"`
Options map[string]strutils.Redacted `json:"options,omitempty"`
@@ -41,13 +44,13 @@ type Config struct {
HTTPClient *http.Client `json:"-"` // for tests only
challengeProvider challenge.Provider
idx int // 0: main, 1+: extra[i]
}
var (
ErrMissingDomain = gperr.New("missing field 'domains'")
ErrMissingEmail = gperr.New("missing field 'email'")
ErrMissingProvider = gperr.New("missing field 'provider'")
ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'")
ErrMissingField = gperr.New("missing field")
ErrDuplicatedPath = gperr.New("duplicated path")
ErrInvalidDomain = gperr.New("invalid domain")
ErrUnknownProvider = gperr.New("unknown provider")
)
@@ -62,69 +65,22 @@ var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
// Validate implements the utils.CustomValidator interface.
func (cfg *Config) Validate() gperr.Error {
if cfg == nil {
return nil
}
seenPaths := make(map[string]int) // path -> provider idx (0 for main, 1+ for extras)
return cfg.validate(seenPaths)
}
func (cfg *ConfigExtra) Validate() gperr.Error {
return nil // done by main config's validate
}
func (cfg *ConfigExtra) AsConfig() *Config {
return (*Config)(cfg)
}
func (cfg *Config) validate(seenPaths map[string]int) gperr.Error {
if cfg.Provider == "" {
cfg.Provider = ProviderLocal
return nil
}
b := gperr.NewBuilder("autocert errors")
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
b.Add(ErrMissingCADirURL)
}
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
if cfg.Email == "" {
b.Add(ErrMissingEmail)
}
if cfg.Provider != ProviderCustom {
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
}
// check if provider is implemented
providerConstructor, ok := Providers[cfg.Provider]
if !ok {
if cfg.Provider != ProviderCustom {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
With(gperr.DoYouMeanField(cfg.Provider, Providers)))
}
} else {
provider, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
} else {
cfg.challengeProvider = provider
}
}
}
if cfg.challengeProvider == nil {
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
}
return b.Error()
}
func (cfg *Config) dns01Options() []dns01.ChallengeOption {
return []dns01.ChallengeOption{
dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)),
}
}
func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
if err := cfg.Validate(); err != nil {
return nil, nil, err
}
if cfg.CertPath == "" {
cfg.CertPath = CertFileDefault
}
@@ -135,6 +91,83 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
cfg.ACMEKeyPath = ACMEKeyFileDefault
}
b := gperr.NewBuilder("certificate error")
// check if cert_path is unique
if first, ok := seenPaths[cfg.CertPath]; ok {
b.Add(ErrDuplicatedPath.Subjectf("cert_path %s", cfg.CertPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first)))
} else {
seenPaths[cfg.CertPath] = cfg.idx
}
// check if key_path is unique
if first, ok := seenPaths[cfg.KeyPath]; ok {
b.Add(ErrDuplicatedPath.Subjectf("key_path %s", cfg.KeyPath).Withf("first seen in %s", fmt.Sprintf("extra[%d]", first)))
} else {
seenPaths[cfg.KeyPath] = cfg.idx
}
if cfg.Provider == ProviderCustom && cfg.CADirURL == "" {
b.Add(ErrMissingField.Subject("ca_dir_url"))
}
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingField.Subject("domains"))
}
if cfg.Email == "" {
b.Add(ErrMissingField.Subject("email"))
}
if cfg.Provider != ProviderCustom {
for i, d := range cfg.Domains {
if !domainOrWildcardRE.MatchString(d) {
b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i))
}
}
}
}
// check if provider is implemented
providerConstructor, ok := Providers[cfg.Provider]
if !ok {
if cfg.Provider != ProviderCustom {
b.Add(ErrUnknownProvider.
Subject(cfg.Provider).
With(gperr.DoYouMeanField(cfg.Provider, Providers)))
}
} else {
provider, err := providerConstructor(cfg.Options)
if err != nil {
b.Add(err)
} else {
cfg.challengeProvider = provider
}
}
if cfg.challengeProvider == nil {
cfg.challengeProvider, _ = Providers[ProviderLocal](nil)
}
if len(cfg.Extra) > 0 {
for i := range cfg.Extra {
cfg.Extra[i] = MergeExtraConfig(cfg, &cfg.Extra[i])
cfg.Extra[i].AsConfig().idx = i + 1
err := cfg.Extra[i].AsConfig().validate(seenPaths)
if err != nil {
b.Add(err.Subjectf("extra[%d]", i))
}
}
}
return b.Error()
}
func (cfg *Config) dns01Options() []dns01.ChallengeOption {
return []dns01.ChallengeOption{
dns01.CondOption(len(cfg.Resolvers) > 0, dns01.AddRecursiveNameservers(cfg.Resolvers)),
}
}
func (cfg *Config) GetLegoConfig() (*User, *lego.Config, error) {
var privKey *ecdsa.PrivateKey
var err error
@@ -178,6 +211,46 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) {
return user, legoCfg, nil
}
func MergeExtraConfig(mainCfg *Config, extraCfg *ConfigExtra) ConfigExtra {
merged := ConfigExtra(*mainCfg)
merged.Extra = nil
merged.CertPath = extraCfg.CertPath
merged.KeyPath = extraCfg.KeyPath
// NOTE: Using same ACME key as main provider
if extraCfg.Provider != "" {
merged.Provider = extraCfg.Provider
}
if extraCfg.Email != "" {
merged.Email = extraCfg.Email
}
if len(extraCfg.Domains) > 0 {
merged.Domains = extraCfg.Domains
}
if len(extraCfg.Options) > 0 {
merged.Options = extraCfg.Options
}
if len(extraCfg.Resolvers) > 0 {
merged.Resolvers = extraCfg.Resolvers
}
if extraCfg.CADirURL != "" {
merged.CADirURL = extraCfg.CADirURL
}
if len(extraCfg.CACerts) > 0 {
merged.CACerts = extraCfg.CACerts
}
if extraCfg.EABKid != "" {
merged.EABKid = extraCfg.EABKid
}
if extraCfg.EABHmac != "" {
merged.EABHmac = extraCfg.EABHmac
}
if extraCfg.HTTPClient != nil {
merged.HTTPClient = extraCfg.HTTPClient
}
return merged
}
func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) {
if common.IsTest {
return nil, os.ErrNotExist

View File

@@ -1,27 +1,32 @@
package autocert
package autocert_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/dnsproviders"
"github.com/yusing/godoxy/internal/serialization"
)
func TestEABConfigRequired(t *testing.T) {
dnsproviders.InitProviders()
tests := []struct {
name string
cfg *Config
cfg *autocert.Config
wantErr bool
}{
{name: "Missing EABKid", cfg: &Config{EABHmac: "1234567890"}, wantErr: true},
{name: "Missing EABHmac", cfg: &Config{EABKid: "1234567890"}, wantErr: true},
{name: "Valid EAB", cfg: &Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false},
{name: "Missing EABKid", cfg: &autocert.Config{EABHmac: "1234567890"}, wantErr: true},
{name: "Missing EABHmac", cfg: &autocert.Config{EABKid: "1234567890"}, wantErr: true},
{name: "Valid EAB", cfg: &autocert.Config{EABKid: "1234567890", EABHmac: "1234567890"}, wantErr: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
yaml := fmt.Appendf(nil, "eab_kid: %s\neab_hmac: %s", test.cfg.EABKid, test.cfg.EABHmac)
cfg := Config{}
cfg := autocert.Config{}
err := serialization.UnmarshalValidateYAML(yaml, &cfg)
if (err != nil) != test.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, test.wantErr)
@@ -29,3 +34,27 @@ func TestEABConfigRequired(t *testing.T) {
})
}
}
func TestExtraCertKeyPathsUnique(t *testing.T) {
t.Run("duplicate cert_path rejected", func(t *testing.T) {
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
Extra: []autocert.ConfigExtra{
{CertPath: "a.crt", KeyPath: "a.key"},
{CertPath: "a.crt", KeyPath: "b.key"},
},
}
require.Error(t, cfg.Validate())
})
t.Run("duplicate key_path rejected", func(t *testing.T) {
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
Extra: []autocert.ConfigExtra{
{CertPath: "a.crt", KeyPath: "a.key"},
{CertPath: "b.crt", KeyPath: "a.key"},
},
}
require.Error(t, cfg.Validate())
})
}

View File

@@ -5,5 +5,4 @@ const (
CertFileDefault = certBasePath + "cert.crt"
KeyFileDefault = certBasePath + "priv.key"
ACMEKeyFileDefault = certBasePath + "acme.key"
LastFailureFile = certBasePath + ".last_failure"
)

View File

@@ -1,15 +1,19 @@
package autocert
import (
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"maps"
"os"
"path"
"path/filepath"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -27,21 +31,34 @@ import (
type (
Provider struct {
logger zerolog.Logger
cfg *Config
user *User
legoCfg *lego.Config
client *lego.Client
lastFailure time.Time
lastFailureFile string
legoCert *certificate.Resource
tlsCert *tls.Certificate
certExpiries CertExpiries
extraProviders []*Provider
sniMatcher sniMatcher
forceRenewalCh chan struct{}
forceRenewalDoneCh atomic.Value // chan struct{}
scheduleRenewalOnce sync.Once
}
CertExpiries map[string]time.Time
RenewMode uint8
)
var ErrGetCertFailure = errors.New("get certificate failed")
var ErrNoCertificate = errors.New("no certificate found")
const (
// renew failed for whatever reason, 1 hour cooldown
@@ -50,26 +67,57 @@ const (
requestCooldownDuration = 15 * time.Second
)
const (
renewModeForce = iota
renewModeIfNeeded
)
// could be nil
var ActiveProvider atomic.Pointer[Provider]
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) *Provider {
return &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
func NewProvider(cfg *Config, user *User, legoCfg *lego.Config) (*Provider, error) {
p := &Provider{
cfg: cfg,
user: user,
legoCfg: legoCfg,
lastFailureFile: lastFailureFileFor(cfg.CertPath, cfg.KeyPath),
forceRenewalCh: make(chan struct{}, 1),
}
p.forceRenewalDoneCh.Store(emptyForceRenewalDoneCh)
if cfg.idx == 0 {
p.logger = log.With().Str("provider", "main").Logger()
} else {
p.logger = log.With().Str("provider", fmt.Sprintf("extra[%d]", cfg.idx)).Logger()
}
if err := p.setupExtraProviders(); err != nil {
return nil, err
}
return p, nil
}
func (p *Provider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
func (p *Provider) GetCert(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if p.tlsCert == nil {
return nil, ErrGetCertFailure
return nil, ErrNoCertificate
}
if hello == nil || hello.ServerName == "" {
return p.tlsCert, nil
}
if prov := p.sniMatcher.match(hello.ServerName); prov != nil && prov.tlsCert != nil {
return prov.tlsCert, nil
}
return p.tlsCert, nil
}
func (p *Provider) GetName() string {
return p.cfg.Provider
if p.cfg.idx == 0 {
return "main"
}
return fmt.Sprintf("extra[%d]", p.cfg.idx)
}
func (p *Provider) fmtError(err error) error {
return gperr.PrependSubject(fmt.Sprintf("provider: %s", p.GetName()), err)
}
func (p *Provider) GetCertPath() string {
@@ -90,7 +138,7 @@ func (p *Provider) GetLastFailure() (time.Time, error) {
}
if p.lastFailure.IsZero() {
data, err := os.ReadFile(LastFailureFile)
data, err := os.ReadFile(p.lastFailureFile)
if err != nil {
if !os.IsNotExist(err) {
return time.Time{}, err
@@ -108,7 +156,7 @@ func (p *Provider) UpdateLastFailure() error {
}
t := time.Now()
p.lastFailure = t
return os.WriteFile(LastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600)
return os.WriteFile(p.lastFailureFile, t.AppendFormat(nil, time.RFC3339), 0o600)
}
func (p *Provider) ClearLastFailure() error {
@@ -116,29 +164,88 @@ func (p *Provider) ClearLastFailure() error {
return nil
}
p.lastFailure = time.Time{}
return os.Remove(LastFailureFile)
err := os.Remove(p.lastFailureFile)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return err
}
return nil
}
// allProviders returns all providers including this provider and all extra providers.
func (p *Provider) allProviders() []*Provider {
return append([]*Provider{p}, p.extraProviders...)
}
// ObtainCertIfNotExistsAll obtains a new certificate for this provider and all extra providers if they do not exist.
func (p *Provider) ObtainCertIfNotExistsAll() error {
errs := gperr.NewGroup("obtain cert error")
for _, provider := range p.allProviders() {
errs.Go(func() error {
if err := provider.obtainCertIfNotExists(); err != nil {
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
}
return nil
})
}
p.rebuildSNIMatcher()
return errs.Wait().Error()
}
// obtainCertIfNotExists obtains a new certificate for this provider if it does not exist.
func (p *Provider) obtainCertIfNotExists() error {
err := p.LoadCert()
if err == nil {
return nil
}
if !errors.Is(err, fs.ErrNotExist) {
return err
}
// check last failure
lastFailure, err := p.GetLastFailure()
if err != nil {
return fmt.Errorf("failed to get last failure: %w", err)
}
if !lastFailure.IsZero() && time.Since(lastFailure) < requestCooldownDuration {
return fmt.Errorf("still in cooldown until %s", strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
}
p.logger.Info().Msg("cert not found, obtaining new cert")
return p.ObtainCert()
}
// ObtainCertAll renews existing certificates or obtains new certificates for this provider and all extra providers.
func (p *Provider) ObtainCertAll() error {
errs := gperr.NewGroup("obtain cert error")
for _, provider := range p.allProviders() {
errs.Go(func() error {
if err := provider.obtainCertIfNotExists(); err != nil {
return fmt.Errorf("failed to obtain cert for %s: %w", provider.GetName(), err)
}
return nil
})
}
return errs.Wait().Error()
}
// ObtainCert renews existing certificate or obtains a new certificate for this provider.
func (p *Provider) ObtainCert() error {
if p.cfg.Provider == ProviderLocal {
return nil
}
if p.cfg.Provider == ProviderPseudo {
log.Info().Msg("init client for pseudo provider")
p.logger.Info().Msg("init client for pseudo provider")
<-time.After(time.Second)
log.Info().Msg("registering acme for pseudo provider")
p.logger.Info().Msg("registering acme for pseudo provider")
<-time.After(time.Second)
log.Info().Msg("obtained cert for pseudo provider")
p.logger.Info().Msg("obtained cert for pseudo provider")
return nil
}
if lastFailure, err := p.GetLastFailure(); err != nil {
return err
} else if time.Since(lastFailure) < requestCooldownDuration {
return fmt.Errorf("%w: still in cooldown until %s", ErrGetCertFailure, strutils.FormatTime(lastFailure.Add(requestCooldownDuration).Local()))
}
if p.client == nil {
if err := p.initClient(); err != nil {
return err
@@ -198,6 +305,7 @@ func (p *Provider) ObtainCert() error {
}
p.tlsCert = &tlsCert
p.certExpiries = expiries
p.rebuildSNIMatcher()
if err := p.ClearLastFailure(); err != nil {
return fmt.Errorf("failed to clear last failure: %w", err)
@@ -206,19 +314,37 @@ func (p *Provider) ObtainCert() error {
}
func (p *Provider) LoadCert() error {
var errs gperr.Builder
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
if err != nil {
return fmt.Errorf("load SSL certificate: %w", err)
errs.Addf("load SSL certificate: %w", p.fmtError(err))
}
expiries, err := getCertExpiries(&cert)
if err != nil {
return fmt.Errorf("parse SSL certificate: %w", err)
errs.Addf("parse SSL certificate: %w", p.fmtError(err))
}
p.tlsCert = &cert
p.certExpiries = expiries
log.Info().Msgf("next cert renewal in %s", strutils.FormatDuration(time.Until(p.ShouldRenewOn())))
return p.renewIfNeeded()
for _, ep := range p.extraProviders {
if err := ep.LoadCert(); err != nil {
errs.Add(err)
}
}
p.rebuildSNIMatcher()
return errs.Error()
}
// PrintCertExpiriesAll prints the certificate expiries for this provider and all extra providers.
func (p *Provider) PrintCertExpiriesAll() {
for _, provider := range p.allProviders() {
for domain, expiry := range provider.certExpiries {
p.logger.Info().Str("domain", domain).Msgf("certificate expire on %s", strutils.FormatTime(expiry))
}
}
}
// ShouldRenewOn returns the time at which the certificate should be renewed.
@@ -226,59 +352,126 @@ func (p *Provider) ShouldRenewOn() time.Time {
for _, expiry := range p.certExpiries {
return expiry.AddDate(0, -1, 0) // 1 month before
}
// this line should never be reached
panic("no certificate available")
// this line should never be reached in production, but will be useful for testing
return time.Now().AddDate(0, 1, 0) // 1 month after
}
func (p *Provider) ScheduleRenewal(parent task.Parent) {
// ForceExpiryAll triggers immediate certificate renewal for this provider and all extra providers.
// Returns true if the renewal was triggered, false if the renewal was dropped.
//
// If at least one renewal is triggered, returns true.
func (p *Provider) ForceExpiryAll() (ok bool) {
doneCh := make(chan struct{})
if swapped := p.forceRenewalDoneCh.CompareAndSwap(emptyForceRenewalDoneCh, doneCh); !swapped { // already in progress
close(doneCh)
return false
}
select {
case p.forceRenewalCh <- struct{}{}:
ok = true
default:
}
for _, ep := range p.extraProviders {
if ep.ForceExpiryAll() {
ok = true
}
}
return ok
}
// WaitRenewalDone waits for the renewal to complete.
// Returns false if the renewal was dropped.
func (p *Provider) WaitRenewalDone(ctx context.Context) bool {
done, ok := p.forceRenewalDoneCh.Load().(chan struct{})
if !ok || done == nil {
return false
}
select {
case <-done:
case <-ctx.Done():
return false
}
for _, ep := range p.extraProviders {
if !ep.WaitRenewalDone(ctx) {
return false
}
}
return true
}
// ScheduleRenewalAll schedules the renewal of the certificate for this provider and all extra providers.
func (p *Provider) ScheduleRenewalAll(parent task.Parent) {
p.scheduleRenewalOnce.Do(func() {
p.scheduleRenewal(parent)
})
for _, ep := range p.extraProviders {
ep.scheduleRenewalOnce.Do(func() {
ep.scheduleRenewal(parent)
})
}
}
var emptyForceRenewalDoneCh any = chan struct{}(nil)
// scheduleRenewal schedules the renewal of the certificate for this provider.
func (p *Provider) scheduleRenewal(parent task.Parent) {
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
return
}
go func() {
renewalTime := p.ShouldRenewOn()
timer := time.NewTimer(time.Until(renewalTime))
defer timer.Stop()
task := parent.Subtask("cert-renew-scheduler", true)
timer := time.NewTimer(time.Until(p.ShouldRenewOn()))
task := parent.Subtask("cert-renew-scheduler:"+filepath.Base(p.cfg.CertPath), true)
renew := func(renewMode RenewMode) {
defer func() {
if done, ok := p.forceRenewalDoneCh.Swap(emptyForceRenewalDoneCh).(chan struct{}); ok && done != nil {
close(done)
}
}()
renewed, err := p.renew(renewMode)
if err != nil {
gperr.LogWarn("autocert: cert renew failed", p.fmtError(err))
notif.Notify(&notif.LogMessage{
Level: zerolog.ErrorLevel,
Title: fmt.Sprintf("SSL certificate renewal failed for %s", p.GetName()),
Body: notif.MessageBody(err.Error()),
})
return
}
if renewed {
p.rebuildSNIMatcher()
notif.Notify(&notif.LogMessage{
Level: zerolog.InfoLevel,
Title: fmt.Sprintf("SSL certificate renewed for %s", p.GetName()),
Body: notif.ListBody(p.cfg.Domains),
})
// Reset on success
if err := p.ClearLastFailure(); err != nil {
gperr.LogWarn("autocert: failed to clear last failure", p.fmtError(err))
}
timer.Reset(time.Until(p.ShouldRenewOn()))
}
}
go func() {
defer timer.Stop()
defer task.Finish(nil)
for {
select {
case <-task.Context().Done():
return
case <-p.forceRenewalCh:
renew(renewModeForce)
case <-timer.C:
// Retry after 1 hour on failure
lastFailure, err := p.GetLastFailure()
if err != nil {
gperr.LogWarn("autocert: failed to get last failure", err)
continue
}
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
continue
}
if err := p.renewIfNeeded(); err != nil {
gperr.LogWarn("autocert: cert renew failed", err)
if err := p.UpdateLastFailure(); err != nil {
gperr.LogWarn("autocert: failed to update last failure", err)
}
notif.Notify(&notif.LogMessage{
Level: zerolog.ErrorLevel,
Title: "SSL certificate renewal failed",
Body: notif.MessageBody(err.Error()),
})
continue
}
notif.Notify(&notif.LogMessage{
Level: zerolog.InfoLevel,
Title: "SSL certificate renewed",
Body: notif.ListBody(p.cfg.Domains),
})
// Reset on success
if err := p.ClearLastFailure(); err != nil {
gperr.LogWarn("autocert: failed to clear last failure", err)
}
renewalTime = p.ShouldRenewOn()
timer.Reset(time.Until(renewalTime))
renew(renewModeIfNeeded)
}
}
}()
@@ -334,10 +527,10 @@ func (p *Provider) saveCert(cert *certificate.Resource) error {
}
/* This should have been done in setup
but double check is always a good choice.*/
_, err := os.Stat(path.Dir(p.cfg.CertPath))
_, err := os.Stat(filepath.Dir(p.cfg.CertPath))
if err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(path.Dir(p.cfg.CertPath), 0o755); err != nil {
if err = os.MkdirAll(filepath.Dir(p.cfg.CertPath), 0o755); err != nil {
return err
}
} else {
@@ -377,21 +570,42 @@ func (p *Provider) certState() CertState {
return CertStateValid
}
func (p *Provider) renewIfNeeded() error {
func (p *Provider) renew(mode RenewMode) (renewed bool, err error) {
if p.cfg.Provider == ProviderLocal {
return nil
return false, nil
}
switch p.certState() {
case CertStateExpired:
log.Info().Msg("certs expired, renewing")
case CertStateMismatch:
log.Info().Msg("cert domains mismatch with config, renewing")
default:
return nil
if mode != renewModeForce {
// Retry after 1 hour on failure
lastFailure, err := p.GetLastFailure()
if err != nil {
return false, fmt.Errorf("failed to get last failure: %w", err)
}
if !lastFailure.IsZero() && time.Since(lastFailure) < renewalCooldownDuration {
until := lastFailure.Add(renewalCooldownDuration).Local()
return false, fmt.Errorf("still in cooldown until %s", strutils.FormatTime(until))
}
}
return p.ObtainCert()
if mode == renewModeIfNeeded {
switch p.certState() {
case CertStateExpired:
log.Info().Msg("certs expired, renewing")
case CertStateMismatch:
log.Info().Msg("cert domains mismatch with config, renewing")
default:
return false, nil
}
}
if mode == renewModeForce {
log.Info().Msg("force renewing cert by user request")
}
if err := p.ObtainCert(); err != nil {
return false, err
}
return true, nil
}
func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
@@ -411,3 +625,21 @@ func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) {
}
return r, nil
}
func lastFailureFileFor(certPath, keyPath string) string {
dir := filepath.Dir(certPath)
sum := sha256.Sum256([]byte(certPath + "|" + keyPath))
return filepath.Join(dir, fmt.Sprintf(".last_failure-%x", sum[:6]))
}
func (p *Provider) rebuildSNIMatcher() {
if p.cfg.idx != 0 { // only main provider has extra providers
return
}
p.sniMatcher = sniMatcher{}
p.sniMatcher.addProvider(p)
for _, ep := range p.extraProviders {
p.sniMatcher.addProvider(ep)
}
}

View File

@@ -10,12 +10,15 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"sort"
"strings"
"sync"
"testing"
"time"
@@ -24,6 +27,368 @@ import (
"github.com/yusing/godoxy/internal/dnsproviders"
)
// TestACMEServer implements a minimal ACME server for testing with request tracking.
type TestACMEServer struct {
server *httptest.Server
caCert *x509.Certificate
caKey *rsa.PrivateKey
clientCSRs map[string]*x509.CertificateRequest
orderDomains map[string][]string
authzDomains map[string]string
orderSeq int
certRequestCount map[string]int
renewalRequestCount map[string]int
mu sync.Mutex
}
func newTestACMEServer(t *testing.T) *TestACMEServer {
t.Helper()
// Generate CA certificate and key
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caCert, err := x509.ParseCertificate(caCertDER)
require.NoError(t, err)
acme := &TestACMEServer{
caCert: caCert,
caKey: caKey,
clientCSRs: make(map[string]*x509.CertificateRequest),
orderDomains: make(map[string][]string),
authzDomains: make(map[string]string),
orderSeq: 0,
certRequestCount: make(map[string]int),
renewalRequestCount: make(map[string]int),
}
mux := http.NewServeMux()
acme.setupRoutes(mux)
acme.server = httptest.NewUnstartedServer(mux)
acme.server.TLS = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{caCert.Raw},
PrivateKey: caKey,
},
},
MinVersion: tls.VersionTLS12,
}
acme.server.StartTLS()
return acme
}
func (s *TestACMEServer) Close() {
s.server.Close()
}
func (s *TestACMEServer) URL() string {
return s.server.URL
}
func (s *TestACMEServer) httpClient() *http.Client {
certPool := x509.NewCertPool()
certPool.AddCert(s.caCert)
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
},
},
}
}
func (s *TestACMEServer) setupRoutes(mux *http.ServeMux) {
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
mux.HandleFunc("/acme/chall/", s.handleChallenge)
mux.HandleFunc("/acme/order/", s.handleOrder)
mux.HandleFunc("/acme/cert/", s.handleCertificate)
}
func (s *TestACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
directory := map[string]any{
"newNonce": s.server.URL + "/acme/new-nonce",
"newAccount": s.server.URL + "/acme/new-account",
"newOrder": s.server.URL + "/acme/new-order",
"keyChange": s.server.URL + "/acme/key-change",
"meta": map[string]any{
"termsOfService": s.server.URL + "/terms",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(directory)
}
func (s *TestACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "test-nonce-12345")
w.WriteHeader(http.StatusOK)
}
func (s *TestACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
account := map[string]any{
"status": "valid",
"contact": []string{"mailto:test@example.com"},
"orders": s.server.URL + "/acme/orders",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/account/1")
w.Header().Set("Replay-Nonce", "test-nonce-67890")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(account)
}
func (s *TestACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var jws struct {
Payload string `json:"payload"`
}
json.Unmarshal(body, &jws)
payloadBytes, _ := base64.RawURLEncoding.DecodeString(jws.Payload)
var orderReq struct {
Identifiers []map[string]string `json:"identifiers"`
}
json.Unmarshal(payloadBytes, &orderReq)
domains := []string{}
for _, id := range orderReq.Identifiers {
domains = append(domains, id["value"])
}
sort.Strings(domains)
domainKey := strings.Join(domains, ",")
s.mu.Lock()
s.orderSeq++
orderID := fmt.Sprintf("test-order-%d", s.orderSeq)
authzID := fmt.Sprintf("test-authz-%d", s.orderSeq)
s.orderDomains[orderID] = domains
if len(domains) > 0 {
s.authzDomains[authzID] = domains[0]
}
s.certRequestCount[domainKey]++
s.mu.Unlock()
order := map[string]any{
"status": "ready",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": orderReq.Identifiers,
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
"finalize": s.server.URL + "/acme/order/" + orderID + "/finalize",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/order/"+orderID)
w.Header().Set("Replay-Nonce", "test-nonce-order")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(order)
}
func (s *TestACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authzID := strings.TrimPrefix(r.URL.Path, "/acme/authz/")
domain := s.authzDomains[authzID]
if domain == "" {
domain = "test.example.com"
}
authz := map[string]any{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifier": map[string]string{"type": "dns", "value": domain},
"challenges": []map[string]any{
{
"type": "dns-01",
"status": "valid",
"url": s.server.URL + "/acme/chall/test-chall-789",
"token": "test-token-abc123",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-authz")
json.NewEncoder(w).Encode(authz)
}
func (s *TestACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
challenge := map[string]any{
"type": "dns-01",
"status": "valid",
"url": r.URL.String(),
"token": "test-token-abc123",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-chall")
json.NewEncoder(w).Encode(challenge)
}
func (s *TestACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/finalize") {
s.handleFinalize(w, r)
return
}
orderID := strings.TrimPrefix(r.URL.Path, "/acme/order/")
domains := s.orderDomains[orderID]
if len(domains) == 0 {
domains = []string{"test.example.com"}
}
certURL := s.server.URL + "/acme/cert/" + orderID
order := map[string]any{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": func() []map[string]string {
out := make([]map[string]string, 0, len(domains))
for _, d := range domains {
out = append(out, map[string]string{"type": "dns", "value": d})
}
return out
}(),
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
json.NewEncoder(w).Encode(order)
}
func (s *TestACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request", http.StatusBadRequest)
return
}
csr, err := s.extractCSRFromJWS(body)
if err != nil {
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
return
}
orderID := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/acme/order/"), "/finalize")
s.mu.Lock()
s.clientCSRs[orderID] = csr
// Detect renewal: if we already have a certificate for these domains, it's a renewal
domains := csr.DNSNames
sort.Strings(domains)
domainKey := strings.Join(domains, ",")
if s.certRequestCount[domainKey] > 1 {
s.renewalRequestCount[domainKey]++
}
s.mu.Unlock()
certURL := s.server.URL + "/acme/cert/" + orderID
order := map[string]any{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": func() []map[string]string {
out := make([]map[string]string, 0, len(domains))
for _, d := range domains {
out = append(out, map[string]string{"type": "dns", "value": d})
}
return out
}(),
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
json.NewEncoder(w).Encode(order)
}
func (s *TestACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
var jws struct {
Payload string `json:"payload"`
}
if err := json.Unmarshal(jwsData, &jws); err != nil {
return nil, err
}
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, err
}
var finalizeReq struct {
CSR string `json:"csr"`
}
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
return nil, err
}
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
if err != nil {
return nil, err
}
return x509.ParseCertificateRequest(csrBytes)
}
func (s *TestACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
csr, exists := s.clientCSRs[orderID]
if !exists {
http.Error(w, "No CSR found for order", http.StatusBadRequest)
return
}
template := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Cert"},
Country: []string{"US"},
},
DNSNames: csr.DNSNames,
NotBefore: time.Now(),
NotAfter: time.Now().Add(90 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
w.Header().Set("Content-Type", "application/pem-certificate-chain")
w.Header().Set("Replay-Nonce", "test-nonce-cert")
w.Write(append(certPEM, caPEM...))
}
func TestMain(m *testing.M) {
dnsproviders.InitProviders()
m.Run()
@@ -41,7 +406,7 @@ func TestCustomProvider(t *testing.T) {
ACMEKeyPath: "certs/custom-acme.key",
}
err := cfg.Validate()
err := error(cfg.Validate())
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
@@ -62,7 +427,8 @@ func TestCustomProvider(t *testing.T) {
err := cfg.Validate()
require.Error(t, err)
require.Contains(t, err.Error(), "missing field 'ca_dir_url'")
require.Contains(t, err.Error(), "missing field")
require.Contains(t, err.Error(), "ca_dir_url")
})
t.Run("custom provider with step-ca internal CA", func(t *testing.T) {
@@ -76,7 +442,7 @@ func TestCustomProvider(t *testing.T) {
ACMEKeyPath: "certs/internal-acme.key",
}
err := cfg.Validate()
err := error(cfg.Validate())
require.NoError(t, err)
user, legoCfg, err := cfg.GetLegoConfig()
@@ -86,9 +452,10 @@ func TestCustomProvider(t *testing.T) {
require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL)
require.Equal(t, "admin@internal.com", user.Email)
provider := autocert.NewProvider(cfg, user, legoCfg)
provider, err := autocert.NewProvider(cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
require.Equal(t, autocert.ProviderCustom, provider.GetName())
require.Equal(t, "main", provider.GetName())
require.Equal(t, "certs/internal.crt", provider.GetCertPath())
require.Equal(t, "certs/internal.key", provider.GetKeyPath())
})
@@ -119,7 +486,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
require.NotNil(t, user)
require.NotNil(t, legoCfg)
provider := autocert.NewProvider(cfg, user, legoCfg)
provider, err := autocert.NewProvider(cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
// Test obtaining certificate
@@ -161,7 +529,8 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
require.NotNil(t, user)
require.NotNil(t, legoCfg)
provider := autocert.NewProvider(cfg, user, legoCfg)
provider, err := autocert.NewProvider(cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
err = provider.ObtainCert()
@@ -178,330 +547,3 @@ func TestObtainCertFromCustomProvider(t *testing.T) {
require.True(t, time.Now().After(x509Cert.NotBefore))
})
}
// testACMEServer implements a minimal ACME server for testing.
type testACMEServer struct {
server *httptest.Server
caCert *x509.Certificate
caKey *rsa.PrivateKey
clientCSRs map[string]*x509.CertificateRequest
orderID string
}
func newTestACMEServer(t *testing.T) *testACMEServer {
t.Helper()
// Generate CA certificate and key
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caCert, err := x509.ParseCertificate(caCertDER)
require.NoError(t, err)
acme := &testACMEServer{
caCert: caCert,
caKey: caKey,
clientCSRs: make(map[string]*x509.CertificateRequest),
orderID: "test-order-123",
}
mux := http.NewServeMux()
acme.setupRoutes(mux)
acme.server = httptest.NewUnstartedServer(mux)
acme.server.TLS = &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{caCert.Raw},
PrivateKey: caKey,
},
},
MinVersion: tls.VersionTLS12,
}
acme.server.StartTLS()
return acme
}
func (s *testACMEServer) Close() {
s.server.Close()
}
func (s *testACMEServer) URL() string {
return s.server.URL
}
func (s *testACMEServer) httpClient() *http.Client {
certPool := x509.NewCertPool()
certPool.AddCert(s.caCert)
return &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 30 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
},
},
}
}
func (s *testACMEServer) setupRoutes(mux *http.ServeMux) {
// ACME directory endpoint
mux.HandleFunc("/acme/acme/directory", s.handleDirectory)
// ACME endpoints
mux.HandleFunc("/acme/new-nonce", s.handleNewNonce)
mux.HandleFunc("/acme/new-account", s.handleNewAccount)
mux.HandleFunc("/acme/new-order", s.handleNewOrder)
mux.HandleFunc("/acme/authz/", s.handleAuthorization)
mux.HandleFunc("/acme/chall/", s.handleChallenge)
mux.HandleFunc("/acme/order/", s.handleOrder)
mux.HandleFunc("/acme/cert/", s.handleCertificate)
}
func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) {
directory := map[string]interface{}{
"newNonce": s.server.URL + "/acme/new-nonce",
"newAccount": s.server.URL + "/acme/new-account",
"newOrder": s.server.URL + "/acme/new-order",
"keyChange": s.server.URL + "/acme/key-change",
"meta": map[string]interface{}{
"termsOfService": s.server.URL + "/terms",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(directory)
}
func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Replay-Nonce", "test-nonce-12345")
w.WriteHeader(http.StatusOK)
}
func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) {
account := map[string]interface{}{
"status": "valid",
"contact": []string{"mailto:test@example.com"},
"orders": s.server.URL + "/acme/orders",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/account/1")
w.Header().Set("Replay-Nonce", "test-nonce-67890")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(account)
}
func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) {
authzID := "test-authz-456"
order := map[string]interface{}{
"status": "ready", // Skip pending state for simplicity
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"authorizations": []string{s.server.URL + "/acme/authz/" + authzID},
"finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID)
w.Header().Set("Replay-Nonce", "test-nonce-order")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authz := map[string]interface{}{
"status": "valid", // Skip challenge validation for simplicity
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifier": map[string]string{"type": "dns", "value": "test.example.com"},
"challenges": []map[string]interface{}{
{
"type": "dns-01",
"status": "valid",
"url": s.server.URL + "/acme/chall/test-chall-789",
"token": "test-token-abc123",
},
},
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-authz")
json.NewEncoder(w).Encode(authz)
}
func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) {
challenge := map[string]interface{}{
"type": "dns-01",
"status": "valid",
"url": r.URL.String(),
"token": "test-token-abc123",
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-chall")
json.NewEncoder(w).Encode(challenge)
}
func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/finalize") {
s.handleFinalize(w, r)
return
}
certURL := s.server.URL + "/acme/cert/" + s.orderID
order := map[string]interface{}{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Replay-Nonce", "test-nonce-order-get")
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) {
// Read the JWS payload
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request", http.StatusBadRequest)
return
}
// Extract CSR from JWS payload
csr, err := s.extractCSRFromJWS(body)
if err != nil {
http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest)
return
}
// Store the CSR for certificate generation
s.clientCSRs[s.orderID] = csr
certURL := s.server.URL + "/acme/cert/" + s.orderID
order := map[string]interface{}{
"status": "valid",
"expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339),
"identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}},
"certificate": certURL,
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize"))
w.Header().Set("Replay-Nonce", "test-nonce-finalize")
json.NewEncoder(w).Encode(order)
}
func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) {
// Parse the JWS structure
var jws struct {
Protected string `json:"protected"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
if err := json.Unmarshal(jwsData, &jws); err != nil {
return nil, err
}
// Decode the payload
payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload)
if err != nil {
return nil, err
}
// Parse the finalize request
var finalizeReq struct {
CSR string `json:"csr"`
}
if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil {
return nil, err
}
// Decode the CSR
csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR)
if err != nil {
return nil, err
}
// Parse the CSR
csr, err := x509.ParseCertificateRequest(csrBytes)
if err != nil {
return nil, err
}
return csr, nil
}
func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) {
// Extract order ID from URL
orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/")
// Get the CSR for this order
csr, exists := s.clientCSRs[orderID]
if !exists {
http.Error(w, "No CSR found for order", http.StatusBadRequest)
return
}
// Create certificate using the public key from the client's CSR
template := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Cert"},
Country: []string{"US"},
},
DNSNames: csr.DNSNames,
NotBefore: time.Now(),
NotAfter: time.Now().Add(90 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
// Use the public key from the CSR and sign with CA key
certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Return certificate chain
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw})
w.Header().Set("Content-Type", "application/pem-certificate-chain")
w.Header().Set("Replay-Nonce", "test-nonce-cert")
w.Write(append(certPEM, caPEM...))
}

View File

@@ -0,0 +1,90 @@
//nolint:errchkjson,errcheck
package provider_test
import (
"fmt"
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/serialization"
"github.com/yusing/goutils/task"
)
func buildMultiCertYAML(serverURL string) []byte {
return fmt.Appendf(nil, `
email: main@example.com
domains: [main.example.com]
provider: custom
ca_dir_url: %s/acme/acme/directory
cert_path: certs/main.crt
key_path: certs/main.key
extra:
- email: extra1@example.com
domains: [extra1.example.com]
cert_path: certs/extra1.crt
key_path: certs/extra1.key
- email: extra2@example.com
domains: [extra2.example.com]
cert_path: certs/extra2.crt
key_path: certs/extra2.key
`, serverURL)
}
func TestMultipleCertificatesLifecycle(t *testing.T) {
acmeServer := newTestACMEServer(t)
defer acmeServer.Close()
yamlConfig := buildMultiCertYAML(acmeServer.URL())
var cfg autocert.Config
cfg.HTTPClient = acmeServer.httpClient()
/* unmarshal yaml config with multiple certs */
err := error(serialization.UnmarshalValidateYAML(yamlConfig, &cfg))
require.NoError(t, err)
require.Equal(t, []string{"main.example.com"}, cfg.Domains)
require.Len(t, cfg.Extra, 2)
require.Equal(t, []string{"extra1.example.com"}, cfg.Extra[0].Domains)
require.Equal(t, []string{"extra2.example.com"}, cfg.Extra[1].Domains)
var provider *autocert.Provider
/* initialize autocert with multi-cert config */
user, legoCfg, gerr := cfg.GetLegoConfig()
require.NoError(t, gerr)
provider, err = autocert.NewProvider(&cfg, user, legoCfg)
require.NoError(t, err)
require.NotNil(t, provider)
// Start renewal scheduler
root := task.RootTask("test", false)
defer root.Finish(nil)
provider.ScheduleRenewalAll(root)
require.Equal(t, "custom", cfg.Provider)
require.Equal(t, "custom", cfg.Extra[0].Provider)
require.Equal(t, "custom", cfg.Extra[1].Provider)
/* track cert requests for all configs */
os.MkdirAll("certs", 0755)
defer os.RemoveAll("certs")
err = provider.ObtainCertIfNotExistsAll()
require.NoError(t, err)
require.Equal(t, 1, acmeServer.certRequestCount["main.example.com"])
require.Equal(t, 1, acmeServer.certRequestCount["extra1.example.com"])
require.Equal(t, 1, acmeServer.certRequestCount["extra2.example.com"])
/* track renewal scheduling and requests */
// force renewal for all providers and wait for completion
ok := provider.ForceExpiryAll()
require.True(t, ok)
provider.WaitRenewalDone(t.Context())
require.Equal(t, 1, acmeServer.renewalRequestCount["main.example.com"])
require.Equal(t, 1, acmeServer.renewalRequestCount["extra1.example.com"])
require.Equal(t, 1, acmeServer.renewalRequestCount["extra2.example.com"])
}

View File

@@ -0,0 +1,416 @@
package provider_test
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
)
func writeSelfSignedCert(t *testing.T, dir string, dnsNames []string) (string, string) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
require.NoError(t, err)
cn := ""
if len(dnsNames) > 0 {
cn = dnsNames[0]
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: cn,
},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: dnsNames,
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
require.NoError(t, os.WriteFile(certPath, certPEM, 0o644))
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600))
return certPath, keyPath
}
func TestGetCertBySNI(t *testing.T) {
t.Run("extra cert used when main does not match", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir := t.TempDir()
extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"*.internal.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "a.internal.example.com"})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "*.internal.example.com")
})
t.Run("exact match wins over wildcard match", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir := t.TempDir()
extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "foo.example.com")
})
t.Run("main cert fallback when no match", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir := t.TempDir()
extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"*.test.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "unknown.domain.com"})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "*.example.com")
})
t.Run("nil ServerName returns main cert", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(nil)
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "*.example.com")
})
t.Run("empty ServerName returns main cert", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: ""})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "*.example.com")
})
t.Run("case insensitive matching", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir := t.TempDir()
extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"Foo.Example.COM"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "FOO.EXAMPLE.COM"})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "Foo.Example.COM")
})
t.Run("normalization with trailing dot and whitespace", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir := t.TempDir()
extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: " foo.example.com. "})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "foo.example.com")
})
t.Run("longest wildcard match wins", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir1 := t.TempDir()
extraCert1, extraKey1 := writeSelfSignedCert(t, extraDir1, []string{"*.a.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert1, KeyPath: extraKey1},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.a.example.com"})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "*.a.example.com")
})
t.Run("main cert wildcard match", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
require.NoError(t, err)
leaf, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf.DNSNames, "*.example.com")
})
t.Run("multiple extra certs", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir1 := t.TempDir()
extraCert1, extraKey1 := writeSelfSignedCert(t, extraDir1, []string{"*.test.com"})
extraDir2 := t.TempDir()
extraCert2, extraKey2 := writeSelfSignedCert(t, extraDir2, []string{"*.dev.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert1, KeyPath: extraKey1},
{CertPath: extraCert2, KeyPath: extraKey2},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.test.com"})
require.NoError(t, err)
leaf1, err := x509.ParseCertificate(cert1.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf1.DNSNames, "*.test.com")
cert2, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.dev.com"})
require.NoError(t, err)
leaf2, err := x509.ParseCertificate(cert2.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf2.DNSNames, "*.dev.com")
})
t.Run("multiple DNSNames in cert", func(t *testing.T) {
mainDir := t.TempDir()
mainCert, mainKey := writeSelfSignedCert(t, mainDir, []string{"*.example.com"})
extraDir := t.TempDir()
extraCert, extraKey := writeSelfSignedCert(t, extraDir, []string{"foo.example.com", "bar.example.com", "*.test.com"})
cfg := &autocert.Config{
Provider: autocert.ProviderLocal,
CertPath: mainCert,
KeyPath: mainKey,
Extra: []autocert.ConfigExtra{
{CertPath: extraCert, KeyPath: extraKey},
},
}
require.NoError(t, cfg.Validate())
p, err := autocert.NewProvider(cfg, nil, nil)
require.NoError(t, err)
err = p.LoadCert()
require.NoError(t, err)
cert1, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "foo.example.com"})
require.NoError(t, err)
leaf1, err := x509.ParseCertificate(cert1.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf1.DNSNames, "foo.example.com")
cert2, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "bar.example.com"})
require.NoError(t, err)
leaf2, err := x509.ParseCertificate(cert2.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf2.DNSNames, "bar.example.com")
cert3, err := p.GetCert(&tls.ClientHelloInfo{ServerName: "baz.test.com"})
require.NoError(t, err)
leaf3, err := x509.ParseCertificate(cert3.Certificate[0])
require.NoError(t, err)
require.Contains(t, leaf3.DNSNames, "*.test.com")
})
}

View File

@@ -1,28 +1,30 @@
package autocert
import (
"errors"
"os"
"github.com/rs/zerolog/log"
strutils "github.com/yusing/goutils/strings"
gperr "github.com/yusing/goutils/errs"
)
func (p *Provider) Setup() (err error) {
if err = p.LoadCert(); err != nil {
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
return err
}
log.Debug().Msg("obtaining cert due to error loading cert")
if err = p.ObtainCert(); err != nil {
return err
}
func (p *Provider) setupExtraProviders() gperr.Error {
p.sniMatcher = sniMatcher{}
if len(p.cfg.Extra) == 0 {
return nil
}
for _, expiry := range p.GetExpiries() {
log.Info().Msg("certificate expire on " + strutils.FormatTime(expiry))
break
}
p.extraProviders = make([]*Provider, 0, len(p.cfg.Extra))
return nil
errs := gperr.NewBuilder("setup extra providers error")
for _, extra := range p.cfg.Extra {
user, legoCfg, err := extra.AsConfig().GetLegoConfig()
if err != nil {
errs.Add(p.fmtError(err))
continue
}
ep, err := NewProvider(extra.AsConfig(), user, legoCfg)
if err != nil {
errs.Add(p.fmtError(err))
continue
}
p.extraProviders = append(p.extraProviders, ep)
}
return errs.Error()
}

View File

@@ -0,0 +1,82 @@
package autocert_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/dnsproviders"
"github.com/yusing/godoxy/internal/serialization"
strutils "github.com/yusing/goutils/strings"
)
func TestSetupExtraProviders(t *testing.T) {
dnsproviders.InitProviders()
cfgYAML := `
email: test@example.com
domains: [example.com]
provider: custom
ca_dir_url: https://ca.example.com:9000/acme/acme/directory
cert_path: certs/test.crt
key_path: certs/test.key
options: {key: value}
resolvers: [8.8.8.8]
ca_certs: [ca.crt]
eab_kid: eabKid
eab_hmac: eabHmac
extra:
- cert_path: certs/extra.crt
key_path: certs/extra.key
- cert_path: certs/extra2.crt
key_path: certs/extra2.key
email: override@example.com
provider: pseudo
domains: [override.com]
ca_dir_url: https://ca2.example.com/directory
options: {opt2: val2}
resolvers: [1.1.1.1]
ca_certs: [ca2.crt]
eab_kid: eabKid2
eab_hmac: eabHmac2
`
var cfg autocert.Config
err := error(serialization.UnmarshalValidateYAML([]byte(cfgYAML), &cfg))
require.NoError(t, err)
// Test: extra[0] inherits all fields from main except CertPath and KeyPath.
merged0 := cfg.Extra[0]
require.Equal(t, "certs/extra.crt", merged0.CertPath)
require.Equal(t, "certs/extra.key", merged0.KeyPath)
// Inherited fields from main config:
require.Equal(t, "test@example.com", merged0.Email) // inherited
require.Equal(t, "custom", merged0.Provider) // inherited
require.Equal(t, []string{"example.com"}, merged0.Domains) // inherited
require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", merged0.CADirURL) // inherited
require.Equal(t, map[string]strutils.Redacted{"key": "value"}, merged0.Options) // inherited
require.Equal(t, []string{"8.8.8.8"}, merged0.Resolvers) // inherited
require.Equal(t, []string{"ca.crt"}, merged0.CACerts) // inherited
require.Equal(t, "eabKid", merged0.EABKid) // inherited
require.Equal(t, "eabHmac", merged0.EABHmac) // inherited
require.Equal(t, cfg.HTTPClient, merged0.HTTPClient) // inherited
require.Nil(t, merged0.Extra)
// Test: extra[1] overrides some fields, and inherits others.
merged1 := cfg.Extra[1]
require.Equal(t, "certs/extra2.crt", merged1.CertPath)
require.Equal(t, "certs/extra2.key", merged1.KeyPath)
// Overridden fields:
require.Equal(t, "override@example.com", merged1.Email) // overridden
require.Equal(t, "pseudo", merged1.Provider) // overridden
require.Equal(t, []string{"override.com"}, merged1.Domains) // overridden
require.Equal(t, "https://ca2.example.com/directory", merged1.CADirURL) // overridden
require.Equal(t, map[string]strutils.Redacted{"opt2": "val2"}, merged1.Options) // overridden
require.Equal(t, []string{"1.1.1.1"}, merged1.Resolvers) // overridden
require.Equal(t, []string{"ca2.crt"}, merged1.CACerts) // overridden
require.Equal(t, "eabKid2", merged1.EABKid) // overridden
require.Equal(t, "eabHmac2", merged1.EABHmac) // overridden
// Inherited field:
require.Equal(t, cfg.HTTPClient, merged1.HTTPClient) // inherited
require.Nil(t, merged1.Extra)
}

View File

@@ -0,0 +1,129 @@
package autocert
import (
"crypto/x509"
"strings"
)
type sniMatcher struct {
exact map[string]*Provider
root sniTreeNode
}
type sniTreeNode struct {
children map[string]*sniTreeNode
wildcard *Provider
}
func (m *sniMatcher) match(serverName string) *Provider {
if m == nil {
return nil
}
serverName = normalizeServerName(serverName)
if serverName == "" {
return nil
}
if m.exact != nil {
if p, ok := m.exact[serverName]; ok {
return p
}
}
return m.matchSuffixTree(serverName)
}
func (m *sniMatcher) matchSuffixTree(serverName string) *Provider {
n := &m.root
labels := strings.Split(serverName, ".")
var best *Provider
for i := len(labels) - 1; i >= 0; i-- {
if n.children == nil {
break
}
next := n.children[labels[i]]
if next == nil {
break
}
n = next
consumed := len(labels) - i
remaining := len(labels) - consumed
if remaining == 1 && n.wildcard != nil {
best = n.wildcard
}
}
return best
}
func normalizeServerName(s string) string {
s = strings.TrimSpace(s)
s = strings.TrimSuffix(s, ".")
return strings.ToLower(s)
}
func (m *sniMatcher) addProvider(p *Provider) {
if p == nil || p.tlsCert == nil || len(p.tlsCert.Certificate) == 0 {
return
}
leaf, err := x509.ParseCertificate(p.tlsCert.Certificate[0])
if err != nil {
return
}
addName := func(name string) {
name = normalizeServerName(name)
if name == "" {
return
}
if after, ok := strings.CutPrefix(name, "*."); ok {
suffix := after
if suffix == "" {
return
}
m.insertWildcardSuffix(suffix, p)
return
}
m.insertExact(name, p)
}
if leaf.Subject.CommonName != "" {
addName(leaf.Subject.CommonName)
}
for _, n := range leaf.DNSNames {
addName(n)
}
}
func (m *sniMatcher) insertExact(name string, p *Provider) {
if name == "" || p == nil {
return
}
if m.exact == nil {
m.exact = make(map[string]*Provider)
}
if _, exists := m.exact[name]; !exists {
m.exact[name] = p
}
}
func (m *sniMatcher) insertWildcardSuffix(suffix string, p *Provider) {
if suffix == "" || p == nil {
return
}
n := &m.root
labels := strings.Split(suffix, ".")
for i := len(labels) - 1; i >= 0; i-- {
if n.children == nil {
n.children = make(map[string]*sniTreeNode)
}
next := n.children[labels[i]]
if next == nil {
next = &sniTreeNode{}
n.children[labels[i]] = next
}
n = next
}
if n.wildcard == nil {
n.wildcard = p
}
}

View File

@@ -0,0 +1,104 @@
package autocert
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
)
func createTLSCert(dnsNames []string) (*tls.Certificate, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
if err != nil {
return nil, err
}
cn := ""
if len(dnsNames) > 0 {
cn = dnsNames[0]
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: cn,
},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: dnsNames,
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return nil, err
}
return &tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: key,
}, nil
}
func BenchmarkSNIMatcher(b *testing.B) {
matcher := sniMatcher{}
wildcard1Cert, err := createTLSCert([]string{"*.example.com"})
if err != nil {
b.Fatal(err)
}
wildcard1 := &Provider{tlsCert: wildcard1Cert}
wildcard2Cert, err := createTLSCert([]string{"*.test.com"})
if err != nil {
b.Fatal(err)
}
wildcard2 := &Provider{tlsCert: wildcard2Cert}
wildcard3Cert, err := createTLSCert([]string{"*.foo.com"})
if err != nil {
b.Fatal(err)
}
wildcard3 := &Provider{tlsCert: wildcard3Cert}
exact1Cert, err := createTLSCert([]string{"bar.example.com"})
if err != nil {
b.Fatal(err)
}
exact1 := &Provider{tlsCert: exact1Cert}
exact2Cert, err := createTLSCert([]string{"baz.test.com"})
if err != nil {
b.Fatal(err)
}
exact2 := &Provider{tlsCert: exact2Cert}
matcher.addProvider(wildcard1)
matcher.addProvider(wildcard2)
matcher.addProvider(wildcard3)
matcher.addProvider(exact1)
matcher.addProvider(exact2)
b.Run("MatchWildcard", func(b *testing.B) {
for b.Loop() {
_ = matcher.match("sub.example.com")
}
})
b.Run("MatchExact", func(b *testing.B) {
for b.Loop() {
_ = matcher.match("bar.example.com")
}
})
}

View File

@@ -9,6 +9,6 @@ import (
type Provider interface {
Setup() error
GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error)
ScheduleRenewal(task.Parent)
ObtainCert() error
ScheduleRenewalAll(task.Parent)
ObtainCertAll() error
}

View File

@@ -13,6 +13,8 @@ var (
IsDebug = env.GetEnvBool("DEBUG", IsTest)
IsTrace = env.GetEnvBool("TRACE", false) && IsDebug
ShortLinkPrefix = env.GetEnvString("SHORTLINK_PREFIX", "go")
ProxyHTTPAddr,
ProxyHTTPHost,
ProxyHTTPPort,

View File

@@ -3,12 +3,10 @@ package config
import (
"errors"
"fmt"
"io/fs"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/common"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/notif"
@@ -62,11 +60,6 @@ func Load() error {
cfgWatcher = watcher.NewConfigFileWatcher(common.ConfigFileName)
initErr := state.InitFromFile(common.ConfigPath)
if errors.Is(initErr, fs.ErrNotExist) {
// log only
log.Warn().Msg("config file not found, using default config")
initErr = nil
}
err := errors.Join(initErr, state.StartProviders())
if err != nil {
logNotifyError("init", err)

View File

@@ -3,7 +3,11 @@ package config
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"iter"
"net/http"
"os"
@@ -17,7 +21,6 @@ import (
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/internal/acl"
"github.com/yusing/godoxy/internal/autocert"
"github.com/yusing/godoxy/internal/common"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/entrypoint"
homepage "github.com/yusing/godoxy/internal/homepage/types"
@@ -90,10 +93,13 @@ func Value() *config.Config {
}
func (state *state) InitFromFile(filename string) error {
data, err := os.ReadFile(common.ConfigPath)
data, err := os.ReadFile(filename)
if err != nil {
state.Config = config.DefaultConfig()
return err
if errors.Is(err, fs.ErrNotExist) {
state.Config = config.DefaultConfig()
} else {
return err
}
}
return state.Init(data)
}
@@ -134,6 +140,10 @@ func (state *state) EntrypointHandler() http.Handler {
return &state.entrypoint
}
func (state *state) ShortLinkMatcher() config.ShortLinkMatcher {
return state.entrypoint.ShortLinkMatcher()
}
// AutoCertProvider returns the autocert provider.
//
// If the autocert provider is not configured, it returns nil.
@@ -191,18 +201,52 @@ func (state *state) initAccessLogger() error {
}
func (state *state) initEntrypoint() error {
epCfg := state.Entrypoint
epCfg := state.Config.Entrypoint
matchDomains := state.MatchDomains
state.entrypoint.SetFindRouteDomains(matchDomains)
state.entrypoint.SetNotFoundRules(epCfg.Rules.NotFound)
if len(matchDomains) > 0 {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix(matchDomains[0])
}
if state.autocertProvider != nil {
if domain := getAutoCertDefaultDomain(state.autocertProvider); domain != "" {
state.entrypoint.ShortLinkMatcher().SetDefaultDomainSuffix("." + domain)
}
}
errs := gperr.NewBuilder("entrypoint error")
errs.Add(state.entrypoint.SetMiddlewares(epCfg.Middlewares))
errs.Add(state.entrypoint.SetAccessLogger(state.task, epCfg.AccessLog))
return errs.Error()
}
func getAutoCertDefaultDomain(p *autocert.Provider) string {
if p == nil {
return ""
}
cert, err := tls.LoadX509KeyPair(p.GetCertPath(), p.GetKeyPath())
if err != nil || len(cert.Certificate) == 0 {
return ""
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return ""
}
domain := x509Cert.Subject.CommonName
if domain == "" && len(x509Cert.DNSNames) > 0 {
domain = x509Cert.DNSNames[0]
}
domain = strings.TrimSpace(domain)
if after, ok := strings.CutPrefix(domain, "*."); ok {
domain = after
}
return strings.ToLower(domain)
}
func (state *state) initMaxMind() error {
maxmindCfg := state.Providers.MaxMind
if maxmindCfg != nil {
@@ -228,6 +272,7 @@ func (state *state) initAutoCert() error {
autocertCfg := state.AutoCert
if autocertCfg == nil {
autocertCfg = new(autocert.Config)
_ = autocertCfg.Validate()
}
user, legoCfg, err := autocertCfg.GetLegoConfig()
@@ -235,12 +280,19 @@ func (state *state) initAutoCert() error {
return err
}
state.autocertProvider = autocert.NewProvider(autocertCfg, user, legoCfg)
if err := state.autocertProvider.Setup(); err != nil {
return fmt.Errorf("autocert error: %w", err)
} else {
state.autocertProvider.ScheduleRenewal(state.task)
p, err := autocert.NewProvider(autocertCfg, user, legoCfg)
if err != nil {
return err
}
if err := p.ObtainCertIfNotExistsAll(); err != nil {
return err
}
p.ScheduleRenewalAll(state.task)
p.PrintCertExpiriesAll()
state.autocertProvider = p
return nil
}
@@ -252,7 +304,7 @@ func (state *state) initProxmox() error {
errs := gperr.NewBuilder()
for _, cfg := range proxmoxCfg {
if err := cfg.Init(); err != nil {
if err := cfg.Init(state.task.Context()); err != nil {
errs.Add(err.Subject(cfg.URL))
}
}

View File

@@ -22,6 +22,7 @@ type State interface {
Value() *Config
EntrypointHandler() http.Handler
ShortLinkMatcher() ShortLinkMatcher
AutoCertProvider() server.CertProvider
LoadOrStoreProvider(key string, value types.RouteProvider) (actual types.RouteProvider, loaded bool)
@@ -33,6 +34,12 @@ type State interface {
FlushTmpLog()
}
type ShortLinkMatcher interface {
AddRoute(alias string)
DelRoute(alias string)
ServeHTTP(w http.ResponseWriter, r *http.Request)
}
// could be nil before first call on Load
var ActiveState synk.Value[State]

View File

@@ -1,7 +1,7 @@
package dnsproviders
type (
DummyConfig struct{}
DummyConfig map[string]any
DummyProvider struct{}
)

View File

@@ -91,14 +91,14 @@ func IsBlacklisted(c *types.Container) bool {
return IsBlacklistedImage(c.Image) || isDatabase(c)
}
func UpdatePorts(c *types.Container) error {
func UpdatePorts(ctx context.Context, c *types.Container) error {
dockerClient, err := NewClient(c.DockerCfg)
if err != nil {
return err
}
defer dockerClient.Close()
inspect, err := dockerClient.ContainerInspect(context.Background(), c.ContainerID, client.ContainerInspectOptions{})
inspect, err := dockerClient.ContainerInspect(ctx, c.ContainerID, client.ContainerInspectOptions{})
if err != nil {
return err
}

View File

@@ -6,6 +6,7 @@ import (
"sync/atomic"
"github.com/rs/zerolog/log"
"github.com/yusing/godoxy/internal/common"
entrypoint "github.com/yusing/godoxy/internal/entrypoint/types"
"github.com/yusing/godoxy/internal/logging/accesslog"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
@@ -21,6 +22,7 @@ type Entrypoint struct {
notFoundHandler http.Handler
accessLogger accesslog.AccessLogger
findRouteFunc func(host string) types.HTTPRoute
shortLinkTree *ShortLinkMatcher
}
// nil-safe
@@ -34,9 +36,14 @@ func init() {
func NewEntrypoint() Entrypoint {
return Entrypoint{
findRouteFunc: findRouteAnyDomain,
shortLinkTree: newShortLinkTree(),
}
}
func (ep *Entrypoint) ShortLinkMatcher() *ShortLinkMatcher {
return ep.shortLinkTree
}
func (ep *Entrypoint) SetFindRouteDomains(domains []string) {
if len(domains) == 0 {
ep.findRouteFunc = findRouteAnyDomain
@@ -90,9 +97,12 @@ func (ep *Entrypoint) FindRoute(s string) types.HTTPRoute {
func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if ep.accessLogger != nil {
rec := accesslog.NewResponseRecorder(w)
rec := accesslog.GetResponseRecorder(w)
w = rec
defer ep.accessLogger.Log(r, rec.Response())
defer func() {
ep.accessLogger.Log(r, rec.Response())
accesslog.PutResponseRecorder(rec)
}()
}
route := ep.findRouteFunc(r.Host)
@@ -104,6 +114,8 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else {
route.ServeHTTP(w, r)
}
case ep.tryHandleShortLink(w, r):
return
case ep.notFoundHandler != nil:
ep.notFoundHandler.ServeHTTP(w, r)
default:
@@ -111,6 +123,22 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (ep *Entrypoint) tryHandleShortLink(w http.ResponseWriter, r *http.Request) (handled bool) {
host := r.Host
if before, _, ok := strings.Cut(host, ":"); ok {
host = before
}
if strings.EqualFold(host, common.ShortLinkPrefix) {
if ep.middleware != nil {
ep.middleware.ServeHTTP(ep.shortLinkTree.ServeHTTP, w, r)
} else {
ep.shortLinkTree.ServeHTTP(w, r)
}
return true
}
return false
}
func (ep *Entrypoint) serveNotFound(w http.ResponseWriter, r *http.Request) {
// Why use StatusNotFound instead of StatusBadRequest or StatusBadGateway?
// On nginx, when route for domain does not exist, it returns StatusBadGateway.

View File

@@ -0,0 +1,110 @@
package entrypoint
import (
"net/http"
"strings"
"github.com/puzpuzpuz/xsync/v4"
)
type ShortLinkMatcher struct {
defaultDomainSuffix string // e.g. ".example.com"
fqdnRoutes *xsync.Map[string, string] // "app" -> "app.example.com"
subdomainRoutes *xsync.Map[string, struct{}]
}
func newShortLinkTree() *ShortLinkMatcher {
return &ShortLinkMatcher{
fqdnRoutes: xsync.NewMap[string, string](),
subdomainRoutes: xsync.NewMap[string, struct{}](),
}
}
func (st *ShortLinkMatcher) SetDefaultDomainSuffix(suffix string) {
if !strings.HasPrefix(suffix, ".") {
suffix = "." + suffix
}
st.defaultDomainSuffix = suffix
}
func (st *ShortLinkMatcher) AddRoute(alias string) {
alias = strings.TrimSpace(alias)
if alias == "" {
return
}
if strings.Contains(alias, ".") { // FQDN alias
st.fqdnRoutes.Store(alias, alias)
key, _, _ := strings.Cut(alias, ".")
if key != "" {
if _, ok := st.subdomainRoutes.Load(key); !ok {
if _, ok := st.fqdnRoutes.Load(key); !ok {
st.fqdnRoutes.Store(key, alias)
}
}
}
return
}
// subdomain alias + defaultDomainSuffix
if st.defaultDomainSuffix == "" {
return
}
st.subdomainRoutes.Store(alias, struct{}{})
}
func (st *ShortLinkMatcher) DelRoute(alias string) {
alias = strings.TrimSpace(alias)
if alias == "" {
return
}
if strings.Contains(alias, ".") {
st.fqdnRoutes.Delete(alias)
key, _, _ := strings.Cut(alias, ".")
if key != "" {
if target, ok := st.fqdnRoutes.Load(key); ok && target == alias {
st.fqdnRoutes.Delete(key)
}
}
return
}
st.subdomainRoutes.Delete(alias)
}
func (st *ShortLinkMatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.EscapedPath()
trim := strings.TrimPrefix(path, "/")
key, rest, _ := strings.Cut(trim, "/")
if key == "" {
http.Error(w, "short link key is required", http.StatusBadRequest)
return
}
if rest != "" {
rest = "/" + rest
} else {
rest = "/"
}
targetHost := ""
if strings.Contains(key, ".") {
targetHost, _ = st.fqdnRoutes.Load(key)
} else if target, ok := st.fqdnRoutes.Load(key); ok {
targetHost = target
} else if _, ok := st.subdomainRoutes.Load(key); ok && st.defaultDomainSuffix != "" {
targetHost = key + st.defaultDomainSuffix
}
if targetHost == "" {
http.Error(w, "short link not found", http.StatusNotFound)
return
}
targetURL := "https://" + targetHost + rest
if q := r.URL.RawQuery; q != "" {
targetURL += "?" + q
}
http.Redirect(w, r, targetURL, http.StatusTemporaryRedirect)
}

View File

@@ -0,0 +1,194 @@
package entrypoint_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/yusing/godoxy/internal/common"
. "github.com/yusing/godoxy/internal/entrypoint"
)
func TestShortLinkMatcher_FQDNAlias(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.AddRoute("app.domain.com")
t.Run("exact path", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/", w.Header().Get("Location"))
})
t.Run("with path remainder", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo/bar", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/foo/bar", w.Header().Get("Location"))
})
t.Run("with query", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo?x=y&z=1", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/foo?x=y&z=1", w.Header().Get("Location"))
})
}
func TestShortLinkMatcher_SubdomainAlias(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.SetDefaultDomainSuffix(".example.com")
matcher.AddRoute("app")
t.Run("exact path", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/", w.Header().Get("Location"))
})
t.Run("with path remainder", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app/foo/bar", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/foo/bar", w.Header().Get("Location"))
})
}
func TestShortLinkMatcher_NotFound(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.SetDefaultDomainSuffix(".example.com")
matcher.AddRoute("app")
t.Run("missing key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
})
t.Run("unknown key", func(t *testing.T) {
req := httptest.NewRequest("GET", "/unknown", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
})
}
func TestShortLinkMatcher_AddDelRoute(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
matcher.SetDefaultDomainSuffix(".example.com")
matcher.AddRoute("app1")
matcher.AddRoute("app2.domain.com")
t.Run("both routes work", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app1", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app1.example.com/", w.Header().Get("Location"))
req = httptest.NewRequest("GET", "/app2.domain.com", nil)
w = httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app2.domain.com/", w.Header().Get("Location"))
})
t.Run("delete route", func(t *testing.T) {
matcher.DelRoute("app1")
req := httptest.NewRequest("GET", "/app1", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
req = httptest.NewRequest("GET", "/app2.domain.com", nil)
w = httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app2.domain.com/", w.Header().Get("Location"))
})
}
func TestShortLinkMatcher_NoDefaultDomainSuffix(t *testing.T) {
ep := NewEntrypoint()
matcher := ep.ShortLinkMatcher()
// no SetDefaultDomainSuffix called
t.Run("subdomain alias ignored", func(t *testing.T) {
matcher.AddRoute("app")
req := httptest.NewRequest("GET", "/app", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
})
t.Run("FQDN alias still works", func(t *testing.T) {
matcher.AddRoute("app.domain.com")
req := httptest.NewRequest("GET", "/app.domain.com", nil)
w := httptest.NewRecorder()
matcher.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.domain.com/", w.Header().Get("Location"))
})
}
func TestEntrypoint_ShortLinkDispatch(t *testing.T) {
ep := NewEntrypoint()
ep.ShortLinkMatcher().SetDefaultDomainSuffix(".example.com")
ep.ShortLinkMatcher().AddRoute("app")
t.Run("shortlink host", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
req.Host = common.ShortLinkPrefix
w := httptest.NewRecorder()
ep.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/", w.Header().Get("Location"))
})
t.Run("shortlink host with port", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
req.Host = common.ShortLinkPrefix + ":8080"
w := httptest.NewRecorder()
ep.ServeHTTP(w, req)
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
assert.Equal(t, "https://app.example.com/", w.Header().Get("Location"))
})
t.Run("normal host", func(t *testing.T) {
req := httptest.NewRequest("GET", "/app", nil)
req.Host = "app.example.com"
w := httptest.NewRecorder()
ep.ServeHTTP(w, req)
// Should not redirect, should try normal route lookup (which will 404)
assert.NotEqual(t, http.StatusTemporaryRedirect, w.Code)
})
}

View File

@@ -25,14 +25,14 @@ const proxmoxStateCheckInterval = 1 * time.Second
var ErrNodeNotFound = gperr.New("node not found in pool")
func NewProxmoxProvider(nodeName string, vmid int) (idlewatcher.Provider, error) {
func NewProxmoxProvider(ctx context.Context, nodeName string, vmid int) (idlewatcher.Provider, error) {
node, ok := proxmox.Nodes.Get(nodeName)
if !ok {
return nil, ErrNodeNotFound.Subject(nodeName).
Withf("available nodes: %s", proxmox.AvailableNodeNames())
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
lxcName, err := node.LXCName(ctx, vmid)

View File

@@ -259,7 +259,7 @@ func NewWatcher(parent task.Parent, r types.Route, cfg *types.IdlewatcherConfig)
p, err = provider.NewDockerProvider(cfg.Docker.DockerCfg, cfg.Docker.ContainerID)
kind = "docker"
default:
p, err = provider.NewProxmoxProvider(cfg.Proxmox.Node, cfg.Proxmox.VMID)
p, err = provider.NewProxmoxProvider(parent.Context(), cfg.Proxmox.Node, cfg.Proxmox.VMID)
kind = "proxmox"
}
targetURL := r.TargetURL()

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/http"
"sync"
)
type ResponseRecorder struct {
@@ -13,14 +14,30 @@ type ResponseRecorder struct {
resp http.Response
}
func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
return &ResponseRecorder{
w: w,
resp: http.Response{
StatusCode: http.StatusOK,
Header: w.Header(),
},
var recorderPool = sync.Pool{
New: func() any {
return &ResponseRecorder{}
},
}
func GetResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
r := recorderPool.Get().(*ResponseRecorder)
r.w = w
r.resp = http.Response{
StatusCode: http.StatusOK,
Header: w.Header(),
}
return r
}
func PutResponseRecorder(r *ResponseRecorder) {
r.w = nil
r.resp = http.Response{}
recorderPool.Put(r)
}
func NewResponseRecorder(w http.ResponseWriter) *ResponseRecorder {
return GetResponseRecorder(w)
}
func (w *ResponseRecorder) Unwrap() http.ResponseWriter {

View File

@@ -106,7 +106,7 @@ func TestReverseProxyBypass(t *testing.T) {
rp := reverseproxy.NewReverseProxy("test", url, fakeRoundTripper{})
err = PatchReverseProxy(rp, map[string]OptionsRaw{
"response": {
"bypass": "path glob(/test/*) | path /api",
"bypass": []string{"path glob(/test/*)", "path /api"},
"set_headers": map[string]string{
"Test-Header": "test-value",
},

View File

@@ -32,6 +32,9 @@ func setup() {
}
func GetStaticFile(filename string) ([]byte, bool) {
if common.IsTest {
return nil, false
}
setupOnce.Do(setup)
return fileContentMap.Load(filename)
}

View File

@@ -16,7 +16,7 @@ func NewTransport() *http.Transport {
Proxy: http.ProxyFromEnvironment,
DialContext: DefaultDialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConnsPerHost: 100,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,

View File

@@ -32,7 +32,7 @@ func (c *Config) Client() *Client {
return c.client
}
func (c *Config) Init() gperr.Error {
func (c *Config) Init(ctx context.Context) gperr.Error {
var tr *http.Transport
if c.NoTLSVerify {
// user specified
@@ -56,7 +56,7 @@ func (c *Config) Init() gperr.Error {
}
c.client = NewClient(c.URL, opts...)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
if err := c.client.UpdateClusterInfo(ctx); err != nil {

View File

@@ -6,6 +6,7 @@ import (
"path"
"path/filepath"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/logging/accesslog"
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
"github.com/yusing/godoxy/internal/net/gphttp/middleware"
@@ -124,8 +125,14 @@ func (s *FileServer) Start(parent task.Parent) gperr.Error {
}
routes.HTTP.Add(s)
if state := config.WorkingState.Load(); state != nil {
state.ShortLinkMatcher().AddRoute(s.Alias)
}
s.task.OnFinished("remove_route_from_http", func() {
routes.HTTP.Del(s)
if state := config.WorkingState.Load(); state != nil {
state.ShortLinkMatcher().DelRoute(s.Alias)
}
})
return nil
}

View File

@@ -79,7 +79,7 @@ func (p *DockerProvider) loadRoutesImpl() (route.Routes, gperr.Error) {
}
if container.IsHostNetworkMode {
err := docker.UpdatePorts(container)
err := docker.UpdatePorts(ctx, container)
if err != nil {
errs.Add(gperr.PrependSubject(container.ContainerName, err))
continue

View File

@@ -6,6 +6,7 @@ import (
"github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/agent/pkg/agentproxy"
config "github.com/yusing/godoxy/internal/config/types"
"github.com/yusing/godoxy/internal/idlewatcher"
"github.com/yusing/godoxy/internal/logging/accesslog"
gphttp "github.com/yusing/godoxy/internal/net/gphttp"
@@ -64,23 +65,25 @@ func NewReverseProxyRoute(base *Route) (*ReveseProxyRoute, gperr.Error) {
scheme := base.Scheme
retried := false
retryLock := sync.Mutex{}
rp.OnSchemeMisMatch = func() (retry bool) { // switch scheme and retry
retryLock.Lock()
defer retryLock.Unlock()
if scheme == route.SchemeHTTP || scheme == route.SchemeHTTPS {
rp.OnSchemeMisMatch = func() (retry bool) { // switch scheme and retry
retryLock.Lock()
defer retryLock.Unlock()
if retried {
return false
if retried {
return false
}
retried = true
if scheme == route.SchemeHTTP {
rp.TargetURL.Scheme = "https"
} else {
rp.TargetURL.Scheme = "http"
}
rp.Info().Msgf("scheme mismatch detected, retrying with %s", rp.TargetURL.Scheme)
return true
}
retried = true
if scheme == route.SchemeHTTP {
rp.TargetURL.Scheme = "https"
} else {
rp.TargetURL.Scheme = "http"
}
rp.Info().Msgf("scheme mismatch detected, retrying with %s", rp.TargetURL.Scheme)
return true
}
if len(base.Middlewares) > 0 {
@@ -164,8 +167,14 @@ func (r *ReveseProxyRoute) Start(parent task.Parent) gperr.Error {
r.addToLoadBalancer(parent)
} else {
routes.HTTP.Add(r)
r.task.OnCancel("remove_route_from_http", func() {
if state := config.WorkingState.Load(); state != nil {
state.ShortLinkMatcher().AddRoute(r.Alias)
}
r.task.OnCancel("remove_route", func() {
routes.HTTP.Del(r)
if state := config.WorkingState.Load(); state != nil {
state.ShortLinkMatcher().DelRoute(r.Alias)
}
})
}
return nil
@@ -206,8 +215,14 @@ func (r *ReveseProxyRoute) addToLoadBalancer(parent task.Parent) {
}
linked.SetHealthMonitor(lb)
routes.HTTP.AddKey(cfg.Link, linked)
if state := config.WorkingState.Load(); state != nil {
state.ShortLinkMatcher().AddRoute(cfg.Link)
}
r.task.OnFinished("remove_loadbalancer_route", func() {
routes.HTTP.DelKey(cfg.Link)
if state := config.WorkingState.Load(); state != nil {
state.ShortLinkMatcher().DelRoute(cfg.Link)
}
})
lbLock.Unlock()
}

View File

@@ -43,7 +43,7 @@ type (
_ utils.NoCopy
Alias string `json:"alias"`
Scheme route.Scheme `json:"scheme,omitempty" swaggertype:"string" enums:"http,https,tcp,udp,fileserver"`
Scheme route.Scheme `json:"scheme,omitempty" swaggertype:"string" enums:"http,https,h2c,tcp,udp,fileserver"`
Host string `json:"host,omitempty"`
Port route.Port `json:"port"`
@@ -271,7 +271,7 @@ func (r *Route) validate() gperr.Error {
r.ProxyURL = gperr.Collect(&errs, nettypes.ParseURL, "file://"+r.Root)
r.Host = ""
r.Port.Proxy = 0
case route.SchemeHTTP, route.SchemeHTTPS:
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
if r.Port.Listening != 0 {
errs.Addf("unexpected listening port for %s scheme", r.Scheme)
}
@@ -294,7 +294,7 @@ func (r *Route) validate() gperr.Error {
switch r.Scheme {
case route.SchemeFileServer:
impl, err = NewFileServer(r)
case route.SchemeHTTP, route.SchemeHTTPS:
case route.SchemeHTTP, route.SchemeHTTPS, route.SchemeH2C:
impl, err = NewReverseProxyRoute(r)
case route.SchemeTCP, route.SchemeUDP:
impl, err = NewStreamRoute(r)
@@ -788,6 +788,15 @@ func (r *Route) Finalize() {
}
r.Port.Listening, r.Port.Proxy = lp, pp
workingState := config.WorkingState.Load()
if workingState == nil {
if common.IsTest { // in tests, working state might be nil
return
}
panic("bug: working state is nil")
}
r.HealthCheck.ApplyDefaults(config.WorkingState.Load().Value().Defaults.HealthCheck)
}

View File

@@ -30,7 +30,7 @@ func (r *RouteContext) Value(key any) any {
func WithRouteContext(r *http.Request, route types.HTTPRoute) *http.Request {
// we don't want to copy the request object every fucking requests
// return r.WithContext(context.WithValue(r.Context(), routeContextKey, route))
ctxFieldPtr := (*context.Context)(unsafe.Pointer(uintptr(unsafe.Pointer(r)) + ctxFieldOffset))
ctxFieldPtr := (*context.Context)(unsafe.Add(unsafe.Pointer(r), ctxFieldOffset))
*ctxFieldPtr = &RouteContext{
Context: r.Context(),
Route: route,

View File

@@ -17,6 +17,8 @@ type HealthInfoWithoutDetail struct {
Latency time.Duration `json:"latency" swaggertype:"number"` // latency in microseconds
} // @name HealthInfoWithoutDetail
type HealthMap = map[string]types.HealthStatusString // @name HealthMap
// GetHealthInfo returns a map of route name to health info.
//
// The health info is for all routes, including excluded routes.
@@ -39,6 +41,14 @@ func GetHealthInfoWithoutDetail() map[string]HealthInfoWithoutDetail {
return healthMap
}
func GetHealthInfoSimple() map[string]types.HealthStatus {
healthMap := make(map[string]types.HealthStatus, NumAllRoutes())
for r := range IterAll {
healthMap[r.Name()] = getHealthInfoSimple(r)
}
return healthMap
}
func getHealthInfo(r types.Route) HealthInfo {
mon := r.HealthMonitor()
if mon == nil {
@@ -73,6 +83,14 @@ func getHealthInfoWithoutDetail(r types.Route) HealthInfoWithoutDetail {
}
}
func getHealthInfoSimple(r types.Route) types.HealthStatus {
mon := r.HealthMonitor()
if mon == nil {
return types.StatusUnknown
}
return mon.Status()
}
// ByProvider returns a map of provider name to routes.
//
// The routes are all routes, including excluded routes.

View File

@@ -270,7 +270,7 @@ func TestLogCommand_ConditionalLogging(t *testing.T) {
errorContent, err := os.ReadFile(errorFile.Name())
require.NoError(t, err)
errorLines := strings.Split(strings.TrimSpace(string(errorContent)), "\n")
assert.Len(t, errorLines, 2)
require.Len(t, errorLines, 2)
assert.Equal(t, "ERROR: GET /notfound 404", errorLines[0])
assert.Equal(t, "ERROR: POST /error 500", errorLines[1])
}
@@ -368,7 +368,7 @@ func TestLogCommand_FilePermissions(t *testing.T) {
logContent := strings.TrimSpace(string(content))
lines := strings.Split(logContent, "\n")
assert.Len(t, lines, 2)
require.Len(t, lines, 2)
assert.Equal(t, "GET 200", lines[0])
assert.Equal(t, "POST 200", lines[1])
}

View File

@@ -14,16 +14,18 @@ var ErrInvalidScheme = gperr.New("invalid scheme")
const (
SchemeHTTP Scheme = 1 << iota
SchemeHTTPS
SchemeH2C
SchemeTCP
SchemeUDP
SchemeFileServer
SchemeNone Scheme = 0
schemeReverseProxy = SchemeHTTP | SchemeHTTPS
schemeReverseProxy = SchemeHTTP | SchemeHTTPS | SchemeH2C
schemeStream = SchemeTCP | SchemeUDP
schemeStrHTTP = "http"
schemeStrHTTPS = "https"
schemeStrH2C = "h2c"
schemeStrTCP = "tcp"
schemeStrUDP = "udp"
schemeStrFileServer = "fileserver"
@@ -36,6 +38,8 @@ func (s Scheme) String() string {
return schemeStrHTTP
case SchemeHTTPS:
return schemeStrHTTPS
case SchemeH2C:
return schemeStrH2C
case SchemeTCP:
return schemeStrTCP
case SchemeUDP:
@@ -66,6 +70,8 @@ func (s *Scheme) Parse(v string) error {
*s = SchemeHTTP
case schemeStrHTTPS:
*s = SchemeHTTPS
case schemeStrH2C:
*s = SchemeH2C
case schemeStrTCP:
*s = SchemeTCP
case schemeStrUDP:

View File

@@ -20,7 +20,7 @@ type DockerProviderConfig struct {
} // @name DockerProviderConfig
type DockerProviderConfigDetailed struct {
Scheme string `json:"scheme,omitempty" validate:"required,oneof=http https tls"`
Scheme string `json:"scheme,omitempty" validate:"required,oneof=http https tcp tls unix ssh"`
Host string `json:"host,omitempty" validate:"required,hostname|ip"`
Port int `json:"port,omitempty" validate:"required,min=1,max=65535"`
TLS *DockerTLSConfig `json:"tls" validate:"omitempty"`
@@ -48,12 +48,14 @@ func (cfg *DockerProviderConfig) Parse(value string) error {
}
switch u.Scheme {
case "http", "https", "tls":
case "http", "https", "tcp", "tls":
cfg.URL = u.String()
case "unix", "ssh":
cfg.URL = value
default:
return fmt.Errorf("invalid scheme: %s", u.Scheme)
}
cfg.URL = u.String()
return nil
}

View File

@@ -1,7 +1,6 @@
package types
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
@@ -27,7 +26,7 @@ test:
ca_file: /etc/ssl/ca.crt
cert_file: /etc/ssl/cert.crt
key_file: /etc/ssl/key.crt`), &cfg)
assert.Error(t, err, os.ErrNotExist)
assert.NoError(t, err)
assert.Equal(t, &DockerProviderConfig{URL: "http://localhost:2375", TLS: &DockerTLSConfig{CAFile: "/etc/ssl/ca.crt", CertFile: "/etc/ssl/cert.crt", KeyFile: "/etc/ssl/key.crt"}}, cfg["test"])
})
}
@@ -38,7 +37,12 @@ func TestDockerProviderConfigValidation(t *testing.T) {
yamlStr string
wantErr bool
}{
{name: "valid url", yamlStr: "test: http://localhost:2375", wantErr: false},
{name: "valid url (http)", yamlStr: "test: http://localhost:2375", wantErr: false},
{name: "valid url (https)", yamlStr: "test: https://localhost:2375", wantErr: false},
{name: "valid url (tcp)", yamlStr: "test: tcp://localhost:2375", wantErr: false},
{name: "valid url (tls)", yamlStr: "test: tls://localhost:2375", wantErr: false},
{name: "valid url (unix)", yamlStr: "test: unix:///var/run/docker.sock", wantErr: false},
{name: "valid url (ssh)", yamlStr: "test: ssh://localhost:2375", wantErr: false},
{name: "invalid url", yamlStr: "test: ftp://localhost/2375", wantErr: true},
{name: "valid scheme", yamlStr: `
test:

View File

@@ -8,12 +8,12 @@ import (
"time"
"github.com/bytedance/sonic"
strutils "github.com/yusing/goutils/strings"
"github.com/yusing/goutils/task"
)
type (
HealthStatus uint8
HealthStatus uint8 // @name HealthStatus
HealthStatusString string // @name HealthStatusString
HealthCheckResult struct {
Healthy bool `json:"healthy"`
@@ -45,20 +45,16 @@ type (
HealthChecker
}
HealthJSON struct {
Name string `json:"name"`
Config *HealthCheckConfig `json:"config"`
Started int64 `json:"started"`
StartedStr string `json:"startedStr"`
Status string `json:"status"`
Uptime float64 `json:"uptime"`
UptimeStr string `json:"uptimeStr"`
Latency float64 `json:"latency"`
LatencyStr string `json:"latencyStr"`
LastSeen int64 `json:"lastSeen"`
LastSeenStr string `json:"lastSeenStr"`
Detail string `json:"detail"`
URL string `json:"url"`
Extra *HealthExtra `json:"extra,omitempty" extensions:"x-nullable"`
Name string `json:"name"`
Config *HealthCheckConfig `json:"config"`
Started int64 `json:"started"` // unix timestamp in seconds
Status HealthStatusString `json:"status"`
Uptime float64 `json:"uptime"` // uptime in seconds
Latency int64 `json:"latency"` // latency in milliseconds
LastSeen int64 `json:"lastSeen"` // unix timestamp in seconds
Detail string `json:"detail"`
URL string `json:"url"`
Extra *HealthExtra `json:"extra,omitempty" extensions:"x-nullable"`
} // @name HealthJSON
HealthJSONRepr struct {
@@ -88,12 +84,12 @@ const (
StatusUnhealthy
StatusError
StatusUnknownStr = "unknown"
StatusHealthyStr = "healthy"
StatusNappingStr = "napping"
StatusStartingStr = "starting"
StatusUnhealthyStr = "unhealthy"
StatusErrorStr = "error"
StatusUnknownStr HealthStatusString = "unknown"
StatusHealthyStr HealthStatusString = "healthy"
StatusNappingStr HealthStatusString = "napping"
StatusStartingStr HealthStatusString = "starting"
StatusUnhealthyStr HealthStatusString = "unhealthy"
StatusErrorStr HealthStatusString = "error"
NumStatuses int = iota - 1
@@ -102,15 +98,15 @@ const (
)
var (
StatusHealthyStr2 = strconv.Itoa(int(StatusHealthy))
StatusNappingStr2 = strconv.Itoa(int(StatusNapping))
StatusStartingStr2 = strconv.Itoa(int(StatusStarting))
StatusUnhealthyStr2 = strconv.Itoa(int(StatusUnhealthy))
StatusErrorStr2 = strconv.Itoa(int(StatusError))
StatusHealthyStr2 HealthStatusString = HealthStatusString(strconv.Itoa(int(StatusHealthy)))
StatusNappingStr2 HealthStatusString = HealthStatusString(strconv.Itoa(int(StatusNapping)))
StatusStartingStr2 HealthStatusString = HealthStatusString(strconv.Itoa(int(StatusStarting)))
StatusUnhealthyStr2 HealthStatusString = HealthStatusString(strconv.Itoa(int(StatusUnhealthy)))
StatusErrorStr2 HealthStatusString = HealthStatusString(strconv.Itoa(int(StatusError)))
)
func NewHealthStatusFromString(s string) HealthStatus {
switch s {
switch HealthStatusString(s) {
case StatusHealthyStr, StatusHealthyStr2:
return StatusHealthy
case StatusUnhealthyStr, StatusUnhealthyStr2:
@@ -126,7 +122,7 @@ func NewHealthStatusFromString(s string) HealthStatus {
}
}
func (s HealthStatus) String() string {
func (s HealthStatus) StatusString() HealthStatusString {
switch s {
case StatusHealthy:
return StatusHealthyStr
@@ -143,6 +139,11 @@ func (s HealthStatus) String() string {
}
}
// String implements fmt.Stringer.
func (s HealthStatus) String() string {
return string(s.StatusString())
}
func (s HealthStatus) Good() bool {
return s&HealthyMask != 0
}
@@ -178,19 +179,15 @@ func (jsonRepr *HealthJSONRepr) MarshalJSON() ([]byte, error) {
url = ""
}
return sonic.Marshal(HealthJSON{
Name: jsonRepr.Name,
Config: jsonRepr.Config,
Started: jsonRepr.Started.Unix(),
StartedStr: strutils.FormatTime(jsonRepr.Started),
Status: jsonRepr.Status.String(),
Uptime: jsonRepr.Uptime.Seconds(),
UptimeStr: strutils.FormatDuration(jsonRepr.Uptime),
Latency: jsonRepr.Latency.Seconds(),
LatencyStr: strconv.Itoa(int(jsonRepr.Latency.Milliseconds())) + " ms",
LastSeen: jsonRepr.LastSeen.Unix(),
LastSeenStr: strutils.FormatLastSeen(jsonRepr.LastSeen),
Detail: jsonRepr.Detail,
URL: url,
Extra: jsonRepr.Extra,
Name: jsonRepr.Name,
Config: jsonRepr.Config,
Started: jsonRepr.Started.Unix(),
Status: HealthStatusString(jsonRepr.Status.String()),
Uptime: jsonRepr.Uptime.Seconds(),
Latency: jsonRepr.Latency.Milliseconds(),
LastSeen: jsonRepr.LastSeen.Unix(),
Detail: jsonRepr.Detail,
URL: url,
Extra: jsonRepr.Extra,
})
}

View File

@@ -1,243 +0,0 @@
package utils
import (
"reflect"
"unsafe"
)
// DeepEqual reports whether x and y are deeply equal.
// It supports numerics, strings, maps, slices, arrays, and structs (exported fields only).
// It's optimized for performance by avoiding reflection for common types and
// adaptively choosing between BFS and DFS traversal strategies.
func DeepEqual(x, y any) bool {
if x == nil || y == nil {
return x == y
}
v1 := reflect.ValueOf(x)
v2 := reflect.ValueOf(y)
if v1.Type() != v2.Type() {
return false
}
return deepEqual(v1, v2, make(map[visit]bool), 0)
}
// visit represents a visit to a pair of values during comparison
type visit struct {
a1, a2 unsafe.Pointer
typ reflect.Type
}
// deepEqual performs the actual deep comparison with cycle detection
func deepEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
// Handle cycle detection for pointer-like types
if v1.CanAddr() && v2.CanAddr() {
addr1 := unsafe.Pointer(v1.UnsafeAddr())
addr2 := unsafe.Pointer(v2.UnsafeAddr())
typ := v1.Type()
v := visit{addr1, addr2, typ}
if visited[v] {
return true // already visiting, assume equal
}
visited[v] = true
defer delete(visited, v)
}
switch v1.Kind() {
case reflect.Bool:
return v1.Bool() == v2.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v1.Uint() == v2.Uint()
case reflect.Float32, reflect.Float64:
return floatEqual(v1.Float(), v2.Float())
case reflect.Complex64, reflect.Complex128:
c1, c2 := v1.Complex(), v2.Complex()
return floatEqual(real(c1), real(c2)) && floatEqual(imag(c1), imag(c2))
case reflect.String:
return v1.String() == v2.String()
case reflect.Array:
return deepEqualArray(v1, v2, visited, depth)
case reflect.Slice:
return deepEqualSlice(v1, v2, visited, depth)
case reflect.Map:
return deepEqualMap(v1, v2, visited, depth)
case reflect.Struct:
return deepEqualStruct(v1, v2, visited, depth)
case reflect.Ptr:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() && v2.IsNil()
}
return deepEqual(v1.Elem(), v2.Elem(), visited, depth+1)
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() && v2.IsNil()
}
return deepEqual(v1.Elem(), v2.Elem(), visited, depth+1)
default:
// For unsupported types (func, chan, etc.), fall back to basic equality
return v1.Interface() == v2.Interface()
}
}
// floatEqual handles NaN cases properly
func floatEqual(f1, f2 float64) bool {
return f1 == f2 || (f1 != f1 && f2 != f2) // NaN == NaN
}
// deepEqualArray compares arrays using DFS (since arrays have fixed size)
func deepEqualArray(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
for i := range v1.Len() {
if !deepEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
}
// deepEqualSlice compares slices, choosing strategy based on size and depth
func deepEqualSlice(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if v1.IsNil() != v2.IsNil() {
return false
}
if v1.Len() != v2.Len() {
return false
}
if v1.IsNil() {
return true
}
// Use BFS for large slices at shallow depth to improve cache locality
// Use DFS for small slices or deep nesting to reduce memory overhead
if shouldUseBFS(v1.Len(), depth) {
return deepEqualSliceBFS(v1, v2, visited, depth)
}
return deepEqualSliceDFS(v1, v2, visited, depth)
}
// deepEqualSliceDFS uses depth-first traversal
func deepEqualSliceDFS(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
for i := range v1.Len() {
if !deepEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
return false
}
}
return true
}
// deepEqualSliceBFS uses breadth-first traversal for better cache locality
func deepEqualSliceBFS(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
length := v1.Len()
// First, check all direct elements
for i := range length {
elem1, elem2 := v1.Index(i), v2.Index(i)
// For simple types, compare directly
if isSimpleType(elem1.Kind()) {
if !deepEqual(elem1, elem2, visited, depth+1) {
return false
}
}
}
// Then, recursively check complex elements
for i := range length {
elem1, elem2 := v1.Index(i), v2.Index(i)
if !isSimpleType(elem1.Kind()) {
if !deepEqual(elem1, elem2, visited, depth+1) {
return false
}
}
}
return true
}
// deepEqualMap compares maps
func deepEqualMap(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if v1.IsNil() != v2.IsNil() {
return false
}
if v1.Len() != v2.Len() {
return false
}
if v1.IsNil() {
return true
}
// Check all keys and values
for _, key := range v1.MapKeys() {
val1 := v1.MapIndex(key)
val2 := v2.MapIndex(key)
if !val2.IsValid() {
return false // key doesn't exist in v2
}
if !deepEqual(val1, val2, visited, depth+1) {
return false
}
}
return true
}
// deepEqualStruct compares structs (exported fields only)
func deepEqualStruct(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
typ := v1.Type()
for i := range typ.NumField() {
field := typ.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
if !deepEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
return false
}
}
return true
}
// shouldUseBFS determines whether to use BFS or DFS based on slice size and depth
func shouldUseBFS(length, depth int) bool {
// Use BFS for large slices at shallow depth (better cache locality)
// Use DFS for small slices or deep nesting (lower memory overhead)
return length > 100 && depth < 3
}
// isSimpleType checks if a type can be compared without deep recursion
func isSimpleType(kind reflect.Kind) bool {
if kind >= reflect.Bool && kind <= reflect.Complex128 {
return true
}
return kind == reflect.String
}

View File

@@ -1,14 +1,18 @@
package monitor
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"net/url"
"time"
"github.com/valyala/fasthttp"
"github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/version"
"golang.org/x/net/http2"
)
type HTTPHealthMonitor struct {
@@ -17,8 +21,6 @@ type HTTPHealthMonitor struct {
}
var pinger = &fasthttp.Client{
ReadTimeout: 5 * time.Second,
WriteTimeout: 3 * time.Second,
MaxConnDuration: 0,
DisableHeaderNamesNormalizing: true,
DisablePathNormalizing: true,
@@ -42,41 +44,30 @@ func NewHTTPHealthMonitor(url *url.URL, config types.HealthCheckConfig) *HTTPHea
var userAgent = "GoDoxy/" + version.Get().String()
func (mon *HTTPHealthMonitor) CheckHealth() (types.HealthCheckResult, error) {
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
func setCommonHeaders(setHeader func(key, value string)) {
setHeader("User-Agent", userAgent)
setHeader("Accept", "text/plain,text/html,*/*;q=0.8")
setHeader("Accept-Encoding", "identity")
setHeader("Cache-Control", "no-cache")
setHeader("Pragma", "no-cache")
}
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(mon.url.Load().JoinPath(mon.config.Path).String())
req.Header.SetMethod(mon.method)
req.Header.Set("User-Agent", userAgent)
req.Header.Set("Accept", "text/plain,text/html,*/*;q=0.8")
req.Header.Set("Accept-Encoding", "identity")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache")
req.SetConnectionClose()
start := time.Now()
respErr := pinger.DoTimeout(req, resp, mon.config.Timeout)
lat := time.Since(start)
if respErr != nil {
// treat TLS error as healthy
func processHealthResponse(lat time.Duration, err error, getStatusCode func() int) (types.HealthCheckResult, error) {
if err != nil {
var tlsErr *tls.CertificateVerificationError
if ok := errors.As(respErr, &tlsErr); !ok {
if ok := errors.As(err, &tlsErr); !ok {
return types.HealthCheckResult{
Latency: lat,
Detail: respErr.Error(),
Detail: err.Error(),
}, nil
}
}
if status := resp.StatusCode(); status >= 500 && status < 600 {
statusCode := getStatusCode()
if statusCode >= 500 && statusCode < 600 {
return types.HealthCheckResult{
Latency: lat,
Detail: fasthttp.StatusMessage(resp.StatusCode()),
Detail: http.StatusText(statusCode),
}, nil
}
@@ -85,3 +76,73 @@ func (mon *HTTPHealthMonitor) CheckHealth() (types.HealthCheckResult, error) {
Healthy: true,
}, nil
}
var h2cClient = &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
},
},
}
func (mon *HTTPHealthMonitor) CheckHealth() (types.HealthCheckResult, error) {
if mon.url.Load().Scheme == "h2c" {
return mon.CheckHealthH2C()
}
return mon.CheckHealthHTTP()
}
func (mon *HTTPHealthMonitor) CheckHealthHTTP() (types.HealthCheckResult, error) {
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(mon.url.Load().JoinPath(mon.config.Path).String())
req.Header.SetMethod(mon.method)
setCommonHeaders(req.Header.Set)
req.SetConnectionClose()
start := time.Now()
respErr := pinger.DoTimeout(req, resp, mon.config.Timeout)
lat := time.Since(start)
return processHealthResponse(lat, respErr, resp.StatusCode)
}
func (mon *HTTPHealthMonitor) CheckHealthH2C() (types.HealthCheckResult, error) {
u := mon.url.Load()
u = u.JoinPath(mon.config.Path) // JoinPath returns a copy of the URL with the path joined
u.Scheme = "http"
ctx, cancel := mon.ContextWithTimeout("h2c health check timed out")
defer cancel()
var req *http.Request
var err error
if mon.method == fasthttp.MethodGet {
req, err = http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
} else {
req, err = http.NewRequestWithContext(ctx, http.MethodHead, u.String(), nil)
}
if err != nil {
return types.HealthCheckResult{
Detail: err.Error(),
}, nil
}
setCommonHeaders(req.Header.Set)
start := time.Now()
resp, err := h2cClient.Do(req)
lat := time.Since(start)
if resp != nil {
defer resp.Body.Close()
}
return processHealthResponse(lat, err, func() int { return resp.StatusCode })
}

View File

@@ -295,7 +295,7 @@ func (mon *monitor) notifyServiceUp(logger *zerolog.Logger, result *types.Health
}
func (mon *monitor) notifyServiceDown(logger *zerolog.Logger, result *types.HealthCheckResult) {
logger.Warn().Msg("service went down")
logger.Warn().Str("detail", result.Detail).Msg("service went down")
extras := mon.buildNotificationExtras(result)
extras.Add("Last Seen", strutils.FormatLastSeen(GetLastSeen(mon.service)))
mon.notifyFunc(&notif.LogMessage{

199
scripts/benchmark.sh Normal file
View File

@@ -0,0 +1,199 @@
#!/bin/bash
# Benchmark script to compare GoDoxy, Traefik, Caddy, and Nginx
# Uses wrk for HTTP load testing
set -e
# Configuration
HOST="bench.domain.com"
DURATION="${DURATION:-10s}"
THREADS="${THREADS:-4}"
CONNECTIONS="${CONNECTIONS:-100}"
TARGET="${TARGET-}"
# Color functions for output
red() { echo -e "\033[0;31m$*\033[0m"; }
green() { echo -e "\033[0;32m$*\033[0m"; }
yellow() { echo -e "\033[1;33m$*\033[0m"; }
blue() { echo -e "\033[0;34m$*\033[0m"; }
# Check if wrk is installed
if ! command -v wrk &>/dev/null; then
red "Error: wrk is not installed"
echo "Please install wrk:"
echo " Ubuntu/Debian: sudo apt-get install wrk"
echo " macOS: brew install wrk"
echo " Or build from source: https://github.com/wg/wrk"
exit 1
fi
if ! command -v h2load &>/dev/null; then
red "Error: h2load is not installed"
echo "Please install h2load (nghttp2-client):"
echo " Ubuntu/Debian: sudo apt-get install nghttp2-client"
echo " macOS: brew install nghttp2"
exit 1
fi
OUTFILE="/tmp/reverse_proxy_benchmark_$(date +%Y%m%d_%H%M%S).log"
: >"$OUTFILE"
exec > >(tee -a "$OUTFILE") 2>&1
blue "========================================"
blue "Reverse Proxy Benchmark Comparison"
blue "========================================"
echo ""
echo "Target: $HOST"
echo "Duration: $DURATION"
echo "Threads: $THREADS"
echo "Connections: $CONNECTIONS"
if [ -n "$TARGET" ]; then
echo "Filter: $TARGET"
fi
echo ""
# Define services to test
declare -A services=(
["GoDoxy"]="http://127.0.0.1:8080"
["Traefik"]="http://127.0.0.1:8081"
["Caddy"]="http://127.0.0.1:8082"
["Nginx"]="http://127.0.0.1:8083"
)
# Array to store connection errors
declare -a connection_errors=()
# Function to test connection before benchmarking
test_connection() {
local name=$1
local url=$2
yellow "Testing connection to $name..."
# Test HTTP/1.1
local res1=$(curl -sS -w "\n%{http_code}" --http1.1 -H "Host: $HOST" --max-time 5 "$url")
local body1=$(echo "$res1" | head -n -1)
local status1=$(echo "$res1" | tail -n 1)
# Test HTTP/2
local res2=$(curl -sS -w "\n%{http_code}" --http2-prior-knowledge -H "Host: $HOST" --max-time 5 "$url")
local body2=$(echo "$res2" | head -n -1)
local status2=$(echo "$res2" | tail -n 1)
local failed=false
if [ "$status1" != "200" ] || [ ${#body1} -ne 4096 ]; then
red "$name failed HTTP/1.1 connection test (Status: $status1, Body length: ${#body1})"
failed=true
fi
if [ "$status2" != "200" ] || [ ${#body2} -ne 4096 ]; then
red "$name failed HTTP/2 connection test (Status: $status2, Body length: ${#body2})"
failed=true
fi
if [ "$failed" = true ]; then
connection_errors+=("$name failed connection test (URL: $url)")
return 1
else
green "$name is reachable (HTTP/1.1 & HTTP/2)"
return 0
fi
}
blue "========================================"
blue "Connection Tests"
blue "========================================"
echo ""
# Run connection tests for all services
for name in "${!services[@]}"; do
if [ -z "$TARGET" ] || [ "${name,,}" = "${TARGET,,}" ]; then
test_connection "$name" "${services[$name]}"
fi
done
echo ""
blue "========================================"
# Exit if any connection test failed
if [ ${#connection_errors[@]} -gt 0 ]; then
echo ""
red "Connection test failed for the following services:"
for error in "${connection_errors[@]}"; do
red " - $error"
done
echo ""
red "Please ensure all services are running before benchmarking"
exit 1
fi
echo ""
green "All services are reachable. Starting benchmarks..."
echo ""
blue "========================================"
echo ""
restart_bench() {
local name=$1
echo ""
yellow "Restarting bench service before benchmarking $name HTTP/1.1..."
docker compose -f dev.compose.yml up -d --force-recreate bench >/dev/null 2>&1
sleep 1
}
# Function to run benchmark
run_benchmark() {
local name=$1
local url=$2
local h2_duration="${DURATION%s}"
restart_bench "$name"
yellow "Testing $name..."
echo "========================================"
echo "$name"
echo "URL: $url"
echo "========================================"
echo ""
echo "[HTTP/1.1] wrk"
wrk -t"$THREADS" -c"$CONNECTIONS" -d"$DURATION" \
-H "Host: $HOST" \
"$url"
restart_bench "$name"
echo ""
echo "[HTTP/2] h2load"
h2load -t"$THREADS" -c"$CONNECTIONS" --duration="$h2_duration" \
-H "Host: $HOST" \
-H ":authority: $HOST" \
"$url" | grep -vE "^(starting benchmark...|spawning thread|progress: |Warm-up |Main benchmark duration|Stopped all clients|Process Request Failure)"
echo ""
green "$name benchmark completed"
blue "----------------------------------------"
echo ""
}
# Run benchmarks for each service
for name in "${!services[@]}"; do
if [ -z "$TARGET" ] || [ "${name,,}" = "${TARGET,,}" ]; then
run_benchmark "$name" "${services[$name]}"
fi
done
blue "========================================"
blue "Benchmark Summary"
blue "========================================"
echo ""
echo "All benchmark output saved to: $OUTFILE"
echo ""
echo "Key metrics to compare:"
echo " - Requests/sec (throughput)"
echo " - Latency (mean, stdev)"
echo " - Transfer/sec"
echo ""
green "All benchmarks completed!"