fix(entrypoint): reject missing inbound mTLS profile references

Add lookupInboundMTLSProfile so global and route-scoped refs must exist
in the loaded profile map. Propagate resolver errors through TLS
GetConfigForClient; in HTTP dispatch, return 421 only for SNI and
misdirected secure-route cases and log 500 for other resolution
failures.

Support adding routes with an existing listener for tests, reserve the
port via net.Listen without a race, and use t.Cleanup for server
teardown. Move relay_proxy_protocol_header documentation to per-route
TCP config in config.example.yml.
This commit is contained in:
yusing
2026-04-13 14:56:38 +08:00
parent 2a3823091d
commit c7f9c2889b
5 changed files with 154 additions and 51 deletions

View File

@@ -160,6 +160,11 @@ providers:
# secret: aaaa-bbbb-cccc-dddd
# no_tls_verify: true
# To relay the downstream client address to a TCP upstream, set
# `relay_proxy_protocol_header: true` on that specific TCP route in route
# configuration (for example, see providers.example.yml). UDP relay is not
# supported yet.
# Match domains
# See https://docs.godoxy.dev/Certificates-and-domain-matching
#

View File

@@ -3,6 +3,7 @@ package entrypoint
import (
"errors"
"fmt"
"net"
"net/http"
"strings"
@@ -54,6 +55,10 @@ func newHTTPServer(ep *Entrypoint) *httpServer {
// Listen starts the server and stop when entrypoint is stopped.
func (srv *httpServer) Listen(addr string, proto HTTPProto) error {
return srv.listen(addr, proto, nil)
}
func (srv *httpServer) listen(addr string, proto HTTPProto, listener net.Listener) error {
if srv.addr != "" {
return errors.New("server already started")
}
@@ -68,8 +73,10 @@ func (srv *httpServer) Listen(addr string, proto HTTPProto) error {
switch proto {
case HTTPProtoHTTP:
opts.HTTPAddr = addr
opts.HTTPListener = listener
case HTTPProtoHTTPS:
opts.HTTPSAddr = addr
opts.HTTPSListener = listener
opts.CertProvider = autocert.FromCtx(srv.ep.task.Context())
opts.TLSConfigMutator = srv.mutateServerTLSConfig
}
@@ -119,9 +126,13 @@ func (srv *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
route, err := srv.resolveRequestRoute(r)
switch {
case err != nil:
case errors.Is(err, errSecureRouteRequiresSNI), errors.Is(err, errSecureRouteMisdirected):
http.Error(w, err.Error(), http.StatusMisdirectedRequest)
return
case err != nil:
log.Err(err).Msg("failed to resolve HTTP route")
http.Error(w, "internal server error", http.StatusInternalServerError)
return
case route != nil:
r = routes.WithRouteContext(r, route)
if srv.ep.middleware != nil {
@@ -145,27 +156,39 @@ var (
func (srv *httpServer) resolveRequestRoute(req *http.Request) (types.HTTPRoute, error) {
hostRoute := srv.FindRoute(req.Host)
if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" || len(srv.ep.inboundMTLSProfiles) == 0 {
if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" {
return hostRoute, nil
}
serverName := req.TLS.ServerName
if serverName == "" {
if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil {
pool, err := srv.resolveInboundMTLSProfileForRoute(hostRoute)
if err != nil {
return nil, err
}
if pool != nil {
return nil, errSecureRouteRequiresSNI
}
return hostRoute, nil
}
sniRoute := srv.FindRoute(serverName)
if pool := srv.resolveInboundMTLSProfileForRoute(sniRoute); pool != nil {
pool, err := srv.resolveInboundMTLSProfileForRoute(sniRoute)
if err != nil {
return nil, err
}
if pool != nil {
if !sameHTTPRoute(hostRoute, sniRoute) {
return nil, errSecureRouteMisdirected
}
return sniRoute, nil
}
if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil {
pool, err = srv.resolveInboundMTLSProfileForRoute(hostRoute)
if err != nil {
return nil, err
}
if pool != nil {
return nil, errSecureRouteMisdirected
}
return hostRoute, nil

View File

@@ -83,16 +83,21 @@ func (srv *httpServer) mutateServerTLSConfig(base *tls.Config) *tls.Config {
if base == nil {
return base
}
if pool := srv.resolveInboundMTLSProfileForRoute(nil); pool != nil {
return applyInboundMTLSProfile(base, pool)
pool, err := srv.resolveInboundMTLSProfileForRoute(nil)
if err != nil {
panic(err)
}
if len(srv.ep.inboundMTLSProfiles) == 0 {
return base
if pool != nil {
return applyInboundMTLSProfile(base, pool)
}
cfg := base.Clone()
cfg.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
if pool := srv.resolveInboundMTLSProfileForServerName(hello.ServerName); pool != nil {
pool, err := srv.resolveInboundMTLSProfileForServerName(hello.ServerName)
if err != nil {
return nil, err
}
if pool != nil {
return applyInboundMTLSProfile(base, pool), nil
}
return cloneTLSConfig(base), nil
@@ -126,29 +131,37 @@ func ValidateInboundMTLSProfileRef(profileRef, globalProfile string, profiles ma
return nil
}
func (srv *httpServer) resolveInboundMTLSProfileForServerName(serverName string) *x509.CertPool {
func (srv *httpServer) resolveInboundMTLSProfileForServerName(serverName string) (*x509.CertPool, error) {
if serverName == "" || srv.ep.inboundMTLSProfiles == nil {
return nil
return nil, nil
}
route := srv.FindRoute(serverName)
if route == nil {
return nil
return nil, nil
}
return srv.resolveInboundMTLSProfileForRoute(route)
}
func (srv *httpServer) resolveInboundMTLSProfileForRoute(route types.HTTPRoute) *x509.CertPool {
func (srv *httpServer) resolveInboundMTLSProfileForRoute(route types.HTTPRoute) (*x509.CertPool, error) {
if srv.ep.inboundMTLSProfiles == nil {
return nil
return nil, nil
}
if globalRef := srv.ep.cfg.InboundMTLSProfile; globalRef != "" {
return srv.ep.inboundMTLSProfiles[globalRef]
return srv.lookupInboundMTLSProfile(globalRef, "entrypoint")
}
if route == nil {
return nil
return nil, nil
}
if ref := route.InboundMTLSProfileRef(); ref != "" {
return srv.ep.inboundMTLSProfiles[ref]
return srv.lookupInboundMTLSProfile(ref, fmt.Sprintf("route %q", route.Name()))
}
return nil
return nil, nil
}
func (srv *httpServer) lookupInboundMTLSProfile(ref, owner string) (*x509.CertPool, error) {
pool, ok := srv.ep.inboundMTLSProfiles[ref]
if !ok {
return nil, fmt.Errorf("%s inbound mTLS profile %q not found", owner, ref)
}
return pool, nil
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -48,18 +49,24 @@ func newFakeHTTPRouteAt(t *testing.T, alias, profile, listenURL string) *fakeHTT
}
}
func (r *fakeHTTPRoute) Key() string { return r.key }
func (r *fakeHTTPRoute) Name() string { return r.name }
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
func (r *fakeHTTPRoute) Finish(any) {}
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {}
func (r *fakeHTTPRoute) ProviderName() string { return "" }
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {}
func (r *fakeHTTPRoute) Key() string { return r.key }
func (r *fakeHTTPRoute) Name() string { return r.name }
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
func (r *fakeHTTPRoute) Finish(any) {
// no-op: test stub
}
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {
// no-op: test stub
}
func (r *fakeHTTPRoute) ProviderName() string { return "" }
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {
// no-op: test stub
}
func (r *fakeHTTPRoute) References() []string { return nil }
func (r *fakeHTTPRoute) ShouldExclude() bool { return false }
func (r *fakeHTTPRoute) Started() <-chan struct{} { return nil }
@@ -169,6 +176,36 @@ func TestSetInboundMTLSProfilesRejectsBadCAFile(t *testing.T) {
require.ErrorContains(t, err, "missing.pem")
}
func TestMutateServerTLSConfigRejectsUnknownRouteProfile(t *testing.T) {
ep := NewTestEntrypoint(t, nil)
ep.SetFindRouteDomains([]string{".example.com"})
srv := newTestHTTPServer(t, ep)
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
base := &tls.Config{MinVersion: tls.VersionTLS12}
mutated := srv.mutateServerTLSConfig(base)
_, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
require.Error(t, err)
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
}
func TestResolveRequestRouteRejectsUnknownRouteProfile(t *testing.T) {
ep := NewTestEntrypoint(t, nil)
ep.SetFindRouteDomains([]string{".example.com"})
srv := newTestHTTPServer(t, ep)
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
req := httptest.NewRequest(http.MethodGet, "https://secure-app.example.com", nil)
req.Host = "secure-app.example.com"
req.TLS = &tls.ConnectionState{ServerName: "secure-app.example.com"}
route, err := srv.resolveRequestRoute(req)
require.Nil(t, route)
require.Error(t, err)
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
}
func TestInboundMTLSGlobalHandshake(t *testing.T) {
ca, srv, client, err := agentcert.NewAgent()
require.NoError(t, err)
@@ -182,13 +219,18 @@ func TestInboundMTLSGlobalHandshake(t *testing.T) {
provider := &staticCertProvider{cert: serverCert}
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
t.Cleanup(func() {
closeTestServers(t, ep)
})
autocert.SetCtx(task.GetTestTask(t), provider)
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"global": {CAFiles: []string{caPath}},
}))
listenAddr := reserveTCPAddr(t)
addHTTPRouteAt(t, ep, "app1", "", listenAddr)
listener, releaseListener := reserveTCPAddr(t)
listenAddr := listener.Addr().String()
addHTTPRouteAt(t, ep, "app1", "", listenAddr, listener)
releaseListener()
t.Run("trusted client succeeds", func(t *testing.T) {
resp, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
@@ -219,8 +261,6 @@ func TestInboundMTLSGlobalHandshake(t *testing.T) {
})
require.Error(t, err)
})
closeTestServers(t, ep)
}
func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
@@ -236,15 +276,20 @@ func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
provider := &staticCertProvider{cert: serverCert}
ep := NewTestEntrypoint(t, nil)
t.Cleanup(func() {
closeTestServers(t, ep)
})
ep.SetFindRouteDomains([]string{".example.com"})
autocert.SetCtx(task.GetTestTask(t), provider)
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
"route": {CAFiles: []string{caPath}},
}))
listenAddr := reserveTCPAddr(t)
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr)
addHTTPRouteAt(t, ep, "open-app", "", listenAddr)
listener, releaseListener := reserveTCPAddr(t)
listenAddr := listener.Addr().String()
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr, listener)
releaseListener()
addHTTPRouteAt(t, ep, "open-app", "", listenAddr, nil)
t.Run("secure route requires client cert when sni matches", func(t *testing.T) {
_, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
@@ -302,14 +347,17 @@ func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
defer func() { _ = resp.Body.Close() }()
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
})
closeTestServers(t, ep)
}
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string) {
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string, listener net.Listener) {
t.Helper()
require.NoError(t, ep.StartAddRoute(newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)))
route := newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)
if listener == nil {
require.NoError(t, ep.StartAddRoute(route))
return
}
require.NoError(t, ep.addHTTPRouteWithListener(route, listenAddr, HTTPProtoHTTPS, listener))
}
func closeTestServers(t *testing.T, ep *Entrypoint) {
@@ -319,13 +367,21 @@ func closeTestServers(t *testing.T, ep *Entrypoint) {
}
}
func reserveTCPAddr(t *testing.T) string {
func reserveTCPAddr(t *testing.T) (net.Listener, func()) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := ln.Addr().String()
require.NoError(t, ln.Close())
return addr
owned := true
t.Cleanup(func() {
if owned {
_ = ln.Close()
}
})
return ln, func() {
owned = false
}
}
func writeTempFile(t *testing.T, name string, data []byte) string {
@@ -395,7 +451,9 @@ func (p *staticCertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, er
return p.cert, nil
}
func (p *staticCertProvider) GetCertInfos() ([]autocert.CertInfo, error) { return nil, nil }
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {}
func (p *staticCertProvider) ObtainCertAll() error { return nil }
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {
// no-op: test stub
}
func (p *staticCertProvider) ObtainCertAll() error { return nil }
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }

View File

@@ -113,10 +113,14 @@ func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
}
func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error {
return ep.addHTTPRouteWithListener(route, addr, proto, nil)
}
func (ep *Entrypoint) addHTTPRouteWithListener(route types.HTTPRoute, addr string, proto HTTPProto, listener net.Listener) error {
var err error
srv, _ := ep.servers.LoadOrCompute(addr, func() (newSrv *httpServer, cancel bool) {
newSrv = newHTTPServer(ep)
err = newSrv.Listen(addr, proto)
err = newSrv.listen(addr, proto, listener)
cancel = err != nil
return
})