mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-14 12:39:40 +02:00
compileInboundMTLSProfiles now returns a nil map when compilation fails, instead of a partially populated map alongside the error. This avoids callers accidentally using incomplete state when err != nil. Add TestCompileInboundMTLSProfilesReturnsNilMapOnError for a mixed ok/bad profile map. Reformat fakeHTTPRoute stub methods in the test file.
472 lines
16 KiB
Go
472 lines
16 KiB
Go
package entrypoint
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/rs/zerolog"
|
|
"github.com/stretchr/testify/require"
|
|
agentcert "github.com/yusing/godoxy/agent/pkg/agent"
|
|
"github.com/yusing/godoxy/internal/agentpool"
|
|
autocert "github.com/yusing/godoxy/internal/autocert/types"
|
|
"github.com/yusing/godoxy/internal/common"
|
|
"github.com/yusing/godoxy/internal/homepage"
|
|
nettypes "github.com/yusing/godoxy/internal/net/types"
|
|
"github.com/yusing/godoxy/internal/types"
|
|
"github.com/yusing/goutils/pool"
|
|
"github.com/yusing/goutils/task"
|
|
)
|
|
|
|
type fakeHTTPRoute struct {
|
|
key string
|
|
name string
|
|
inboundMTLSProfile string
|
|
listenURL *nettypes.URL
|
|
task *task.Task
|
|
}
|
|
|
|
func newFakeHTTPRoute(t *testing.T, alias, profile string) *fakeHTTPRoute {
|
|
return newFakeHTTPRouteAt(t, alias, profile, "https://:1000")
|
|
}
|
|
|
|
func newFakeHTTPRouteAt(t *testing.T, alias, profile, listenURL string) *fakeHTTPRoute {
|
|
t.Helper()
|
|
|
|
return &fakeHTTPRoute{
|
|
key: alias,
|
|
name: alias,
|
|
inboundMTLSProfile: profile,
|
|
listenURL: nettypes.MustParseURL(listenURL),
|
|
task: task.GetTestTask(t),
|
|
}
|
|
}
|
|
|
|
func (r *fakeHTTPRoute) Key() string { return r.key }
|
|
func (r *fakeHTTPRoute) Name() string { return r.name }
|
|
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
|
|
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
|
|
func (r *fakeHTTPRoute) Finish(any) {
|
|
// no-op: test stub
|
|
}
|
|
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {
|
|
// no-op: test stub
|
|
}
|
|
func (r *fakeHTTPRoute) ProviderName() string { return "" }
|
|
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
|
|
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
|
|
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
|
|
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
|
|
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {
|
|
// no-op: test stub
|
|
}
|
|
func (r *fakeHTTPRoute) References() []string { return nil }
|
|
func (r *fakeHTTPRoute) ShouldExclude() bool { return false }
|
|
func (r *fakeHTTPRoute) Started() <-chan struct{} { return nil }
|
|
func (r *fakeHTTPRoute) IdlewatcherConfig() *types.IdlewatcherConfig { return nil }
|
|
func (r *fakeHTTPRoute) HealthCheckConfig() types.HealthCheckConfig { return types.HealthCheckConfig{} }
|
|
func (r *fakeHTTPRoute) LoadBalanceConfig() *types.LoadBalancerConfig {
|
|
return nil
|
|
}
|
|
func (r *fakeHTTPRoute) HomepageItem() homepage.Item { return homepage.Item{} }
|
|
func (r *fakeHTTPRoute) DisplayName() string { return r.name }
|
|
func (r *fakeHTTPRoute) ContainerInfo() *types.Container {
|
|
return nil
|
|
}
|
|
func (r *fakeHTTPRoute) GetAgent() *agentpool.Agent { return nil }
|
|
func (r *fakeHTTPRoute) IsDocker() bool { return false }
|
|
func (r *fakeHTTPRoute) IsAgent() bool { return false }
|
|
func (r *fakeHTTPRoute) UseLoadBalance() bool { return false }
|
|
func (r *fakeHTTPRoute) UseIdleWatcher() bool { return false }
|
|
func (r *fakeHTTPRoute) UseHealthCheck() bool { return false }
|
|
func (r *fakeHTTPRoute) UseAccessLog() bool { return false }
|
|
func (r *fakeHTTPRoute) ServeHTTP(http.ResponseWriter, *http.Request) {
|
|
// no-op: test stub
|
|
}
|
|
func (r *fakeHTTPRoute) InboundMTLSProfileRef() string { return r.inboundMTLSProfile }
|
|
|
|
func newTestHTTPServer(t *testing.T, ep *Entrypoint) *httpServer {
|
|
t.Helper()
|
|
|
|
srv, ok := ep.servers.Load(common.ProxyHTTPAddr)
|
|
if ok {
|
|
return srv
|
|
}
|
|
|
|
srv = &httpServer{
|
|
ep: ep,
|
|
addr: common.ProxyHTTPAddr,
|
|
routes: pool.New[types.HTTPRoute]("test-http-routes", "test-http-routes"),
|
|
}
|
|
ep.servers.Store(common.ProxyHTTPAddr, srv)
|
|
return srv
|
|
}
|
|
|
|
func TestMutateServerTLSConfigWithGlobalProfile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
|
|
srv := newTestHTTPServer(t, ep)
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"global": {UseSystemCAs: true},
|
|
}))
|
|
|
|
base := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
mutated := srv.mutateServerTLSConfig(base)
|
|
|
|
require.Equal(t, tls.RequireAndVerifyClientCert, mutated.ClientAuth)
|
|
require.NotNil(t, mutated.ClientCAs)
|
|
require.Nil(t, mutated.GetConfigForClient)
|
|
}
|
|
|
|
func TestMutateServerTLSConfigWithRouteProfiles(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, nil)
|
|
ep.SetFindRouteDomains([]string{".example.com"})
|
|
srv := newTestHTTPServer(t, ep)
|
|
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "route"))
|
|
srv.AddRoute(newFakeHTTPRoute(t, "open-app", ""))
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"route": {UseSystemCAs: true},
|
|
}))
|
|
|
|
base := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
mutated := srv.mutateServerTLSConfig(base)
|
|
|
|
require.Zero(t, mutated.ClientAuth)
|
|
require.Nil(t, mutated.ClientCAs)
|
|
require.NotNil(t, mutated.GetConfigForClient)
|
|
|
|
secureCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
|
|
require.NoError(t, err)
|
|
require.Equal(t, tls.RequireAndVerifyClientCert, secureCfg.ClientAuth)
|
|
require.NotNil(t, secureCfg.ClientCAs)
|
|
require.Nil(t, secureCfg.GetConfigForClient)
|
|
|
|
openCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "open-app.example.com"})
|
|
require.NoError(t, err)
|
|
require.Zero(t, openCfg.ClientAuth)
|
|
require.Nil(t, openCfg.ClientCAs)
|
|
require.Nil(t, openCfg.GetConfigForClient)
|
|
|
|
unknownCfg, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "unknown.example.com"})
|
|
require.NoError(t, err)
|
|
require.Zero(t, unknownCfg.ClientAuth)
|
|
require.Nil(t, unknownCfg.ClientCAs)
|
|
require.Nil(t, unknownCfg.GetConfigForClient)
|
|
}
|
|
|
|
func TestSetInboundMTLSProfilesRejectsUnknownGlobalProfile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "missing"})
|
|
err := ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"known": {UseSystemCAs: true},
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, `entrypoint inbound mTLS profile "missing" not found`)
|
|
}
|
|
|
|
func TestSetInboundMTLSProfilesRejectsBadCAFile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "broken"})
|
|
err := ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"broken": {CAFiles: []string{filepath.Join(t.TempDir(), "missing.pem")}},
|
|
})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "missing.pem")
|
|
}
|
|
|
|
func TestCompileInboundMTLSProfilesReturnsNilMapOnError(t *testing.T) {
|
|
compiled, err := compileInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"ok": {UseSystemCAs: true},
|
|
"bad": {CAFiles: []string{filepath.Join(t.TempDir(), "missing.pem")}},
|
|
})
|
|
require.Nil(t, compiled)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, "missing.pem")
|
|
}
|
|
|
|
func TestMutateServerTLSConfigRejectsUnknownRouteProfile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, nil)
|
|
ep.SetFindRouteDomains([]string{".example.com"})
|
|
srv := newTestHTTPServer(t, ep)
|
|
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
|
|
|
|
base := &tls.Config{MinVersion: tls.VersionTLS12}
|
|
mutated := srv.mutateServerTLSConfig(base)
|
|
|
|
_, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
|
|
}
|
|
|
|
func TestResolveRequestRouteRejectsUnknownRouteProfile(t *testing.T) {
|
|
ep := NewTestEntrypoint(t, nil)
|
|
ep.SetFindRouteDomains([]string{".example.com"})
|
|
srv := newTestHTTPServer(t, ep)
|
|
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "https://secure-app.example.com", nil)
|
|
req.Host = "secure-app.example.com"
|
|
req.TLS = &tls.ConnectionState{ServerName: "secure-app.example.com"}
|
|
|
|
route, err := srv.resolveRequestRoute(req)
|
|
require.Nil(t, route)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
|
|
}
|
|
|
|
func TestInboundMTLSGlobalHandshake(t *testing.T) {
|
|
ca, srv, client, err := agentcert.NewAgent()
|
|
require.NoError(t, err)
|
|
|
|
serverCert, err := srv.ToTLSCert()
|
|
require.NoError(t, err)
|
|
clientCert, err := client.ToTLSCert()
|
|
require.NoError(t, err)
|
|
|
|
caPath := writeTempFile(t, "ca.pem", ca.Cert)
|
|
provider := &staticCertProvider{cert: serverCert}
|
|
|
|
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
|
|
t.Cleanup(func() {
|
|
closeTestServers(t, ep)
|
|
})
|
|
autocert.SetCtx(task.GetTestTask(t), provider)
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"global": {CAFiles: []string{caPath}},
|
|
}))
|
|
|
|
listener, releaseListener := reserveTCPAddr(t)
|
|
listenAddr := listener.Addr().String()
|
|
addHTTPRouteAt(t, ep, "app1", "", listenAddr, listener)
|
|
releaseListener()
|
|
|
|
t.Run("trusted client succeeds", func(t *testing.T) {
|
|
resp, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*clientCert},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
_ = resp.Body.Close()
|
|
})
|
|
|
|
t.Run("missing client cert fails handshake", func(t *testing.T) {
|
|
_, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("wrong client cert fails handshake", func(t *testing.T) {
|
|
_, _, badClient, err := agentcert.NewAgent()
|
|
require.NoError(t, err)
|
|
badClientCert, err := badClient.ToTLSCert()
|
|
require.NoError(t, err)
|
|
|
|
_, err = doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*badClientCert},
|
|
})
|
|
require.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
|
|
ca, srv, client, err := agentcert.NewAgent()
|
|
require.NoError(t, err)
|
|
|
|
serverCert, err := srv.ToTLSCert()
|
|
require.NoError(t, err)
|
|
clientCert, err := client.ToTLSCert()
|
|
require.NoError(t, err)
|
|
|
|
caPath := writeTempFile(t, "ca.pem", ca.Cert)
|
|
provider := &staticCertProvider{cert: serverCert}
|
|
|
|
ep := NewTestEntrypoint(t, nil)
|
|
t.Cleanup(func() {
|
|
closeTestServers(t, ep)
|
|
})
|
|
ep.SetFindRouteDomains([]string{".example.com"})
|
|
autocert.SetCtx(task.GetTestTask(t), provider)
|
|
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
|
"route": {CAFiles: []string{caPath}},
|
|
}))
|
|
|
|
listener, releaseListener := reserveTCPAddr(t)
|
|
listenAddr := listener.Addr().String()
|
|
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr, listener)
|
|
releaseListener()
|
|
addHTTPRouteAt(t, ep, "open-app", "", listenAddr, nil)
|
|
|
|
t.Run("secure route requires client cert when sni matches", func(t *testing.T) {
|
|
_, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.Error(t, err)
|
|
})
|
|
|
|
t.Run("secure route accepts trusted client cert", func(t *testing.T) {
|
|
resp, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*clientCert},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
_ = resp.Body.Close()
|
|
})
|
|
|
|
t.Run("open route without client cert succeeds", func(t *testing.T) {
|
|
resp, err := doHTTPSRequest(listenAddr, "open-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
_ = resp.Body.Close()
|
|
})
|
|
|
|
t.Run("secure route rejects requests without sni", func(t *testing.T) {
|
|
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "secure-app.example.com", "", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = tlsConn.Close() }()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("secure route rejects host and sni mismatch without cert", func(t *testing.T) {
|
|
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "secure-app.example.com", "open-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = tlsConn.Close() }()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
|
})
|
|
|
|
t.Run("open route rejects host and sni mismatch when sni selects secure route", func(t *testing.T) {
|
|
resp, tlsConn, err := doHTTPSRequestWithServerName(listenAddr, "open-app.example.com", "secure-app.example.com", &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*clientCert},
|
|
})
|
|
require.NoError(t, err)
|
|
defer func() { _ = tlsConn.Close() }()
|
|
defer func() { _ = resp.Body.Close() }()
|
|
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
|
})
|
|
}
|
|
|
|
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string, listener net.Listener) {
|
|
t.Helper()
|
|
|
|
route := newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)
|
|
if listener == nil {
|
|
require.NoError(t, ep.StartAddRoute(route))
|
|
return
|
|
}
|
|
require.NoError(t, ep.addHTTPRouteWithListener(route, listenAddr, HTTPProtoHTTPS, listener))
|
|
}
|
|
|
|
func closeTestServers(t *testing.T, ep *Entrypoint) {
|
|
t.Helper()
|
|
for _, srv := range ep.servers.Range {
|
|
srv.Close()
|
|
}
|
|
}
|
|
|
|
func reserveTCPAddr(t *testing.T) (net.Listener, func()) {
|
|
t.Helper()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
owned := true
|
|
t.Cleanup(func() {
|
|
if owned {
|
|
_ = ln.Close()
|
|
}
|
|
})
|
|
return ln, func() {
|
|
owned = false
|
|
}
|
|
}
|
|
|
|
func writeTempFile(t *testing.T, name string, data []byte) string {
|
|
t.Helper()
|
|
path := filepath.Join(t.TempDir(), name)
|
|
require.NoError(t, os.WriteFile(path, data, 0o600))
|
|
return path
|
|
}
|
|
|
|
func doHTTPSRequest(addr, host string, tlsConfig *tls.Config) (*http.Response, error) {
|
|
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Host = host
|
|
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: cloneTLSConfigWithServerName(tlsConfig, host),
|
|
},
|
|
}
|
|
return client.Do(req)
|
|
}
|
|
|
|
// doHTTPSRequestWithServerName sends GET https://addr/ with HTTP Host set to host and TLS
|
|
// ServerName set to serverName (SNI may differ from Host). The returned connection stays open
|
|
// until the caller closes it after finishing with resp (typically close resp.Body first, then
|
|
// the tls connection).
|
|
func doHTTPSRequestWithServerName(addr, host, serverName string, tlsConfig *tls.Config) (*http.Response, io.Closer, error) {
|
|
conn, err := tls.Dial("tcp", addr, cloneTLSConfigWithServerName(tlsConfig, serverName))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
req.Host = host
|
|
if err := req.Write(conn); err != nil {
|
|
_ = conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, nil, err
|
|
}
|
|
return resp, conn, nil
|
|
}
|
|
|
|
func cloneTLSConfigWithServerName(cfg *tls.Config, serverName string) *tls.Config {
|
|
if cfg == nil {
|
|
cfg = &tls.Config{}
|
|
}
|
|
cloned := cfg.Clone()
|
|
cloned.ServerName = serverName
|
|
return cloned
|
|
}
|
|
|
|
type staticCertProvider struct {
|
|
cert *tls.Certificate
|
|
}
|
|
|
|
func (p *staticCertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
return p.cert, nil
|
|
}
|
|
func (p *staticCertProvider) GetCertInfos() ([]autocert.CertInfo, error) { return nil, nil }
|
|
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {
|
|
// no-op: test stub
|
|
}
|
|
func (p *staticCertProvider) ObtainCertAll() error { return nil }
|
|
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
|
|
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }
|