Files
godoxy/internal/entrypoint/inbound_mtls_test.go
yusing 2a6ad90b72 fix(entrypoint): return nil map from compileInboundMTLSProfiles on error
compileInboundMTLSProfiles now returns a nil map when compilation fails,
instead of a partially populated map alongside the error. This avoids
callers accidentally using incomplete state when err != nil.

Add TestCompileInboundMTLSProfilesReturnsNilMapOnError for a mixed ok/bad
profile map. Reformat fakeHTTPRoute stub methods in the test file.
2026-04-13 16:46:34 +08:00

472 lines
16 KiB
Go

package entrypoint
import (
"bufio"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
agentcert "github.com/yusing/godoxy/agent/pkg/agent"
"github.com/yusing/godoxy/internal/agentpool"
autocert "github.com/yusing/godoxy/internal/autocert/types"
"github.com/yusing/godoxy/internal/common"
"github.com/yusing/godoxy/internal/homepage"
nettypes "github.com/yusing/godoxy/internal/net/types"
"github.com/yusing/godoxy/internal/types"
"github.com/yusing/goutils/pool"
"github.com/yusing/goutils/task"
)
type fakeHTTPRoute struct {
key string
name string
inboundMTLSProfile string
listenURL *nettypes.URL
task *task.Task
}
func newFakeHTTPRoute(t *testing.T, alias, profile string) *fakeHTTPRoute {
return newFakeHTTPRouteAt(t, alias, profile, "https://:1000")
}
func newFakeHTTPRouteAt(t *testing.T, alias, profile, listenURL string) *fakeHTTPRoute {
t.Helper()
return &fakeHTTPRoute{
key: alias,
name: alias,
inboundMTLSProfile: profile,
listenURL: nettypes.MustParseURL(listenURL),
task: task.GetTestTask(t),
}
}
func (r *fakeHTTPRoute) Key() string { return r.key }
func (r *fakeHTTPRoute) Name() string { return r.name }
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
func (r *fakeHTTPRoute) Finish(any) {
// no-op: test stub
}
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {
// no-op: test stub
}
func (r *fakeHTTPRoute) ProviderName() string { return "" }
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {
// no-op: test stub
}
func (r *fakeHTTPRoute) References() []string { return nil }
func (r *fakeHTTPRoute) ShouldExclude() bool { return false }
func (r *fakeHTTPRoute) Started() <-chan struct{} { return nil }
func (r *fakeHTTPRoute) IdlewatcherConfig() *types.IdlewatcherConfig { return nil }
func (r *fakeHTTPRoute) HealthCheckConfig() types.HealthCheckConfig { return types.HealthCheckConfig{} }
func (r *fakeHTTPRoute) LoadBalanceConfig() *types.LoadBalancerConfig {
return nil
}
func (r *fakeHTTPRoute) HomepageItem() homepage.Item { return homepage.Item{} }
func (r *fakeHTTPRoute) DisplayName() string { return r.name }
func (r *fakeHTTPRoute) ContainerInfo() *types.Container {
return nil
}
func (r *fakeHTTPRoute) GetAgent() *agentpool.Agent { return nil }
func (r *fakeHTTPRoute) IsDocker() bool { return false }
func (r *fakeHTTPRoute) IsAgent() bool { return false }
func (r *fakeHTTPRoute) UseLoadBalance() bool { return false }
func (r *fakeHTTPRoute) UseIdleWatcher() bool { return false }
func (r *fakeHTTPRoute) UseHealthCheck() bool { return false }
func (r *fakeHTTPRoute) UseAccessLog() bool { return false }
func (r *fakeHTTPRoute) ServeHTTP(http.ResponseWriter, *http.Request) {
// no-op: test stub
}
func (r *fakeHTTPRoute) InboundMTLSProfileRef() string { return r.inboundMTLSProfile }
func newTestHTTPServer(t *testing.T, ep *Entrypoint) *httpServer {
t.Helper()
srv, ok := ep.servers.Load(common.ProxyHTTPAddr)
if ok {
return srv
}
srv = &httpServer{
ep: ep,
addr: common.ProxyHTTPAddr,
routes: pool.New[types.HTTPRoute]("test-http-routes", "test-http-routes"),
}
ep.servers.Store(common.ProxyHTTPAddr, srv)
return srv
}
func TestMutateServerTLSConfigWithGlobalProfile(t *testing.T) {
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
srv := newTestHTTPServer(t, ep)
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"global": {UseSystemCAs: true},
}))
base := &tls.Config{MinVersion: tls.VersionTLS12}
mutated := srv.mutateServerTLSConfig(base)
require.Equal(t, tls.RequireAndVerifyClientCert, mutated.ClientAuth)
require.NotNil(t, mutated.ClientCAs)
require.Nil(t, mutated.GetConfigForClient)
}
func TestMutateServerTLSConfigWithRouteProfiles(t *testing.T) {
ep := NewTestEntrypoint(t, nil)
ep.SetFindRouteDomains([]string{".example.com"})
srv := newTestHTTPServer(t, ep)
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "route"))
srv.AddRoute(newFakeHTTPRoute(t, "open-app", ""))
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"route": {UseSystemCAs: true},
}))
base := &tls.Config{MinVersion: tls.VersionTLS12}
mutated := srv.mutateServerTLSConfig(base)
require.Zero(t, mutated.ClientAuth)
require.Nil(t, mutated.ClientCAs)
require.NotNil(t, mutated.GetConfigForClient)
secureCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
require.NoError(t, err)
require.Equal(t, tls.RequireAndVerifyClientCert, secureCfg.ClientAuth)
require.NotNil(t, secureCfg.ClientCAs)
require.Nil(t, secureCfg.GetConfigForClient)
openCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "open-app.example.com"})
require.NoError(t, err)
require.Zero(t, openCfg.ClientAuth)
require.Nil(t, openCfg.ClientCAs)
require.Nil(t, openCfg.GetConfigForClient)
unknownCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "unknown.example.com"})
require.NoError(t, err)
require.Zero(t, unknownCfg.ClientAuth)
require.Nil(t, unknownCfg.ClientCAs)
require.Nil(t, unknownCfg.GetConfigForClient)
}
func TestSetInboundMTLSProfilesRejectsUnknownGlobalProfile(t *testing.T) {
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "missing"})
err := ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"known": {UseSystemCAs: true},
})
require.Error(t, err)
require.ErrorContains(t, err, `entrypoint inbound mTLS profile "missing" not found`)
}
func TestSetInboundMTLSProfilesRejectsBadCAFile(t *testing.T) {
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "broken"})
err := ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"broken": {CAFiles: []string{filepath.Join(t.TempDir(), "missing.pem")}},
})
require.Error(t, err)
require.ErrorContains(t, err, "missing.pem")
}
func TestCompileInboundMTLSProfilesReturnsNilMapOnError(t *testing.T) {
compiled, err := compileInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"ok": {UseSystemCAs: true},
"bad": {CAFiles: []string{filepath.Join(t.TempDir(), "missing.pem")}},
})
require.Nil(t, compiled)
require.Error(t, err)
require.ErrorContains(t, err, "missing.pem")
}
func TestMutateServerTLSConfigRejectsUnknownRouteProfile(t *testing.T) {
ep := NewTestEntrypoint(t, nil)
ep.SetFindRouteDomains([]string{".example.com"})
srv := newTestHTTPServer(t, ep)
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
base := &tls.Config{MinVersion: tls.VersionTLS12}
mutated := srv.mutateServerTLSConfig(base)
_, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
require.Error(t, err)
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
}
func TestResolveRequestRouteRejectsUnknownRouteProfile(t *testing.T) {
ep := NewTestEntrypoint(t, nil)
ep.SetFindRouteDomains([]string{".example.com"})
srv := newTestHTTPServer(t, ep)
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
req := httptest.NewRequest(http.MethodGet, "https://secure-app.example.com", nil)
req.Host = "secure-app.example.com"
req.TLS = &tls.ConnectionState{ServerName: "secure-app.example.com"}
route, err := srv.resolveRequestRoute(req)
require.Nil(t, route)
require.Error(t, err)
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
}
func TestInboundMTLSGlobalHandshake(t *testing.T) {
ca, srv, client, err := agentcert.NewAgent()
require.NoError(t, err)
serverCert, err := srv.ToTLSCert()
require.NoError(t, err)
clientCert, err := client.ToTLSCert()
require.NoError(t, err)
caPath := writeTempFile(t, "ca.pem", ca.Cert)
provider := &staticCertProvider{cert: serverCert}
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
t.Cleanup(func() {
closeTestServers(t, ep)
})
autocert.SetCtx(task.GetTestTask(t), provider)
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"global": {CAFiles: []string{caPath}},
}))
listener, releaseListener := reserveTCPAddr(t)
listenAddr := listener.Addr().String()
addHTTPRouteAt(t, ep, "app1", "", listenAddr, listener)
releaseListener()
t.Run("trusted client succeeds", func(t *testing.T) {
resp, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{*clientCert},
})
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()
})
t.Run("missing client cert fails handshake", func(t *testing.T) {
_, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
InsecureSkipVerify: true,
})
require.Error(t, err)
})
t.Run("wrong client cert fails handshake", func(t *testing.T) {
_, _, badClient, err := agentcert.NewAgent()
require.NoError(t, err)
badClientCert, err := badClient.ToTLSCert()
require.NoError(t, err)
_, err = doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{*badClientCert},
})
require.Error(t, err)
})
}
func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
ca, srv, client, err := agentcert.NewAgent()
require.NoError(t, err)
serverCert, err := srv.ToTLSCert()
require.NoError(t, err)
clientCert, err := client.ToTLSCert()
require.NoError(t, err)
caPath := writeTempFile(t, "ca.pem", ca.Cert)
provider := &staticCertProvider{cert: serverCert}
ep := NewTestEntrypoint(t, nil)
t.Cleanup(func() {
closeTestServers(t, ep)
})
ep.SetFindRouteDomains([]string{".example.com"})
autocert.SetCtx(task.GetTestTask(t), provider)
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"route": {CAFiles: []string{caPath}},
}))
listener, releaseListener := reserveTCPAddr(t)
listenAddr := listener.Addr().String()
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr, listener)
releaseListener()
addHTTPRouteAt(t, ep, "open-app", "", listenAddr, nil)
t.Run("secure route requires client cert when sni matches", func(t *testing.T) {
_, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
InsecureSkipVerify: true,
})
require.Error(t, err)
})
t.Run("secure route accepts trusted client cert", func(t *testing.T) {
resp, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{*clientCert},
})
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()
})
t.Run("open route without client cert succeeds", func(t *testing.T) {
resp, err := doHTTPSRequest(listenAddr, "open-app.example.com", &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
_ = resp.Body.Close()
})
t.Run("secure route rejects requests without sni", func(t *testing.T) {
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "secure-app.example.com", "", &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() { _ = tlsConn.Close() }()
defer func() { _ = resp.Body.Close() }()
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
})
t.Run("secure route rejects host and sni mismatch without cert", func(t *testing.T) {
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "secure-app.example.com", "open-app.example.com", &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() { _ = tlsConn.Close() }()
defer func() { _ = resp.Body.Close() }()
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
})
t.Run("open route rejects host and sni mismatch when sni selects secure route", func(t *testing.T) {
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "open-app.example.com", "secure-app.example.com", &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{*clientCert},
})
require.NoError(t, err)
defer func() { _ = tlsConn.Close() }()
defer func() { _ = resp.Body.Close() }()
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
})
}
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string, listener net.Listener) {
t.Helper()
route := newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)
if listener == nil {
require.NoError(t, ep.StartAddRoute(route))
return
}
require.NoError(t, ep.addHTTPRouteWithListener(route, listenAddr, HTTPProtoHTTPS, listener))
}
func closeTestServers(t *testing.T, ep *Entrypoint) {
t.Helper()
for _, srv := range ep.servers.Range {
srv.Close()
}
}
func reserveTCPAddr(t *testing.T) (net.Listener, func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
owned := true
t.Cleanup(func() {
if owned {
_ = ln.Close()
}
})
return ln, func() {
owned = false
}
}
func writeTempFile(t *testing.T, name string, data []byte) string {
t.Helper()
path := filepath.Join(t.TempDir(), name)
require.NoError(t, os.WriteFile(path, data, 0o600))
return path
}
func doHTTPSRequest(addr, host string, tlsConfig *tls.Config) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
if err != nil {
return nil, err
}
req.Host = host
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: cloneTLSConfigWithServerName(tlsConfig, host),
},
}
return client.Do(req)
}
// doHTTPSRequestWithServerName sends GET https://addr/ with HTTP Host set to host and TLS
// ServerName set to serverName (SNI may differ from Host). The returned connection stays open
// until the caller closes it after finishing with resp (typically close resp.Body first, then
// the tls connection).
func doHTTPSRequestWithServerName(addr, host, serverName string, tlsConfig *tls.Config) (*http.Response, io.Closer, error) {
conn, err := tls.Dial("tcp", addr, cloneTLSConfigWithServerName(tlsConfig, serverName))
if err != nil {
return nil, nil, err
}
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
if err != nil {
_ = conn.Close()
return nil, nil, err
}
req.Host = host
if err := req.Write(conn); err != nil {
_ = conn.Close()
return nil, nil, err
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
_ = conn.Close()
return nil, nil, err
}
return resp, conn, nil
}
func cloneTLSConfigWithServerName(cfg *tls.Config, serverName string) *tls.Config {
if cfg == nil {
cfg = &tls.Config{}
}
cloned := cfg.Clone()
cloned.ServerName = serverName
return cloned
}
type staticCertProvider struct {
cert *tls.Certificate
}
func (p *staticCertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.cert, nil
}
func (p *staticCertProvider) GetCertInfos() ([]autocert.CertInfo, error) { return nil, nil }
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {
// no-op: test stub
}
func (p *staticCertProvider) ObtainCertAll() error { return nil }
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }