From 0d4f2293ffb8c9f1bcc394b68881e4a015e64783 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 9 Apr 2026 17:27:42 +0000 Subject: [PATCH] state: replace zcache with bounded LRU for auth cache Replace zcache with golang-lru/v2/expirable for both the state auth cache and the OIDC state cache. Add tuning.register_cache_max_entries (default 1024) to cap the number of pending registration entries. Introduce types.RegistrationData to replace caching a full *Node; only the fields the registration callback path reads are retained. Remove the dead HSDatabase.regCache field. Drop zgo.at/zcache/v2 from go.mod. --- cmd/headscale/cli/policy.go | 2 +- config-example.yaml | 6 +- flake.nix | 2 +- go.mod | 2 +- go.sum | 3 - hscontrol/auth.go | 90 ++++++-------- hscontrol/auth_tags_test.go | 8 +- hscontrol/auth_test.go | 18 +-- hscontrol/db/db.go | 16 +-- hscontrol/db/db_test.go | 8 -- hscontrol/db/suite_test.go | 2 - hscontrol/grpcv1.go | 36 +++--- hscontrol/mapper/batcher_test.go | 11 +- hscontrol/oidc.go | 22 ++-- hscontrol/servertest/ephemeral_test.go | 8 +- hscontrol/state/auth_cache_test.go | 64 ++++++++++ hscontrol/state/debug.go | 14 +-- hscontrol/state/state.go | 161 ++++++++++--------------- hscontrol/types/common.go | 59 ++++++--- hscontrol/types/config.go | 14 ++- hscontrol/types/registration.go | 55 +++++++++ 21 files changed, 343 insertions(+), 258 deletions(-) create mode 100644 hscontrol/state/auth_cache_test.go create mode 100644 hscontrol/types/registration.go diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index e5dfba3d..02e91c93 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -28,7 +28,7 @@ func bypassDatabase() (*db.HSDatabase, error) { return nil, fmt.Errorf("loading config: %w", err) } - d, err := db.NewHeadscaleDatabase(cfg, nil) + d, err := db.NewHeadscaleDatabase(cfg) if err != nil { return nil, fmt.Errorf("opening database: %w", err) } diff --git a/config-example.yaml b/config-example.yaml index 751e9517..d3ca2f09 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -445,12 +445,16 @@ taildrop: # When enabled, nodes can send files to other nodes owned by the same user. # Tagged devices and cross-user transfers are not permitted by Tailscale clients. enabled: true - # Advanced performance tuning parameters. # The defaults are carefully chosen and should rarely need adjustment. # Only modify these if you have identified a specific performance issue. # # tuning: +# # Maximum number of pending registration entries in the auth cache. +# # Oldest entries are evicted when the cap is reached. +# # +# # register_cache_max_entries: 1024 +# # # NodeStore write batching configuration. # # The NodeStore batches write operations before rebuilding peer relationships, # # which is computationally expensive. Batching reduces rebuild frequency. diff --git a/flake.nix b/flake.nix index adf702c2..86f8b1e1 100644 --- a/flake.nix +++ b/flake.nix @@ -27,7 +27,7 @@ let pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; buildGo = pkgs.buildGo126Module; - vendorHash = "sha256-G+yhItFhlp2XP/Zd9N4nMQf96YMQLuYd069H+Quewtk="; + vendorHash = "sha256-x0xXxa7sjyDwWLq8fO0Z/pbPefctzctK3TAdBea7FtY="; in { headscale = buildGo { diff --git a/go.mod b/go.mod index 2e53b5e5..6754d07c 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/google/go-cmp v0.7.0 github.com/gorilla/mux v1.8.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jagottsicher/termcolor v1.0.2 github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 github.com/ory/dockertest/v3 v3.12.0 @@ -54,7 +55,6 @@ require ( gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.31.1 tailscale.com v1.96.5 - zgo.at/zcache/v2 v2.4.1 zombiezen.com/go/postgrestest v1.0.1 ) diff --git a/go.sum b/go.sum index 0709c257..69b0ba90 100644 --- a/go.sum +++ b/go.sum @@ -252,7 +252,6 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4= github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= @@ -682,7 +681,5 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= tailscale.com v1.96.5 h1:gNkfA/KSZAl6jCH9cj8urq00HRWItDDTtGsyATI89jA= tailscale.com v1.96.5/go.mod h1:/3lnZBYb2UEwnN0MNu2SDXUtT06AGd5k0s+OWx3WmcY= -zgo.at/zcache/v2 v2.4.1 h1:Dfjoi8yI0Uq7NCc4lo2kaQJJmp9Mijo21gef+oJstbY= -zgo.at/zcache/v2 v2.4.1/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ= diff --git a/hscontrol/auth.go b/hscontrol/auth.go index f24cfc0b..34066d0c 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -1,7 +1,6 @@ package hscontrol import ( - "cmp" "context" "errors" "fmt" @@ -302,31 +301,10 @@ func (h *Headscale) reqToNewRegisterResponse( return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) } - // Ensure we have a valid hostname - hostname := util.EnsureHostname( - req.Hostinfo.View(), - machineKey.String(), - req.NodeKey.String(), + authRegReq := types.NewRegisterAuthRequest( + registrationDataFromRequest(req, machineKey), ) - // Ensure we have valid hostinfo - hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) - hostinfo.Hostname = hostname - - nodeToRegister := types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - } - - if !req.Expiry.IsZero() { - nodeToRegister.Expiry = &req.Expiry - } - - authRegReq := types.NewRegisterAuthRequest(nodeToRegister) - log.Info().Msgf("new followup node registration using auth id: %s", newAuthID) h.state.SetAuthCacheEntry(newAuthID, authRegReq) @@ -335,6 +313,36 @@ func (h *Headscale) reqToNewRegisterResponse( }, nil } +// registrationDataFromRequest builds the RegistrationData payload stored +// in the auth cache for a pending registration. The original Hostinfo is +// retained so that consumers (auth callback, observability) see the +// fields the client originally announced; the bounded-LRU cap on the +// cache is what bounds the unauthenticated cache-fill DoS surface. +func registrationDataFromRequest( + req tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) *types.RegistrationData { + hostname := util.EnsureHostname( + req.Hostinfo.View(), + machineKey.String(), + req.NodeKey.String(), + ) + + regData := &types.RegistrationData{ + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostname: hostname, + Hostinfo: req.Hostinfo, + } + + if !req.Expiry.IsZero() { + expiry := req.Expiry + regData.Expiry = &expiry + } + + return regData +} + func (h *Headscale) handleRegisterWithAuthKey( req tailcfg.RegisterRequest, machineKey key.MachinePublic, @@ -408,50 +416,24 @@ func (h *Headscale) handleRegisterInteractive( return nil, fmt.Errorf("generating registration ID: %w", err) } - // Ensure we have a valid hostname - hostname := util.EnsureHostname( - req.Hostinfo.View(), - machineKey.String(), - req.NodeKey.String(), - ) - - // Ensure we have valid hostinfo - hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) if req.Hostinfo == nil { log.Warn(). Str("machine.key", machineKey.ShortString()). Str("node.key", req.NodeKey.ShortString()). - Str("generated.hostname", hostname). Msg("Received registration request with nil hostinfo, generated default hostname") } else if req.Hostinfo.Hostname == "" { log.Warn(). Str("machine.key", machineKey.ShortString()). Str("node.key", req.NodeKey.ShortString()). - Str("generated.hostname", hostname). Msg("Received registration request with empty hostname, generated default") } - hostinfo.Hostname = hostname - - nodeToRegister := types.Node{ - Hostname: hostname, - MachineKey: machineKey, - NodeKey: req.NodeKey, - Hostinfo: hostinfo, - LastSeen: new(time.Now()), - } - - if !req.Expiry.IsZero() { - nodeToRegister.Expiry = &req.Expiry - } - - authRegReq := types.NewRegisterAuthRequest(nodeToRegister) - - h.state.SetAuthCacheEntry( - authID, - authRegReq, + authRegReq := types.NewRegisterAuthRequest( + registrationDataFromRequest(req, machineKey), ) + h.state.SetAuthCacheEntry(authID, authRegReq) + log.Info().Msgf("starting node registration using auth id: %s", authID) return &tailcfg.RegisterResponse{ diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go index d9184f4f..b5fd448c 100644 --- a/hscontrol/auth_tags_test.go +++ b/hscontrol/auth_tags_test.go @@ -696,7 +696,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 1: Create user-owned node WITH expiry set clientExpiry := time.Now().Add(24 * time.Hour) registrationID1 := types.MustAuthID() - regEntry1 := types.NewRegisterAuthRequest(types.Node{ + regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "personal-to-tagged", @@ -718,7 +718,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) { // Step 2: Re-auth with tags (Personal → Tagged conversion) nodeKey2 := key.NewNode() registrationID2 := types.MustAuthID() - regEntry2 := types.NewRegisterAuthRequest(types.Node{ + regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "personal-to-tagged", @@ -768,7 +768,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { // Step 1: Create tagged node (expiry should be nil) registrationID1 := types.MustAuthID() - regEntry1 := types.NewRegisterAuthRequest(types.Node{ + regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "tagged-to-personal", @@ -790,7 +790,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) { nodeKey2 := key.NewNode() clientExpiry := time.Now().Add(48 * time.Hour) registrationID2 := types.MustAuthID() - regEntry2 := types.NewRegisterAuthRequest(types.Node{ + regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), NodeKey: nodeKey2.Public(), Hostname: "tagged-to-personal", diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 58dad28b..7fc333f3 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -681,7 +681,7 @@ func TestAuthenticationFlows(t *testing.T) { return "", err } - nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{ Hostname: "followup-success-node", }) app.state.SetAuthCacheEntry(regID, nodeToRegister) @@ -723,7 +723,7 @@ func TestAuthenticationFlows(t *testing.T) { return "", err } - nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{ Hostname: "followup-timeout-node", }) app.state.SetAuthCacheEntry(regID, nodeToRegister) @@ -1341,7 +1341,7 @@ func TestAuthenticationFlows(t *testing.T) { return "", err } - nodeToRegister := types.NewRegisterAuthRequest(types.Node{ + nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{ Hostname: "nil-response-node", }) app.state.SetAuthCacheEntry(regID, nodeToRegister) @@ -2618,7 +2618,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { cacheEntry, found := app.state.GetAuthCacheEntry(registrationID) require.True(t, found, "registration cache entry should exist") require.NotNil(t, cacheEntry, "cache entry should not be nil") - require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key") + require.Equal(t, req.NodeKey, cacheEntry.RegistrationData().NodeKey, "cache entry should have correct node key") } case stepTypeAuthCompletion: @@ -3570,7 +3570,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { // Simulate a registration cache entry (as would be created during web auth) registrationID := types.MustAuthID() - regEntry := types.NewRegisterAuthRequest(types.Node{ + regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "webauth-tags-node", @@ -3633,7 +3633,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { // Step 1: Initial registration with tags registrationID1 := types.MustAuthID() - regEntry1 := types.NewRegisterAuthRequest(types.Node{ + regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), NodeKey: nodeKey1.Public(), Hostname: "reauth-untag-node", @@ -3660,7 +3660,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { // Step 2: Reauth with EMPTY tags to untag nodeKey2 := key.NewNode() // New node key for reauth registrationID2 := types.MustAuthID() - regEntry2 := types.NewRegisterAuthRequest(types.Node{ + regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "reauth-untag-node", @@ -3746,7 +3746,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned nodeKey2 := key.NewNode() // New node key for reauth registrationID := types.MustAuthID() - regEntry := types.NewRegisterAuthRequest(types.Node{ + regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), // Same machine key NodeKey: nodeKey2.Public(), // Different node key (rotation) Hostname: "authkey-tagged-node", @@ -3945,7 +3945,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) { // This is what happens when running: headscale auth register --auth-id --user alice nodeKey2 := key.NewNode() registrationID := types.MustAuthID() - regEntry := types.NewRegisterAuthRequest(types.Node{ + regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{ MachineKey: machineKey.Public(), // Same machine key as the tagged node NodeKey: nodeKey2.Public(), Hostname: "tagged-orphan-node", diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 6e5f73d5..d95b7496 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -24,7 +24,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" - "zgo.at/zcache/v2" ) //go:embed schema.sql @@ -45,19 +44,15 @@ const ( ) type HSDatabase struct { - DB *gorm.DB - cfg *types.Config - regCache *zcache.Cache[types.AuthID, types.AuthRequest] + DB *gorm.DB + cfg *types.Config } // NewHeadscaleDatabase creates a new database connection and runs migrations. // It accepts the full configuration to allow migrations access to policy settings. // //nolint:gocyclo // complex database initialization with many migrations -func NewHeadscaleDatabase( - cfg *types.Config, - regCache *zcache.Cache[types.AuthID, types.AuthRequest], -) (*HSDatabase, error) { +func NewHeadscaleDatabase(cfg *types.Config) (*HSDatabase, error) { dbConn, err := openDB(cfg.Database) if err != nil { return nil, err @@ -838,9 +833,8 @@ WHERE tags IS NOT NULL AND tags != '[]' AND tags != ''; } db := HSDatabase{ - DB: dbConn, - cfg: cfg, - regCache: regCache, + DB: dbConn, + cfg: cfg, } return &db, err diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 151d9966..ee88d4ae 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -8,13 +8,11 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" - "zgo.at/zcache/v2" ) // TestSQLiteMigrationAndDataValidation tests specific SQLite migration scenarios @@ -162,10 +160,6 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { } } -func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { - return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) -} - func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { db, err := sql.Open("sqlite", dbPath) if err != nil { @@ -379,7 +373,6 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { Mode: types.PolicyModeDB, }, }, - emptyCache(), ) if err != nil { t.Fatalf("setting up database: %s", err) @@ -439,7 +432,6 @@ func TestSQLiteAllTestdataMigrations(t *testing.T) { Mode: types.PolicyModeDB, }, }, - emptyCache(), ) require.NoError(t, err) }) diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index a0e0bc6b..0bf0e69e 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -34,7 +34,6 @@ func newSQLiteTestDB() (*HSDatabase, error) { Mode: types.PolicyModeDB, }, }, - emptyCache(), ) if err != nil { return nil, err @@ -95,7 +94,6 @@ func newHeadscaleDBFromPostgresURL(t *testing.T, pu *url.URL) *HSDatabase { Mode: types.PolicyModeDB, }, }, - emptyCache(), ) if err != nil { t.Fatal(err) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 567efc8e..48d6e2d1 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -802,27 +802,16 @@ func (api headscaleV1APIServer) DebugCreateNode( Interface("route-str", request.GetRoutes()). Msg("Creating routes for node") - hostinfo := tailcfg.Hostinfo{ - RoutableIPs: routes, - OS: "TestOS", - Hostname: request.GetName(), - } - registrationId, err := types.AuthIDFromString(request.GetKey()) if err != nil { return nil, err } - newNode := types.Node{ + regData := &types.RegistrationData{ NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(), Hostname: request.GetName(), - User: user, - - Expiry: &time.Time{}, - LastSeen: &time.Time{}, - - Hostinfo: &hostinfo, + Expiry: &time.Time{}, // zero time, not nil — preserves proto JSON round-trip semantics } log.Debug(). @@ -830,10 +819,27 @@ func (api headscaleV1APIServer) DebugCreateNode( Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") - authRegReq := types.NewRegisterAuthRequest(newNode) + authRegReq := types.NewRegisterAuthRequest(regData) api.h.state.SetAuthCacheEntry(registrationId, authRegReq) - return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil + // Echo back a synthetic Node so the debug response surface stays + // stable. The actual node is created later by AuthApprove via + // HandleNodeFromAuthPath using the cached RegistrationData. + echoNode := types.Node{ + NodeKey: regData.NodeKey, + MachineKey: regData.MachineKey, + Hostname: regData.Hostname, + User: user, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: request.GetName(), + OS: "TestOS", + RoutableIPs: routes, + }, + } + + return &v1.DebugCreateNodeResponse{Node: echoNode.Proto()}, nil } func (api headscaleV1APIServer) Health( diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 1b89c214..e9fa1f5e 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "zgo.at/zcache/v2" ) var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher") @@ -109,11 +108,6 @@ var allBatcherFunctions = []batcherTestCase{ {"Default", NewBatcherAndMapper}, } -// emptyCache creates an empty registration cache for testing. -func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] { - return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour) -} - // Test configuration constants. const ( // Test data configuration. @@ -211,10 +205,7 @@ func setupBatcherWithTestData( } // Create database and populate it with test data - database, err := db.NewHeadscaleDatabase( - cfg, - emptyCache(), - ) + database, err := db.NewHeadscaleDatabase(cfg) if err != nil { t.Fatalf("setting up database: %s", err) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 304245de..010995bc 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -12,6 +12,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/hashicorp/golang-lru/v2/expirable" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" @@ -19,14 +20,17 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "zgo.at/zcache/v2" ) const ( randomByteSize = 16 defaultOAuthOptionsCount = 3 authCacheExpiration = time.Minute * 15 - authCacheCleanup = time.Minute * 20 + + // authCacheMaxEntries bounds the OIDC state→AuthInfo cache to prevent + // unauthenticated cache-fill DoS via repeated /register/{auth_id} or + // /auth/{auth_id} GETs that mint OIDC state cookies. + authCacheMaxEntries = 1024 ) var ( @@ -55,9 +59,10 @@ type AuthProviderOIDC struct { serverURL string cfg *types.OIDCConfig - // authCache holds auth information between - // the auth and the callback steps. - authCache *zcache.Cache[string, AuthInfo] + // authCache holds auth information between the auth and the callback + // steps. It is a bounded LRU keyed by OIDC state, evicting oldest + // entries to keep the cache footprint constant under attack. + authCache *expirable.LRU[string, AuthInfo] oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -84,9 +89,10 @@ func NewAuthProviderOIDC( Scopes: cfg.Scope, } - authCache := zcache.New[string, AuthInfo]( + authCache := expirable.NewLRU[string, AuthInfo]( + authCacheMaxEntries, + nil, authCacheExpiration, - authCacheCleanup, ) return &AuthProviderOIDC{ @@ -188,7 +194,7 @@ func (a *AuthProviderOIDC) authHandler( extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info - a.authCache.Set(state, registrationInfo) + a.authCache.Add(state, registrationInfo) authURL := a.oauth2Config.AuthCodeURL(state, extras...) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL) diff --git a/hscontrol/servertest/ephemeral_test.go b/hscontrol/servertest/ephemeral_test.go index d5a49832..e65cd1d0 100644 --- a/hscontrol/servertest/ephemeral_test.go +++ b/hscontrol/servertest/ephemeral_test.go @@ -18,10 +18,10 @@ import ( // fake-clock advancement, but three blockers prevent adoption // as of Go 1.26: // -// 1. zcache janitor goroutine: No Close() method; stopped only via -// runtime.SetFinalizer which runs outside synctest bubbles. -// - https://github.com/patrickmn/go-cache/issues/185 -// - https://github.com/golang/go/issues/75113 (Go1.27: finalizers inside bubble) +// 1. golang-lru/v2/expirable janitor goroutine: No Close() method; +// the deleteExpired ticker goroutine never exits because the done +// channel is never closed (documented as a v3 TODO upstream). +// - https://github.com/hashicorp/golang-lru/blob/v2.0.7/expirable/expirable_lru.go#L78-L81 // // 2. database/sql internal goroutines: Uses sync.RWMutex which is not // durably blocking in synctest, causing hangs. diff --git a/hscontrol/state/auth_cache_test.go b/hscontrol/state/auth_cache_test.go new file mode 100644 index 00000000..313db948 --- /dev/null +++ b/hscontrol/state/auth_cache_test.go @@ -0,0 +1,64 @@ +package state + +import ( + "testing" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestAuthCacheBoundedLRU verifies that the registration auth cache is +// bounded by a maximum entry count, that exceeding the maxEntries evicts the +// oldest entry, and that the eviction callback resolves the parked +// AuthRequest with ErrRegistrationExpired so any waiting goroutine wakes. +func TestAuthCacheBoundedLRU(t *testing.T) { + const maxEntries = 4 + + cache := expirable.NewLRU[types.AuthID, *types.AuthRequest]( + maxEntries, + func(_ types.AuthID, rn *types.AuthRequest) { + rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired}) + }, + time.Hour, // long TTL — we test eviction by size, not by time + ) + + entries := make([]*types.AuthRequest, 0, maxEntries+1) + ids := make([]types.AuthID, 0, maxEntries+1) + + for range maxEntries + 1 { + id := types.MustAuthID() + entry := types.NewAuthRequest() + cache.Add(id, entry) + ids = append(ids, id) + entries = append(entries, entry) + } + + // Cap should be respected. + assert.Equal(t, maxEntries, cache.Len(), "cache must not exceed the configured maxEntries") + + // The oldest entry must have been evicted. + _, ok := cache.Get(ids[0]) + assert.False(t, ok, "oldest entry must be evicted when maxEntries is exceeded") + + // The eviction callback must have woken the parked AuthRequest. + select { + case verdict := <-entries[0].WaitForAuth(): + require.False(t, verdict.Accept(), "evicted entry must not signal Accept") + require.ErrorIs(t, + verdict.Err, ErrRegistrationExpired, + "evicted entry must surface ErrRegistrationExpired, got: %v", + verdict.Err, + ) + case <-time.After(time.Second): + t.Fatal("eviction callback did not wake the parked AuthRequest") + } + + // All non-evicted entries must still be retrievable. + for i := 1; i <= maxEntries; i++ { + _, ok := cache.Get(ids[i]) + assert.True(t, ok, "non-evicted entry %d should still be in the cache", i) + } +} diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index abb34eb0..d0705d9c 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -211,15 +211,13 @@ func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy { // DebugRegistrationCache returns debug information about the registration cache. func (s *State) DebugRegistrationCache() map[string]any { - // The cache doesn't expose internal statistics, so we provide basic info - result := map[string]any{ - "type": "zcache", - "expiration": registerCacheExpiration.String(), - "cleanup": registerCacheCleanup.String(), - "status": "active", + return map[string]any{ + "type": "expirable-lru", + "expiration": registerCacheExpiration.String(), + "max_entries": defaultRegisterCacheMaxEntries, + "current_len": s.authCache.Len(), + "status": "active", } - - return result } // DebugConfig returns debug information about the current configuration. diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 5091dda9..d96223f5 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/hashicorp/golang-lru/v2/expirable" hsdb "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy/matcher" @@ -30,15 +31,18 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/views" - zcache "zgo.at/zcache/v2" ) const ( // registerCacheExpiration defines how long node registration entries remain in cache. registerCacheExpiration = time.Minute * 15 - // registerCacheCleanup defines the interval for cleaning up expired cache entries. - registerCacheCleanup = time.Minute * 20 + // defaultRegisterCacheMaxEntries is the default upper bound on the number + // of pending registration entries the auth cache will hold. With a 15-minute + // TTL and a stripped-down RegistrationData payload (~200 bytes per entry), + // 1024 entries cap the worst-case cache footprint at well under 1 MiB even + // under sustained unauthenticated cache-fill attempts. + defaultRegisterCacheMaxEntries = 1024 // defaultNodeStoreBatchSize is the default number of write operations to batch // before rebuilding the in-memory node snapshot. @@ -126,8 +130,12 @@ type State struct { // polMan handles policy evaluation and management polMan policy.PolicyManager - // authCache caches any pending authentication requests, from either auth type (Web and OIDC). - authCache *zcache.Cache[types.AuthID, types.AuthRequest] + // authCache holds any pending authentication requests from either auth + // type (Web and OIDC). It is a bounded LRU keyed by AuthID; oldest + // entries are evicted once the size cap is reached, and entries that + // time out have their auth verdict resolved with ErrRegistrationExpired + // via the eviction callback so any waiting goroutines wake. + authCache *expirable.LRU[types.AuthID, *types.AuthRequest] // primaryRoutes tracks primary route assignments for nodes primaryRoutes *routes.PrimaryRoutes @@ -166,26 +174,20 @@ func NewState(cfg *types.Config) (*State, error) { cacheExpiration = cfg.Tuning.RegisterCacheExpiration } - cacheCleanup := registerCacheCleanup - if cfg.Tuning.RegisterCacheCleanup != 0 { - cacheCleanup = cfg.Tuning.RegisterCacheCleanup + cacheMaxEntries := defaultRegisterCacheMaxEntries + if cfg.Tuning.RegisterCacheMaxEntries > 0 { + cacheMaxEntries = cfg.Tuning.RegisterCacheMaxEntries } - authCache := zcache.New[types.AuthID, types.AuthRequest]( - cacheExpiration, - cacheCleanup, - ) - - authCache.OnEvicted( - func(id types.AuthID, rn types.AuthRequest) { + authCache := expirable.NewLRU[types.AuthID, *types.AuthRequest]( + cacheMaxEntries, + func(id types.AuthID, rn *types.AuthRequest) { rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired}) }, + cacheExpiration, ) - db, err := hsdb.NewHeadscaleDatabase( - cfg, - authCache, - ) + db, err := hsdb.NewHeadscaleDatabase(cfg) if err != nil { return nil, fmt.Errorf("initializing database: %w", err) } @@ -1252,19 +1254,14 @@ func (s *State) DeletePreAuthKey(id uint64) error { return s.db.DeletePreAuthKey(id) } -// GetAuthCacheEntry retrieves a node registration from cache. +// GetAuthCacheEntry retrieves a pending auth request from the cache. func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) { - entry, found := s.authCache.Get(id) - if !found { - return nil, false - } - - return &entry, true + return s.authCache.Get(id) } -// SetAuthCacheEntry stores a node registration in cache. -func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) { - s.authCache.Set(id, entry) +// SetAuthCacheEntry stores a pending auth request in the cache. +func (s *State) SetAuthCacheEntry(id types.AuthID, entry *types.AuthRequest) { + s.authCache.Add(id, entry) } // SetLastSSHAuth records a successful SSH check authentication @@ -1296,25 +1293,6 @@ func (s *State) ClearSSHCheckAuth() { s.sshCheckAuth = make(map[sshCheckPair]time.Time) } -// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. -func logHostinfoValidation(nv types.NodeView, username, hostname string) { - if !nv.Hostinfo().Valid() { - log.Warn(). - Caller(). - EmbedObject(nv). - Str(zf.UserName, username). - Str(zf.GeneratedHostname, hostname). - Msg("Registration had nil hostinfo, generated default hostname") - } else if nv.Hostinfo().Hostname() == "" { - log.Warn(). - Caller(). - EmbedObject(nv). - Str(zf.UserName, username). - Str(zf.GeneratedHostname, hostname). - Msg("Registration had empty hostname, generated default") - } -} - // preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity. // If no existing node is provided, it creates new netinfo from the provided hostinfo. func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo { @@ -1349,15 +1327,15 @@ type newNodeParams struct { type authNodeUpdateParams struct { // Node to update; must be valid and in NodeStore. ExistingNode types.NodeView - // Client data: keys, hostinfo, endpoints. - RegEntry *types.AuthRequest + // Cached registration payload from the originating client request. + RegData *types.RegistrationData // Pre-validated hostinfo; NetInfo preserved from ExistingNode. ValidHostinfo *tailcfg.Hostinfo // Hostname from hostinfo, or generated from keys if client omits it. Hostname string // Auth user; may differ from ExistingNode.User() on conversion. User *types.User - // Overrides RegEntry.Node.Expiry; ignored for tagged nodes. + // Overrides RegData.Expiry; ignored for tagged nodes. Expiry *time.Time // Only used when IsConvertFromTag=true. RegisterMethod string @@ -1369,7 +1347,7 @@ type authNodeUpdateParams struct { // an existing node. It updates the node in NodeStore, processes RequestTags, and // persists changes to the database. func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { - regNv := params.RegEntry.Node() + regData := params.RegData // Log the operation type if params.IsConvertFromTag { log.Info(). @@ -1379,15 +1357,17 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView } else { log.Info(). Object("existing", params.ExistingNode). - Object("incoming", regNv). + Str("incoming.hostname", regData.Hostname). + Str("incoming.machine_key", regData.MachineKey.ShortString()). Msg("Updating existing node registration via reauth") } - // Process RequestTags during reauth (#2979) - // Due to json:",omitempty", we treat empty/nil as "clear tags" + // Process RequestTags during reauth (#2979). + // Due to json:",omitempty", empty/nil from the cached Hostinfo + // means "clear tags". var requestTags []string - if regNv.Hostinfo().Valid() { - requestTags = regNv.Hostinfo().RequestTags().AsSlice() + if regData.Hostinfo != nil { + requestTags = regData.Hostinfo.RequestTags } oldTags := params.ExistingNode.Tags().AsSlice() @@ -1405,8 +1385,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView // Update existing node in NodeStore - validation passed, safe to mutate updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { - node.NodeKey = regNv.NodeKey() - node.DiscoKey = regNv.DiscoKey() + node.NodeKey = regData.NodeKey + node.DiscoKey = regData.DiscoKey node.Hostname = params.Hostname // Preserve NetInfo from existing node when re-registering @@ -1417,7 +1397,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView params.ValidHostinfo, ) - node.Endpoints = regNv.Endpoints().AsSlice() + node.Endpoints = regData.Endpoints // Do NOT reset IsOnline here. Online status is managed exclusively by // Connect()/Disconnect() in the poll session lifecycle. Resetting it // during re-registration causes a false offline blip: the change @@ -1425,12 +1405,12 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView // to peers, even though Connect() will immediately set it back to true. node.LastSeen = new(time.Now()) - // Set RegisterMethod - for conversion this is the new method, - // for reauth we preserve the existing one from regEntry + // On conversion (tagged → user) we set the new register method. + // On plain reauth we preserve the existing node.RegisterMethod; + // the cached RegistrationData no longer carries it because the + // producer never populated it. if params.IsConvertFromTag { node.RegisterMethod = params.RegisterMethod - } else { - node.RegisterMethod = regNv.RegisterMethod() } // Track tagged status BEFORE processing tags @@ -1450,7 +1430,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = regNv.Expiry().Clone() + node.Expiry = regData.Expiry } case !wasTagged && isTagged: // Personal → Tagged: clear expiry (tagged nodes don't expire) @@ -1460,14 +1440,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = regNv.Expiry().Clone() + node.Expiry = regData.Expiry } case !isTagged: // Personal → Personal: update expiry from client if params.Expiry != nil { node.Expiry = params.Expiry } else { - node.Expiry = regNv.Expiry().Clone() + node.Expiry = regData.Expiry } } // Tagged → Tagged: keep existing expiry (nil) - no action needed @@ -1795,29 +1775,20 @@ func (s *State) HandleNodeFromAuthPath( return types.NodeView{}, change.Change{}, fmt.Errorf("finding user: %w", err) } - // Ensure we have a valid hostname from the registration cache entry - hostname := util.EnsureHostname( - regEntry.Node().Hostinfo(), - regEntry.Node().MachineKey().String(), - regEntry.Node().NodeKey().String(), - ) + regData := regEntry.RegistrationData() - // Ensure we have valid hostinfo + // Hostname was already validated/normalised at producer time. Build + // the initial Hostinfo from the cached client-supplied Hostinfo (or + // an empty stub if the client did not send one). + hostname := regData.Hostname hostinfo := &tailcfg.Hostinfo{} - if regEntry.Node().Hostinfo().Valid() { - hostinfo = regEntry.Node().Hostinfo().AsStruct() + if regData.Hostinfo != nil { + hostinfo = regData.Hostinfo.Clone() } - hostinfo.Hostname = hostname - logHostinfoValidation( - regEntry.Node(), - user.Name, - hostname, - ) - // Lookup existing nodes - machineKey := regEntry.Node().MachineKey() + machineKey := regData.MachineKey existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) @@ -1839,7 +1810,7 @@ func (s *State) HandleNodeFromAuthPath( // Common params for update operations updateParams := authNodeUpdateParams{ - RegEntry: regEntry, + RegData: regData, ValidHostinfo: hostinfo, Hostname: hostname, User: user, @@ -1874,7 +1845,7 @@ func (s *State) HandleNodeFromAuthPath( Msg("Creating new node for different user (same machine key exists for another user)") finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, hostinfo, + logger, user, regData, hostname, hostinfo, expiry, registrationMethod, existingNodeAnyUser, ) if err != nil { @@ -1882,7 +1853,7 @@ func (s *State) HandleNodeFromAuthPath( } } else { finalNode, err = s.createNewNodeFromAuth( - logger, user, regEntry, hostname, hostinfo, + logger, user, regData, hostname, hostinfo, expiry, registrationMethod, types.NodeView{}, ) if err != nil { @@ -1893,8 +1864,8 @@ func (s *State) HandleNodeFromAuthPath( // Signal to waiting clients regEntry.FinishAuth(types.AuthVerdict{Node: finalNode}) - // Delete from registration cache - s.authCache.Delete(authID) + // Remove from registration cache + s.authCache.Remove(authID) // Update policy managers usersChange, err := s.updatePolicyManagerUsers() @@ -1923,7 +1894,7 @@ func (s *State) HandleNodeFromAuthPath( func (s *State) createNewNodeFromAuth( logger zerolog.Logger, user *types.User, - regEntry *types.AuthRequest, + regData *types.RegistrationData, hostname string, validHostinfo *tailcfg.Hostinfo, expiry *time.Time, @@ -1936,13 +1907,13 @@ func (s *State) createNewNodeFromAuth( return s.createAndSaveNewNode(newNodeParams{ User: *user, - MachineKey: regEntry.Node().MachineKey(), - NodeKey: regEntry.Node().NodeKey(), - DiscoKey: regEntry.Node().DiscoKey(), + MachineKey: regData.MachineKey, + NodeKey: regData.NodeKey, + DiscoKey: regData.DiscoKey, Hostname: hostname, Hostinfo: validHostinfo, - Endpoints: regEntry.Node().Endpoints().AsSlice(), - Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()), + Endpoints: regData.Endpoints, + Expiry: cmp.Or(expiry, regData.Expiry), RegisterMethod: registrationMethod, ExistingNodeForNetinfo: existingNodeForNetinfo, }) diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 5992a3a1..1cc1b08a 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -221,40 +221,65 @@ func (r AuthID) Validate() error { return nil } -// AuthRequest represent a pending authentication request from a user or a node. -// If it is a registration request, the node field will be populate with the node that is trying to register. -// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel. -// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed. +// AuthRequest represents a pending authentication request from a user or a +// node. It carries the minimum data needed to either complete a node +// registration (regData populated) or signal the verdict of an interactive +// auth flow (no payload). Verdict delivery is via the finished channel; the +// closed flag guards FinishAuth against double-close. +// +// AuthRequest is always handled by pointer so the channel and atomic flag +// have a single canonical instance even when stored in caches that +// internally copy values. type AuthRequest struct { - node *Node + // regData is populated for node-registration flows (interactive web + // or OIDC). It carries only the minimal subset of registration data + // the auth callback needs to promote this request into a real node; + // see RegistrationData for the rationale behind keeping the payload + // small. + // + // nil for non-registration flows (e.g. SSH check). Use + // RegistrationData() to read it safely. + regData *RegistrationData + finished chan AuthVerdict closed *atomic.Bool } -func NewAuthRequest() AuthRequest { - return AuthRequest{ +// NewAuthRequest creates a pending auth request with no payload, suitable +// for non-registration flows that only need a verdict channel. +func NewAuthRequest() *AuthRequest { + return &AuthRequest{ finished: make(chan AuthVerdict, 1), closed: &atomic.Bool{}, } } -func NewRegisterAuthRequest(node Node) AuthRequest { - return AuthRequest{ - node: &node, +// NewRegisterAuthRequest creates a pending auth request carrying the +// minimal RegistrationData for a node-registration flow. The data is +// stored by pointer; callers must not mutate it after handing it off. +func NewRegisterAuthRequest(data *RegistrationData) *AuthRequest { + return &AuthRequest{ + regData: data, finished: make(chan AuthVerdict, 1), closed: &atomic.Bool{}, } } -// Node returns the node that is trying to register. -// It will panic if the AuthRequest is not a registration request. -// Can _only_ be used in the registration path. -func (rn *AuthRequest) Node() NodeView { - if rn.node == nil { - panic("Node can only be used in registration requests") +// RegistrationData returns the cached registration payload. It panics if +// called on an AuthRequest that was not created via +// NewRegisterAuthRequest, mirroring the previous Node() contract. +func (rn *AuthRequest) RegistrationData() *RegistrationData { + if rn.regData == nil { + panic("RegistrationData can only be used in registration requests") } - return rn.node.View() + return rn.regData +} + +// IsRegistration reports whether this auth request carries registration +// data (i.e. it was created via NewRegisterAuthRequest). +func (rn *AuthRequest) IsRegistration() bool { + return rn.regData != nil } func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) { diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 22abdd5e..bf50d308 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -278,14 +278,16 @@ type Tuning struct { // updates for connected clients. BatcherWorkers int - // RegisterCacheCleanup is the interval between cleanup operations for - // expired registration cache entries. - RegisterCacheCleanup time.Duration - // RegisterCacheExpiration is how long registration cache entries remain - // valid before being eligible for cleanup. + // valid before being eligible for eviction. RegisterCacheExpiration time.Duration + // RegisterCacheMaxEntries bounds the number of pending registration + // entries the auth cache will hold. Older entries are evicted (LRU) + // when the cap is reached, preventing unauthenticated cache-fill DoS. + // A value of 0 falls back to defaultRegisterCacheMaxEntries (1024). + RegisterCacheMaxEntries int + // NodeStoreBatchSize controls how many write operations are accumulated // before rebuilding the in-memory node snapshot. // @@ -1192,8 +1194,8 @@ func LoadServerConfig() (*Config, error) { return DefaultBatcherWorkers() }(), - RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"), + RegisterCacheMaxEntries: viper.GetInt("tuning.register_cache_max_entries"), NodeStoreBatchSize: viper.GetInt("tuning.node_store_batch_size"), NodeStoreBatchTimeout: viper.GetDuration("tuning.node_store_batch_timeout"), }, diff --git a/hscontrol/types/registration.go b/hscontrol/types/registration.go new file mode 100644 index 00000000..e6991b4d --- /dev/null +++ b/hscontrol/types/registration.go @@ -0,0 +1,55 @@ +package types + +import ( + "net/netip" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// RegistrationData is the payload cached for a pending node registration. +// It replaces the previous practice of caching a full *Node and carries +// only the fields the registration callback path actually consumes when +// promoting a pending registration to a real node. +// +// Combined with the bounded-LRU cache that holds these entries, this caps +// the worst-case memory footprint of unauthenticated cache-fill attempts +// at (max_entries × per_entry_size). The cache is sized so that the +// product is bounded to a few MiB even with attacker-supplied 1 MiB +// Hostinfos (the Noise body limit). +type RegistrationData struct { + // MachineKey is the cryptographic identity of the machine being + // registered. Required. + MachineKey key.MachinePublic + + // NodeKey is the cryptographic identity of the node session. + // Required. + NodeKey key.NodePublic + + // DiscoKey is the disco public key for peer-to-peer connections. + DiscoKey key.DiscoPublic + + // Hostname is the resolved hostname for the registering node. + // Already validated/normalised by EnsureHostname at producer time. + Hostname string + + // Hostinfo is the original Hostinfo from the RegisterRequest, + // stored so that the auth callback can populate the new node's + // initial Hostinfo (and so that observability/CLI consumers see + // fields like OS, OSVersion, and IPNVersion before the first + // MapRequest restores the live set). + // + // May be nil if the client did not send Hostinfo in the original + // RegisterRequest. + Hostinfo *tailcfg.Hostinfo + + // Endpoints is the initial set of WireGuard endpoints the node + // reported. The first MapRequest after registration overwrites + // this with the live set. + Endpoints []netip.AddrPort + + // Expiry is the optional client-requested expiry for this node. + // May be nil if the client did not request a specific expiry. + Expiry *time.Time +}