mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-13 20:19:41 +02:00
fix(entrypoint): reject missing inbound mTLS profile references
Add lookupInboundMTLSProfile so global and route-scoped refs must exist in the loaded profile map. Propagate resolver errors through TLS GetConfigForClient; in HTTP dispatch, return 421 only for SNI and misdirected secure-route cases and log 500 for other resolution failures. Support adding routes with an existing listener for tests, reserve the port via net.Listen without a race, and use t.Cleanup for server teardown. Move relay_proxy_protocol_header documentation to per-route TCP config in config.example.yml.
This commit is contained in:
@@ -160,6 +160,11 @@ providers:
|
||||
# secret: aaaa-bbbb-cccc-dddd
|
||||
# no_tls_verify: true
|
||||
|
||||
# To relay the downstream client address to a TCP upstream, set
|
||||
# `relay_proxy_protocol_header: true` on that specific TCP route in route
|
||||
# configuration (for example, see providers.example.yml). UDP relay is not
|
||||
# supported yet.
|
||||
|
||||
# Match domains
|
||||
# See https://docs.godoxy.dev/Certificates-and-domain-matching
|
||||
#
|
||||
|
||||
@@ -3,6 +3,7 @@ package entrypoint
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -54,6 +55,10 @@ func newHTTPServer(ep *Entrypoint) *httpServer {
|
||||
|
||||
// Listen starts the server and stop when entrypoint is stopped.
|
||||
func (srv *httpServer) Listen(addr string, proto HTTPProto) error {
|
||||
return srv.listen(addr, proto, nil)
|
||||
}
|
||||
|
||||
func (srv *httpServer) listen(addr string, proto HTTPProto, listener net.Listener) error {
|
||||
if srv.addr != "" {
|
||||
return errors.New("server already started")
|
||||
}
|
||||
@@ -68,8 +73,10 @@ func (srv *httpServer) Listen(addr string, proto HTTPProto) error {
|
||||
switch proto {
|
||||
case HTTPProtoHTTP:
|
||||
opts.HTTPAddr = addr
|
||||
opts.HTTPListener = listener
|
||||
case HTTPProtoHTTPS:
|
||||
opts.HTTPSAddr = addr
|
||||
opts.HTTPSListener = listener
|
||||
opts.CertProvider = autocert.FromCtx(srv.ep.task.Context())
|
||||
opts.TLSConfigMutator = srv.mutateServerTLSConfig
|
||||
}
|
||||
@@ -119,9 +126,13 @@ func (srv *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
route, err := srv.resolveRequestRoute(r)
|
||||
switch {
|
||||
case err != nil:
|
||||
case errors.Is(err, errSecureRouteRequiresSNI), errors.Is(err, errSecureRouteMisdirected):
|
||||
http.Error(w, err.Error(), http.StatusMisdirectedRequest)
|
||||
return
|
||||
case err != nil:
|
||||
log.Err(err).Msg("failed to resolve HTTP route")
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
case route != nil:
|
||||
r = routes.WithRouteContext(r, route)
|
||||
if srv.ep.middleware != nil {
|
||||
@@ -145,27 +156,39 @@ var (
|
||||
|
||||
func (srv *httpServer) resolveRequestRoute(req *http.Request) (types.HTTPRoute, error) {
|
||||
hostRoute := srv.FindRoute(req.Host)
|
||||
if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" || len(srv.ep.inboundMTLSProfiles) == 0 {
|
||||
if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" {
|
||||
return hostRoute, nil
|
||||
}
|
||||
|
||||
serverName := req.TLS.ServerName
|
||||
if serverName == "" {
|
||||
if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil {
|
||||
pool, err := srv.resolveInboundMTLSProfileForRoute(hostRoute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pool != nil {
|
||||
return nil, errSecureRouteRequiresSNI
|
||||
}
|
||||
return hostRoute, nil
|
||||
}
|
||||
|
||||
sniRoute := srv.FindRoute(serverName)
|
||||
if pool := srv.resolveInboundMTLSProfileForRoute(sniRoute); pool != nil {
|
||||
pool, err := srv.resolveInboundMTLSProfileForRoute(sniRoute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pool != nil {
|
||||
if !sameHTTPRoute(hostRoute, sniRoute) {
|
||||
return nil, errSecureRouteMisdirected
|
||||
}
|
||||
return sniRoute, nil
|
||||
}
|
||||
|
||||
if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil {
|
||||
pool, err = srv.resolveInboundMTLSProfileForRoute(hostRoute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pool != nil {
|
||||
return nil, errSecureRouteMisdirected
|
||||
}
|
||||
return hostRoute, nil
|
||||
|
||||
@@ -83,16 +83,21 @@ func (srv *httpServer) mutateServerTLSConfig(base *tls.Config) *tls.Config {
|
||||
if base == nil {
|
||||
return base
|
||||
}
|
||||
if pool := srv.resolveInboundMTLSProfileForRoute(nil); pool != nil {
|
||||
return applyInboundMTLSProfile(base, pool)
|
||||
pool, err := srv.resolveInboundMTLSProfileForRoute(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(srv.ep.inboundMTLSProfiles) == 0 {
|
||||
return base
|
||||
if pool != nil {
|
||||
return applyInboundMTLSProfile(base, pool)
|
||||
}
|
||||
|
||||
cfg := base.Clone()
|
||||
cfg.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
if pool := srv.resolveInboundMTLSProfileForServerName(hello.ServerName); pool != nil {
|
||||
pool, err := srv.resolveInboundMTLSProfileForServerName(hello.ServerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pool != nil {
|
||||
return applyInboundMTLSProfile(base, pool), nil
|
||||
}
|
||||
return cloneTLSConfig(base), nil
|
||||
@@ -126,29 +131,37 @@ func ValidateInboundMTLSProfileRef(profileRef, globalProfile string, profiles ma
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *httpServer) resolveInboundMTLSProfileForServerName(serverName string) *x509.CertPool {
|
||||
func (srv *httpServer) resolveInboundMTLSProfileForServerName(serverName string) (*x509.CertPool, error) {
|
||||
if serverName == "" || srv.ep.inboundMTLSProfiles == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
route := srv.FindRoute(serverName)
|
||||
if route == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return srv.resolveInboundMTLSProfileForRoute(route)
|
||||
}
|
||||
|
||||
func (srv *httpServer) resolveInboundMTLSProfileForRoute(route types.HTTPRoute) *x509.CertPool {
|
||||
func (srv *httpServer) resolveInboundMTLSProfileForRoute(route types.HTTPRoute) (*x509.CertPool, error) {
|
||||
if srv.ep.inboundMTLSProfiles == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
if globalRef := srv.ep.cfg.InboundMTLSProfile; globalRef != "" {
|
||||
return srv.ep.inboundMTLSProfiles[globalRef]
|
||||
return srv.lookupInboundMTLSProfile(globalRef, "entrypoint")
|
||||
}
|
||||
if route == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
if ref := route.InboundMTLSProfileRef(); ref != "" {
|
||||
return srv.ep.inboundMTLSProfiles[ref]
|
||||
return srv.lookupInboundMTLSProfile(ref, fmt.Sprintf("route %q", route.Name()))
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (srv *httpServer) lookupInboundMTLSProfile(ref, owner string) (*x509.CertPool, error) {
|
||||
pool, ok := srv.ep.inboundMTLSProfiles[ref]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s inbound mTLS profile %q not found", owner, ref)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -48,18 +49,24 @@ func newFakeHTTPRouteAt(t *testing.T, alias, profile, listenURL string) *fakeHTT
|
||||
}
|
||||
}
|
||||
|
||||
func (r *fakeHTTPRoute) Key() string { return r.key }
|
||||
func (r *fakeHTTPRoute) Name() string { return r.name }
|
||||
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
|
||||
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
|
||||
func (r *fakeHTTPRoute) Finish(any) {}
|
||||
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {}
|
||||
func (r *fakeHTTPRoute) ProviderName() string { return "" }
|
||||
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
|
||||
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
|
||||
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
|
||||
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
|
||||
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {}
|
||||
func (r *fakeHTTPRoute) Key() string { return r.key }
|
||||
func (r *fakeHTTPRoute) Name() string { return r.name }
|
||||
func (r *fakeHTTPRoute) Start(task.Parent) error { return nil }
|
||||
func (r *fakeHTTPRoute) Task() *task.Task { return r.task }
|
||||
func (r *fakeHTTPRoute) Finish(any) {
|
||||
// no-op: test stub
|
||||
}
|
||||
func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {
|
||||
// no-op: test stub
|
||||
}
|
||||
func (r *fakeHTTPRoute) ProviderName() string { return "" }
|
||||
func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil }
|
||||
func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL }
|
||||
func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil }
|
||||
func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil }
|
||||
func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {
|
||||
// no-op: test stub
|
||||
}
|
||||
func (r *fakeHTTPRoute) References() []string { return nil }
|
||||
func (r *fakeHTTPRoute) ShouldExclude() bool { return false }
|
||||
func (r *fakeHTTPRoute) Started() <-chan struct{} { return nil }
|
||||
@@ -169,6 +176,36 @@ func TestSetInboundMTLSProfilesRejectsBadCAFile(t *testing.T) {
|
||||
require.ErrorContains(t, err, "missing.pem")
|
||||
}
|
||||
|
||||
func TestMutateServerTLSConfigRejectsUnknownRouteProfile(t *testing.T) {
|
||||
ep := NewTestEntrypoint(t, nil)
|
||||
ep.SetFindRouteDomains([]string{".example.com"})
|
||||
srv := newTestHTTPServer(t, ep)
|
||||
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
|
||||
|
||||
base := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
mutated := srv.mutateServerTLSConfig(base)
|
||||
|
||||
_, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
|
||||
}
|
||||
|
||||
func TestResolveRequestRouteRejectsUnknownRouteProfile(t *testing.T) {
|
||||
ep := NewTestEntrypoint(t, nil)
|
||||
ep.SetFindRouteDomains([]string{".example.com"})
|
||||
srv := newTestHTTPServer(t, ep)
|
||||
srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://secure-app.example.com", nil)
|
||||
req.Host = "secure-app.example.com"
|
||||
req.TLS = &tls.ConnectionState{ServerName: "secure-app.example.com"}
|
||||
|
||||
route, err := srv.resolveRequestRoute(req)
|
||||
require.Nil(t, route)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`)
|
||||
}
|
||||
|
||||
func TestInboundMTLSGlobalHandshake(t *testing.T) {
|
||||
ca, srv, client, err := agentcert.NewAgent()
|
||||
require.NoError(t, err)
|
||||
@@ -182,13 +219,18 @@ func TestInboundMTLSGlobalHandshake(t *testing.T) {
|
||||
provider := &staticCertProvider{cert: serverCert}
|
||||
|
||||
ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"})
|
||||
t.Cleanup(func() {
|
||||
closeTestServers(t, ep)
|
||||
})
|
||||
autocert.SetCtx(task.GetTestTask(t), provider)
|
||||
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
||||
"global": {CAFiles: []string{caPath}},
|
||||
}))
|
||||
|
||||
listenAddr := reserveTCPAddr(t)
|
||||
addHTTPRouteAt(t, ep, "app1", "", listenAddr)
|
||||
listener, releaseListener := reserveTCPAddr(t)
|
||||
listenAddr := listener.Addr().String()
|
||||
addHTTPRouteAt(t, ep, "app1", "", listenAddr, listener)
|
||||
releaseListener()
|
||||
|
||||
t.Run("trusted client succeeds", func(t *testing.T) {
|
||||
resp, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{
|
||||
@@ -219,8 +261,6 @@ func TestInboundMTLSGlobalHandshake(t *testing.T) {
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
closeTestServers(t, ep)
|
||||
}
|
||||
|
||||
func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
|
||||
@@ -236,15 +276,20 @@ func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
|
||||
provider := &staticCertProvider{cert: serverCert}
|
||||
|
||||
ep := NewTestEntrypoint(t, nil)
|
||||
t.Cleanup(func() {
|
||||
closeTestServers(t, ep)
|
||||
})
|
||||
ep.SetFindRouteDomains([]string{".example.com"})
|
||||
autocert.SetCtx(task.GetTestTask(t), provider)
|
||||
require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{
|
||||
"route": {CAFiles: []string{caPath}},
|
||||
}))
|
||||
|
||||
listenAddr := reserveTCPAddr(t)
|
||||
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr)
|
||||
addHTTPRouteAt(t, ep, "open-app", "", listenAddr)
|
||||
listener, releaseListener := reserveTCPAddr(t)
|
||||
listenAddr := listener.Addr().String()
|
||||
addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr, listener)
|
||||
releaseListener()
|
||||
addHTTPRouteAt(t, ep, "open-app", "", listenAddr, nil)
|
||||
|
||||
t.Run("secure route requires client cert when sni matches", func(t *testing.T) {
|
||||
_, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{
|
||||
@@ -302,14 +347,17 @@ func TestInboundMTLSRouteScopedHandshake(t *testing.T) {
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
closeTestServers(t, ep)
|
||||
}
|
||||
|
||||
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string) {
|
||||
func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string, listener net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
require.NoError(t, ep.StartAddRoute(newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)))
|
||||
route := newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr)
|
||||
if listener == nil {
|
||||
require.NoError(t, ep.StartAddRoute(route))
|
||||
return
|
||||
}
|
||||
require.NoError(t, ep.addHTTPRouteWithListener(route, listenAddr, HTTPProtoHTTPS, listener))
|
||||
}
|
||||
|
||||
func closeTestServers(t *testing.T, ep *Entrypoint) {
|
||||
@@ -319,13 +367,21 @@ func closeTestServers(t *testing.T, ep *Entrypoint) {
|
||||
}
|
||||
}
|
||||
|
||||
func reserveTCPAddr(t *testing.T) string {
|
||||
func reserveTCPAddr(t *testing.T) (net.Listener, func()) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := ln.Addr().String()
|
||||
require.NoError(t, ln.Close())
|
||||
return addr
|
||||
|
||||
owned := true
|
||||
t.Cleanup(func() {
|
||||
if owned {
|
||||
_ = ln.Close()
|
||||
}
|
||||
})
|
||||
return ln, func() {
|
||||
owned = false
|
||||
}
|
||||
}
|
||||
|
||||
func writeTempFile(t *testing.T, name string, data []byte) string {
|
||||
@@ -395,7 +451,9 @@ func (p *staticCertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, er
|
||||
return p.cert, nil
|
||||
}
|
||||
func (p *staticCertProvider) GetCertInfos() ([]autocert.CertInfo, error) { return nil, nil }
|
||||
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {}
|
||||
func (p *staticCertProvider) ObtainCertAll() error { return nil }
|
||||
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
|
||||
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }
|
||||
func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {
|
||||
// no-op: test stub
|
||||
}
|
||||
func (p *staticCertProvider) ObtainCertAll() error { return nil }
|
||||
func (p *staticCertProvider) ForceExpiryAll() bool { return false }
|
||||
func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true }
|
||||
|
||||
@@ -113,10 +113,14 @@ func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error {
|
||||
}
|
||||
|
||||
func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error {
|
||||
return ep.addHTTPRouteWithListener(route, addr, proto, nil)
|
||||
}
|
||||
|
||||
func (ep *Entrypoint) addHTTPRouteWithListener(route types.HTTPRoute, addr string, proto HTTPProto, listener net.Listener) error {
|
||||
var err error
|
||||
srv, _ := ep.servers.LoadOrCompute(addr, func() (newSrv *httpServer, cancel bool) {
|
||||
newSrv = newHTTPServer(ep)
|
||||
err = newSrv.Listen(addr, proto)
|
||||
err = newSrv.listen(addr, proto, listener)
|
||||
cancel = err != nil
|
||||
return
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user