mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-14 05:00:19 +02:00
Add root-level inbound_mtls_profiles combining optional system CAs with PEM CA files, and entrypoint.inbound_mtls_profile to require client certificates on every HTTPS connection. Route-level inbound_mtls_profile is allowed only without a global profile; per-handshake TLS picks ClientCAs from SNI, and requests fail with 421 when Host and SNI would select different mTLS routes. Compile pools at init (SetInboundMTLSProfiles from state.initEntrypoint) and reject unknown profile refs or mixed global-plus-route configuration. Extend config.example.yml and package READMEs; add entrypoint and config tests for TLS mutation, handshakes, and validation.
402 lines
14 KiB
Go
402 lines
14 KiB
Go
package entrypoint
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/stretchr/testify/require"
|
|
agentcert "github.com/yusing/godoxy/agent/pkg/agent"
|
|
"github.com/yusing/godoxy/internal/agentpool"
|
|
autocert "github.com/yusing/godoxy/internal/autocert/types"
|
|
"github.com/yusing/godoxy/internal/common"
|
|
"github.com/yusing/godoxy/internal/homepage"
|
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
|
"github.com/yusing/godoxy/internal/types"
|
|
"github.com/yusing/goutils/pool"
|
|
"github.com/yusing/goutils/task"
|
|
)
|
|
|
|
type fakeHTTPRoute struct {
|
|
key string
|
|
name string
|
|
inboundMTLSProfile string
|
|
listenURL *nettypes.URL
|
|
task *task.Task
|
|
}
|
|
|
|
func newFakeHTTPRoute(t *testing.T, alias, profile string) *fakeHTTPRoute {
|
|
return newFakeHTTPRouteAt(t, alias, profile, "https://:1000")
|
|
}
|
|
|
|
func newFakeHTTPRouteAt(t *testing.T, alias, profile, listenURL string) *fakeHTTPRoute {
|
|
t.Helper()
|
|
|
|
return &fakeHTTPRoute{
|
|
key: alias,
|
|
name: alias,
|
|
inboundMTLSProfile: profile,
|
|
listenURL: nettypes.MustParseURL(listenURL),
|
|
task: task.GetTestTask(t),
|
|
}
|
|
}
|
|
|
|
func (r *fakeHTTPRoute) Key() string { return r.key }
|
|
func (r *fakeHTTPRoute) Name() string { return r.name }
|
|
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
|
|
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
|
|
func (r *fakeHTTPRoute) Finish(any) {}
|
|
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {}
|
|
func (r *fakeHTTPRoute) ProviderName() string { return "" }
|
|
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
|
|
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
|
|
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
|
|
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
|
|
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {}
|
|
func (r *fakeHTTPRoute) References() []string { return nil }
|
|
func (r *fakeHTTPRoute) ShouldExclude() bool { return false }
|
|
func (r *fakeHTTPRoute) Started() <-chan struct{} { return nil }
|
|
func (r *fakeHTTPRoute) IdlewatcherConfig() *types.IdlewatcherConfig { return nil }
|
|
func (r *fakeHTTPRoute) HealthCheckConfig() types.HealthCheckConfig { return types.HealthCheckConfig{} }
|
|
func (r *fakeHTTPRoute) LoadBalanceConfig() *types.LoadBalancerConfig {
|
|
return nil
|
|
}
|
|
func (r *fakeHTTPRoute) HomepageItem() homepage.Item { return homepage.Item{} }
|
|
func (r *fakeHTTPRoute) DisplayName() string { return r.name }
|
|
func (r *fakeHTTPRoute) ContainerInfo() *types.Container {
|
|
return nil
|
|
}
|
|
func (r *fakeHTTPRoute) GetAgent() *agentpool.Agent { return nil }
|
|
func (r *fakeHTTPRoute) IsDocker() bool { return false }
|
|
func (r *fakeHTTPRoute) IsAgent() bool { return false }
|
|
func (r *fakeHTTPRoute) UseLoadBalance() bool { return false }
|
|
func (r *fakeHTTPRoute) UseIdleWatcher() bool { return false }
|
|
func (r *fakeHTTPRoute) UseHealthCheck() bool { return false }
|
|
func (r *fakeHTTPRoute) UseAccessLog() bool { return false }
|
|
func (r *fakeHTTPRoute) ServeHTTP(http.ResponseWriter, *http.Request) {}
|
|
func (r *fakeHTTPRoute) InboundMTLSProfileRef() string { return r.inboundMTLSProfile }
|
|
|
|
func newTestHTTPServer(t *testing.T, ep *Entrypoint) *httpServer {
|
|
t.Helper()
|
|
|
|
srv, ok := ep.servers.Load(common.ProxyHTTPAddr)
|
|
if ok {
|
|
return srv
|
|
}
|
|
|
|
srv = &httpServer{
|
|
ep: ep,
|
|
addr: common.ProxyHTTPAddr,
|
|
routes: pool.New[types.HTTPRoute]("test-http-routes", "test-http-routes"),
|
|
}
|
|
ep.servers.Store(common.ProxyHTTPAddr, srv)
|
|
return srv
|
|
}
|
|
|
|
func TestMutateServerTLSConfigWithGlobalProfile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
|
|
srv := newTestHTTPServer(t, ep)
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"global": {UseSystemCAs: true},
|
|
}))
|
|
|
|
base := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
mutated := srv.mutateServerTLSConfig(base)
|
|
|
|
require.Equal(t, tls.RequireAndVerifyClientCert, mutated.ClientAuth)
|
|
require.NotNil(t, mutated.ClientCAs)
|
|
require.Nil(t, mutated.GetConfigForClient)
|
|
}
|
|
|
|
func TestMutateServerTLSConfigWithRouteProfiles(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, nil)
|
|
ep.SetFindRouteDomains([]string{".example.com"})
|
|
srv := newTestHTTPServer(t, ep)
|
|
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "route"))
|
|
srv.AddRoute(newFakeHTTPRoute(t, "open-app", ""))
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"route": {UseSystemCAs: true},
|
|
}))
|
|
|
|
base := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
mutated := srv.mutateServerTLSConfig(base)
|
|
|
|
require.Zero(t, mutated.ClientAuth)
|
|
require.Nil(t, mutated.ClientCAs)
|
|
require.NotNil(t, mutated.GetConfigForClient)
|
|
|
|
secureCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
|
|
require.NoError(t, err)
|
|
require.Equal(t, tls.RequireAndVerifyClientCert, secureCfg.ClientAuth)
|
|
require.NotNil(t, secureCfg.ClientCAs)
|
|
require.Nil(t, secureCfg.GetConfigForClient)
|
|
|
|
openCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "open-app.example.com"})
|
|
require.NoError(t, err)
|
|
require.Zero(t, openCfg.ClientAuth)
|
|
require.Nil(t, openCfg.ClientCAs)
|
|
require.Nil(t, openCfg.GetConfigForClient)
|
|
|
|
unknownCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "unknown.example.com"})
|
|
require.NoError(t, err)
|
|
require.Zero(t, unknownCfg.ClientAuth)
|
|
require.Nil(t, unknownCfg.ClientCAs)
|
|
require.Nil(t, unknownCfg.GetConfigForClient)
|
|
}
|
|
|
|
func TestSetInboundMTLSProfilesRejectsUnknownGlobalProfile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "missing"})
|
|
err := ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"known": {UseSystemCAs: true},
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, `entrypoint inbound mTLS profile "missing" not found`)
|
|
}
|
|
|
|
func TestSetInboundMTLSProfilesRejectsBadCAFile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "broken"})
|
|
err := ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"broken": {CAFiles: []string{filepath.Join(t.TempDir(), "missing.pem")}},
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "missing.pem")
|
|
}
|
|
|
|
func TestInboundMTLSGlobalHandshake(t *testing.T) {
|
|
ca, srv, client, err := agentcert.NewAgent()
|
|
require.NoError(t, err)
|
|
|
|
serverCert, err := srv.ToTLSCert()
|
|
require.NoError(t, err)
|
|
clientCert, err := client.ToTLSCert()
|
|
require.NoError(t, err)
|
|
|
|
caPath := writeTempFile(t, "ca.pem", ca.Cert)
|
|
provider := &staticCertProvider{cert: serverCert}
|
|
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
|
|
autocert.SetCtx(task.GetTestTask(t), provider)
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"global": {CAFiles: []string{caPath}},
|
|
}))
|
|
|
|
listenAddr := reserveTCPAddr(t)
|
|
addHTTPRouteAt(t, ep, "app1", "", listenAddr)
|
|
|
|
t.Run("trusted client succeeds", func(t *testing.T) {
|
|
resp, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*clientCert},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
_ = resp.Body.Close()
|
|
})
|
|
|
|
t.Run("missing client cert fails handshake", func(t *testing.T) {
|
|
_, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("wrong client cert fails handshake", func(t *testing.T) {
|
|
_, _, badClient, err := agentcert.NewAgent()
|
|
require.NoError(t, err)
|
|
badClientCert, err := badClient.ToTLSCert()
|
|
require.NoError(t, err)
|
|
|
|
_, err = doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*badClientCert},
|
|
})
|
|
require.Error(t, err)
|
|
})
|
|
|
|
closeTestServers(t, ep)
|
|
}
|
|
|
|
func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
|
|
ca, srv, client, err := agentcert.NewAgent()
|
|
require.NoError(t, err)
|
|
|
|
serverCert, err := srv.ToTLSCert()
|
|
require.NoError(t, err)
|
|
clientCert, err := client.ToTLSCert()
|
|
require.NoError(t, err)
|
|
|
|
caPath := writeTempFile(t, "ca.pem", ca.Cert)
|
|
provider := &staticCertProvider{cert: serverCert}
|
|
|
|
ep := NewTestEntrypoint(t, nil)
|
|
ep.SetFindRouteDomains([]string{".example.com"})
|
|
autocert.SetCtx(task.GetTestTask(t), provider)
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"route": {CAFiles: []string{caPath}},
|
|
}))
|
|
|
|
listenAddr := reserveTCPAddr(t)
|
|
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr)
|
|
addHTTPRouteAt(t, ep, "open-app", "", listenAddr)
|
|
|
|
t.Run("secure route requires client cert when sni matches", func(t *testing.T) {
|
|
_, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("secure route accepts trusted client cert", func(t *testing.T) {
|
|
resp, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*clientCert},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
_ = resp.Body.Close()
|
|
})
|
|
|
|
t.Run("open route without client cert succeeds", func(t *testing.T) {
|
|
resp, err := doHTTPSRequest(listenAddr, "open-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
_ = resp.Body.Close()
|
|
})
|
|
|
|
t.Run("secure route rejects requests without sni", func(t *testing.T) {
|
|
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "secure-app.example.com", "", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = tlsConn.Close() }()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("secure route rejects host and sni mismatch without cert", func(t *testing.T) {
|
|
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "secure-app.example.com", "open-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = tlsConn.Close() }()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("open route rejects host and sni mismatch when sni selects secure route", func(t *testing.T) {
|
|
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "open-app.example.com", "secure-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*clientCert},
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = tlsConn.Close() }()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
|
})
|
|
|
|
closeTestServers(t, ep)
|
|
}
|
|
|
|
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string) {
|
|
t.Helper()
|
|
|
|
require.NoError(t, ep.StartAddRoute(newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)))
|
|
}
|
|
|
|
func closeTestServers(t *testing.T, ep *Entrypoint) {
|
|
t.Helper()
|
|
for _, srv := range ep.servers.Range {
|
|
srv.Close()
|
|
}
|
|
}
|
|
|
|
func reserveTCPAddr(t *testing.T) string {
|
|
t.Helper()
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
addr := ln.Addr().String()
|
|
require.NoError(t, ln.Close())
|
|
return addr
|
|
}
|
|
|
|
func writeTempFile(t *testing.T, name string, data []byte) string {
|
|
t.Helper()
|
|
path := filepath.Join(t.TempDir(), name)
|
|
require.NoError(t, os.WriteFile(path, data, 0o600))
|
|
return path
|
|
}
|
|
|
|
func doHTTPSRequest(addr, host string, tlsConfig *tls.Config) (*http.Response, error) {
|
|
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Host = host
|
|
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: cloneTLSConfigWithServerName(tlsConfig, host),
|
|
},
|
|
}
|
|
return client.Do(req)
|
|
}
|
|
|
|
// doHTTPSRequestWithServerName sends GET https://addr/ with HTTP Host set to host and TLS
|
|
// ServerName set to serverName (SNI may differ from Host). The returned connection stays open
|
|
// until the caller closes it after finishing with resp (typically close resp.Body first, then
|
|
// the tls connection).
|
|
func doHTTPSRequestWithServerName(addr, host, serverName string, tlsConfig *tls.Config) (*http.Response, io.Closer, error) {
|
|
conn, err := tls.Dial("tcp", addr, cloneTLSConfigWithServerName(tlsConfig, serverName))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
req.Host = host
|
|
if err := req.Write(conn); err != nil {
|
|
_ = conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
return resp, conn, nil
|
|
}
|
|
|
|
func cloneTLSConfigWithServerName(cfg *tls.Config, serverName string) *tls.Config {
|
|
if cfg == nil {
|
|
cfg = &tls.Config{}
|
|
}
|
|
cloned := cfg.Clone()
|
|
cloned.ServerName = serverName
|
|
return cloned
|
|
}
|
|
|
|
type staticCertProvider struct {
|
|
cert *tls.Certificate
|
|
}
|
|
|
|
func (p *staticCertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return p.cert, nil
|
|
}
|
|
func (p *staticCertProvider) GetCertInfos() ([]autocert.CertInfo, error) { return nil, nil }
|
|
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {}
|
|
func (p *staticCertProvider) ObtainCertAll() error { return nil }
|
|
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
|
|
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }
|