diff --git a/config.example.yml b/config.example.yml index bcebc234..8e7047c6 100644 --- a/config.example.yml +++ b/config.example.yml @@ -160,6 +160,11 @@ providers: # secret: aaaa-bbbb-cccc-dddd # no_tls_verify: true +# To relay the downstream client address to a TCP upstream, set +# `relay_proxy_protocol_header: true` on that specific TCP route in route +# configuration (for example, see providers.example.yml). UDP relay is not +# supported yet. + # Match domains # See https://docs.godoxy.dev/Certificates-and-domain-matching # diff --git a/internal/entrypoint/http_server.go b/internal/entrypoint/http_server.go index 8e176583..c22f32c5 100644 --- a/internal/entrypoint/http_server.go +++ b/internal/entrypoint/http_server.go @@ -3,6 +3,7 @@ package entrypoint import ( "errors" "fmt" + "net" "net/http" "strings" @@ -54,6 +55,10 @@ func newHTTPServer(ep *Entrypoint) *httpServer { // Listen starts the server and stop when entrypoint is stopped. func (srv *httpServer) Listen(addr string, proto HTTPProto) error { + return srv.listen(addr, proto, nil) +} + +func (srv *httpServer) listen(addr string, proto HTTPProto, listener net.Listener) error { if srv.addr != "" { return errors.New("server already started") } @@ -68,8 +73,10 @@ func (srv *httpServer) Listen(addr string, proto HTTPProto) error { switch proto { case HTTPProtoHTTP: opts.HTTPAddr = addr + opts.HTTPListener = listener case HTTPProtoHTTPS: opts.HTTPSAddr = addr + opts.HTTPSListener = listener opts.CertProvider = autocert.FromCtx(srv.ep.task.Context()) opts.TLSConfigMutator = srv.mutateServerTLSConfig } @@ -119,9 +126,13 @@ func (srv *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { route, err := srv.resolveRequestRoute(r) switch { - case err != nil: + case errors.Is(err, errSecureRouteRequiresSNI), errors.Is(err, errSecureRouteMisdirected): http.Error(w, err.Error(), http.StatusMisdirectedRequest) return + case err != nil: + log.Err(err).Msg("failed to resolve HTTP route") + http.Error(w, "internal server error", http.StatusInternalServerError) + return case route != nil: r = routes.WithRouteContext(r, route) if srv.ep.middleware != nil { @@ -145,27 +156,39 @@ var ( func (srv *httpServer) resolveRequestRoute(req *http.Request) (types.HTTPRoute, error) { hostRoute := srv.FindRoute(req.Host) - if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" || len(srv.ep.inboundMTLSProfiles) == 0 { + if req.TLS == nil || srv.ep.cfg.InboundMTLSProfile != "" { return hostRoute, nil } serverName := req.TLS.ServerName if serverName == "" { - if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil { + pool, err := srv.resolveInboundMTLSProfileForRoute(hostRoute) + if err != nil { + return nil, err + } + if pool != nil { return nil, errSecureRouteRequiresSNI } return hostRoute, nil } sniRoute := srv.FindRoute(serverName) - if pool := srv.resolveInboundMTLSProfileForRoute(sniRoute); pool != nil { + pool, err := srv.resolveInboundMTLSProfileForRoute(sniRoute) + if err != nil { + return nil, err + } + if pool != nil { if !sameHTTPRoute(hostRoute, sniRoute) { return nil, errSecureRouteMisdirected } return sniRoute, nil } - if pool := srv.resolveInboundMTLSProfileForRoute(hostRoute); pool != nil { + pool, err = srv.resolveInboundMTLSProfileForRoute(hostRoute) + if err != nil { + return nil, err + } + if pool != nil { return nil, errSecureRouteMisdirected } return hostRoute, nil diff --git a/internal/entrypoint/inbound_mtls.go b/internal/entrypoint/inbound_mtls.go index 68130341..40348f5f 100644 --- a/internal/entrypoint/inbound_mtls.go +++ b/internal/entrypoint/inbound_mtls.go @@ -83,16 +83,21 @@ func (srv *httpServer) mutateServerTLSConfig(base *tls.Config) *tls.Config { if base == nil { return base } - if pool := srv.resolveInboundMTLSProfileForRoute(nil); pool != nil { - return applyInboundMTLSProfile(base, pool) + pool, err := srv.resolveInboundMTLSProfileForRoute(nil) + if err != nil { + panic(err) } - if len(srv.ep.inboundMTLSProfiles) == 0 { - return base + if pool != nil { + return applyInboundMTLSProfile(base, pool) } cfg := base.Clone() cfg.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { - if pool := srv.resolveInboundMTLSProfileForServerName(hello.ServerName); pool != nil { + pool, err := srv.resolveInboundMTLSProfileForServerName(hello.ServerName) + if err != nil { + return nil, err + } + if pool != nil { return applyInboundMTLSProfile(base, pool), nil } return cloneTLSConfig(base), nil @@ -126,29 +131,37 @@ func ValidateInboundMTLSProfileRef(profileRef, globalProfile string, profiles ma return nil } -func (srv *httpServer) resolveInboundMTLSProfileForServerName(serverName string) *x509.CertPool { +func (srv *httpServer) resolveInboundMTLSProfileForServerName(serverName string) (*x509.CertPool, error) { if serverName == "" || srv.ep.inboundMTLSProfiles == nil { - return nil + return nil, nil } route := srv.FindRoute(serverName) if route == nil { - return nil + return nil, nil } return srv.resolveInboundMTLSProfileForRoute(route) } -func (srv *httpServer) resolveInboundMTLSProfileForRoute(route types.HTTPRoute) *x509.CertPool { +func (srv *httpServer) resolveInboundMTLSProfileForRoute(route types.HTTPRoute) (*x509.CertPool, error) { if srv.ep.inboundMTLSProfiles == nil { - return nil + return nil, nil } if globalRef := srv.ep.cfg.InboundMTLSProfile; globalRef != "" { - return srv.ep.inboundMTLSProfiles[globalRef] + return srv.lookupInboundMTLSProfile(globalRef, "entrypoint") } if route == nil { - return nil + return nil, nil } if ref := route.InboundMTLSProfileRef(); ref != "" { - return srv.ep.inboundMTLSProfiles[ref] + return srv.lookupInboundMTLSProfile(ref, fmt.Sprintf("route %q", route.Name())) } - return nil + return nil, nil +} + +func (srv *httpServer) lookupInboundMTLSProfile(ref, owner string) (*x509.CertPool, error) { + pool, ok := srv.ep.inboundMTLSProfiles[ref] + if !ok { + return nil, fmt.Errorf("%s inbound mTLS profile %q not found", owner, ref) + } + return pool, nil } diff --git a/internal/entrypoint/inbound_mtls_test.go b/internal/entrypoint/inbound_mtls_test.go index 4424ef05..18c9fdfd 100644 --- a/internal/entrypoint/inbound_mtls_test.go +++ b/internal/entrypoint/inbound_mtls_test.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "net/http/httptest" "os" "path/filepath" "testing" @@ -48,18 +49,24 @@ func newFakeHTTPRouteAt(t *testing.T, alias, profile, listenURL string) *fakeHTT } } -func (r *fakeHTTPRoute) Key() string { return r.key } -func (r *fakeHTTPRoute) Name() string { return r.name } -func (r *fakeHTTPRoute) Start(task.Parent) error { return nil } -func (r *fakeHTTPRoute) Task() *task.Task { return r.task } -func (r *fakeHTTPRoute) Finish(any) {} -func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) {} -func (r *fakeHTTPRoute) ProviderName() string { return "" } -func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil } -func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL } -func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil } -func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil } -func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) {} +func (r *fakeHTTPRoute) Key() string { return r.key } +func (r *fakeHTTPRoute) Name() string { return r.name } +func (r *fakeHTTPRoute) Start(task.Parent) error { return nil } +func (r *fakeHTTPRoute) Task() *task.Task { return r.task } +func (r *fakeHTTPRoute) Finish(any) { + // no-op: test stub +} +func (r *fakeHTTPRoute) MarshalZerologObject(*zerolog.Event) { + // no-op: test stub +} +func (r *fakeHTTPRoute) ProviderName() string { return "" } +func (r *fakeHTTPRoute) GetProvider() types.RouteProvider { return nil } +func (r *fakeHTTPRoute) ListenURL() *nettypes.URL { return r.listenURL } +func (r *fakeHTTPRoute) TargetURL() *nettypes.URL { return nil } +func (r *fakeHTTPRoute) HealthMonitor() types.HealthMonitor { return nil } +func (r *fakeHTTPRoute) SetHealthMonitor(types.HealthMonitor) { + // no-op: test stub +} func (r *fakeHTTPRoute) References() []string { return nil } func (r *fakeHTTPRoute) ShouldExclude() bool { return false } func (r *fakeHTTPRoute) Started() <-chan struct{} { return nil } @@ -169,6 +176,36 @@ func TestSetInboundMTLSProfilesRejectsBadCAFile(t *testing.T) { require.ErrorContains(t, err, "missing.pem") } +func TestMutateServerTLSConfigRejectsUnknownRouteProfile(t *testing.T) { + ep := NewTestEntrypoint(t, nil) + ep.SetFindRouteDomains([]string{".example.com"}) + srv := newTestHTTPServer(t, ep) + srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing")) + + base := &tls.Config{MinVersion: tls.VersionTLS12} + mutated := srv.mutateServerTLSConfig(base) + + _, err := mutated.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "secure-app.example.com"}) + require.Error(t, err) + require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`) +} + +func TestResolveRequestRouteRejectsUnknownRouteProfile(t *testing.T) { + ep := NewTestEntrypoint(t, nil) + ep.SetFindRouteDomains([]string{".example.com"}) + srv := newTestHTTPServer(t, ep) + srv.AddRoute(newFakeHTTPRoute(t, "secure-app", "missing")) + + req := httptest.NewRequest(http.MethodGet, "https://secure-app.example.com", nil) + req.Host = "secure-app.example.com" + req.TLS = &tls.ConnectionState{ServerName: "secure-app.example.com"} + + route, err := srv.resolveRequestRoute(req) + require.Nil(t, route) + require.Error(t, err) + require.ErrorContains(t, err, `route "secure-app" inbound mTLS profile "missing" not found`) +} + func TestInboundMTLSGlobalHandshake(t *testing.T) { ca, srv, client, err := agentcert.NewAgent() require.NoError(t, err) @@ -182,13 +219,18 @@ func TestInboundMTLSGlobalHandshake(t *testing.T) { provider := &staticCertProvider{cert: serverCert} ep := NewTestEntrypoint(t, &Config{InboundMTLSProfile: "global"}) + t.Cleanup(func() { + closeTestServers(t, ep) + }) autocert.SetCtx(task.GetTestTask(t), provider) require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{ "global": {CAFiles: []string{caPath}}, })) - listenAddr := reserveTCPAddr(t) - addHTTPRouteAt(t, ep, "app1", "", listenAddr) + listener, releaseListener := reserveTCPAddr(t) + listenAddr := listener.Addr().String() + addHTTPRouteAt(t, ep, "app1", "", listenAddr, listener) + releaseListener() t.Run("trusted client succeeds", func(t *testing.T) { resp, err := doHTTPSRequest(listenAddr, "app1.example.com", &tls.Config{ @@ -219,8 +261,6 @@ func TestInboundMTLSGlobalHandshake(t *testing.T) { }) require.Error(t, err) }) - - closeTestServers(t, ep) } func TestInboundMTLSRouteScopedHandshake(t *testing.T) { @@ -236,15 +276,20 @@ func TestInboundMTLSRouteScopedHandshake(t *testing.T) { provider := &staticCertProvider{cert: serverCert} ep := NewTestEntrypoint(t, nil) + t.Cleanup(func() { + closeTestServers(t, ep) + }) ep.SetFindRouteDomains([]string{".example.com"}) autocert.SetCtx(task.GetTestTask(t), provider) require.NoError(t, ep.SetInboundMTLSProfiles(map[string]types.InboundMTLSProfile{ "route": {CAFiles: []string{caPath}}, })) - listenAddr := reserveTCPAddr(t) - addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr) - addHTTPRouteAt(t, ep, "open-app", "", listenAddr) + listener, releaseListener := reserveTCPAddr(t) + listenAddr := listener.Addr().String() + addHTTPRouteAt(t, ep, "secure-app", "route", listenAddr, listener) + releaseListener() + addHTTPRouteAt(t, ep, "open-app", "", listenAddr, nil) t.Run("secure route requires client cert when sni matches", func(t *testing.T) { _, err := doHTTPSRequest(listenAddr, "secure-app.example.com", &tls.Config{ @@ -302,14 +347,17 @@ func TestInboundMTLSRouteScopedHandshake(t *testing.T) { defer func() { _ = resp.Body.Close() }() require.Equal(t, http.StatusMisdirectedRequest, resp.StatusCode) }) - - closeTestServers(t, ep) } -func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string) { +func addHTTPRouteAt(t *testing.T, ep *Entrypoint, alias, profile, listenAddr string, listener net.Listener) { t.Helper() - require.NoError(t, ep.StartAddRoute(newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr))) + route := newFakeHTTPRouteAt(t, alias, profile, "https://"+listenAddr) + if listener == nil { + require.NoError(t, ep.StartAddRoute(route)) + return + } + require.NoError(t, ep.addHTTPRouteWithListener(route, listenAddr, HTTPProtoHTTPS, listener)) } func closeTestServers(t *testing.T, ep *Entrypoint) { @@ -319,13 +367,21 @@ func closeTestServers(t *testing.T, ep *Entrypoint) { } } -func reserveTCPAddr(t *testing.T) string { +func reserveTCPAddr(t *testing.T) (net.Listener, func()) { t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - addr := ln.Addr().String() - require.NoError(t, ln.Close()) - return addr + + owned := true + t.Cleanup(func() { + if owned { + _ = ln.Close() + } + }) + return ln, func() { + owned = false + } } func writeTempFile(t *testing.T, name string, data []byte) string { @@ -395,7 +451,9 @@ func (p *staticCertProvider) GetCert(*tls.ClientHelloInfo) (*tls.Certificate, er return p.cert, nil } func (p *staticCertProvider) GetCertInfos() ([]autocert.CertInfo, error) { return nil, nil } -func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) {} -func (p *staticCertProvider) ObtainCertAll() error { return nil } -func (p *staticCertProvider) ForceExpiryAll() bool { return false } -func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true } +func (p *staticCertProvider) ScheduleRenewalAll(task.Parent) { + // no-op: test stub +} +func (p *staticCertProvider) ObtainCertAll() error { return nil } +func (p *staticCertProvider) ForceExpiryAll() bool { return false } +func (p *staticCertProvider) WaitRenewalDone(context.Context) bool { return true } diff --git a/internal/entrypoint/routes.go b/internal/entrypoint/routes.go index 53219e88..9346c3f8 100644 --- a/internal/entrypoint/routes.go +++ b/internal/entrypoint/routes.go @@ -113,10 +113,14 @@ func (ep *Entrypoint) AddHTTPRoute(route types.HTTPRoute) error { } func (ep *Entrypoint) addHTTPRoute(route types.HTTPRoute, addr string, proto HTTPProto) error { + return ep.addHTTPRouteWithListener(route, addr, proto, nil) +} + +func (ep *Entrypoint) addHTTPRouteWithListener(route types.HTTPRoute, addr string, proto HTTPProto, listener net.Listener) error { var err error srv, _ := ep.servers.LoadOrCompute(addr, func() (newSrv *httpServer, cancel bool) { newSrv = newHTTPServer(ep) - err = newSrv.Listen(addr, proto) + err = newSrv.listen(addr, proto, listener) cancel = err != nil return })