mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-10 19:17:25 +02:00
state: replace zcache with bounded LRU for auth cache
Replace zcache with golang-lru/v2/expirable for both the state auth cache and the OIDC state cache. Add tuning.register_cache_max_entries (default 1024) to cap the number of pending registration entries. Introduce types.RegistrationData to replace caching a full *Node; only the fields the registration callback path reads are retained. Remove the dead HSDatabase.regCache field. Drop zgo.at/zcache/v2 from go.mod.
This commit is contained in:
@@ -28,7 +28,7 @@ func bypassDatabase() (*db.HSDatabase, error) {
|
||||
return nil, fmt.Errorf("loading config: %w", err)
|
||||
}
|
||||
|
||||
d, err := db.NewHeadscaleDatabase(cfg, nil)
|
||||
d, err := db.NewHeadscaleDatabase(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening database: %w", err)
|
||||
}
|
||||
|
||||
@@ -445,12 +445,16 @@ taildrop:
|
||||
# When enabled, nodes can send files to other nodes owned by the same user.
|
||||
# Tagged devices and cross-user transfers are not permitted by Tailscale clients.
|
||||
enabled: true
|
||||
|
||||
# Advanced performance tuning parameters.
|
||||
# The defaults are carefully chosen and should rarely need adjustment.
|
||||
# Only modify these if you have identified a specific performance issue.
|
||||
#
|
||||
# tuning:
|
||||
# # Maximum number of pending registration entries in the auth cache.
|
||||
# # Oldest entries are evicted when the cap is reached.
|
||||
# #
|
||||
# # register_cache_max_entries: 1024
|
||||
#
|
||||
# # NodeStore write batching configuration.
|
||||
# # The NodeStore batches write operations before rebuilding peer relationships,
|
||||
# # which is computationally expensive. Batching reduces rebuild frequency.
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
||||
buildGo = pkgs.buildGo126Module;
|
||||
vendorHash = "sha256-G+yhItFhlp2XP/Zd9N4nMQf96YMQLuYd069H+Quewtk=";
|
||||
vendorHash = "sha256-x0xXxa7sjyDwWLq8fO0Z/pbPefctzctK3TAdBea7FtY=";
|
||||
in
|
||||
{
|
||||
headscale = buildGo {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -22,6 +22,7 @@ require (
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/jagottsicher/termcolor v1.0.2
|
||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||
github.com/ory/dockertest/v3 v3.12.0
|
||||
@@ -54,7 +55,6 @@ require (
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.31.1
|
||||
tailscale.com v1.96.5
|
||||
zgo.at/zcache/v2 v2.4.1
|
||||
zombiezen.com/go/postgrestest v1.0.1
|
||||
)
|
||||
|
||||
|
||||
3
go.sum
3
go.sum
@@ -252,7 +252,6 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||
github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4=
|
||||
github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU=
|
||||
@@ -682,7 +681,5 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB
|
||||
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
tailscale.com v1.96.5 h1:gNkfA/KSZAl6jCH9cj8urq00HRWItDDTtGsyATI89jA=
|
||||
tailscale.com v1.96.5/go.mod h1:/3lnZBYb2UEwnN0MNu2SDXUtT06AGd5k0s+OWx3WmcY=
|
||||
zgo.at/zcache/v2 v2.4.1 h1:Dfjoi8yI0Uq7NCc4lo2kaQJJmp9Mijo21gef+oJstbY=
|
||||
zgo.at/zcache/v2 v2.4.1/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
||||
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
||||
zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ=
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package hscontrol
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -302,31 +301,10 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname
|
||||
hostname := util.EnsureHostname(
|
||||
req.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
authRegReq := types.NewRegisterAuthRequest(
|
||||
registrationDataFromRequest(req, machineKey),
|
||||
)
|
||||
|
||||
// Ensure we have valid hostinfo
|
||||
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: new(time.Now()),
|
||||
}
|
||||
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
||||
|
||||
log.Info().Msgf("new followup node registration using auth id: %s", newAuthID)
|
||||
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
|
||||
|
||||
@@ -335,6 +313,36 @@ func (h *Headscale) reqToNewRegisterResponse(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// registrationDataFromRequest builds the RegistrationData payload stored
|
||||
// in the auth cache for a pending registration. The original Hostinfo is
|
||||
// retained so that consumers (auth callback, observability) see the
|
||||
// fields the client originally announced; the bounded-LRU cap on the
|
||||
// cache is what bounds the unauthenticated cache-fill DoS surface.
|
||||
func registrationDataFromRequest(
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) *types.RegistrationData {
|
||||
hostname := util.EnsureHostname(
|
||||
req.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
)
|
||||
|
||||
regData := &types.RegistrationData{
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostname: hostname,
|
||||
Hostinfo: req.Hostinfo,
|
||||
}
|
||||
|
||||
if !req.Expiry.IsZero() {
|
||||
expiry := req.Expiry
|
||||
regData.Expiry = &expiry
|
||||
}
|
||||
|
||||
return regData
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegisterWithAuthKey(
|
||||
req tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
@@ -408,50 +416,24 @@ func (h *Headscale) handleRegisterInteractive(
|
||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname
|
||||
hostname := util.EnsureHostname(
|
||||
req.Hostinfo.View(),
|
||||
machineKey.String(),
|
||||
req.NodeKey.String(),
|
||||
)
|
||||
|
||||
// Ensure we have valid hostinfo
|
||||
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
||||
if req.Hostinfo == nil {
|
||||
log.Warn().
|
||||
Str("machine.key", machineKey.ShortString()).
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with nil hostinfo, generated default hostname")
|
||||
} else if req.Hostinfo.Hostname == "" {
|
||||
log.Warn().
|
||||
Str("machine.key", machineKey.ShortString()).
|
||||
Str("node.key", req.NodeKey.ShortString()).
|
||||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with empty hostname, generated default")
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: new(time.Now()),
|
||||
}
|
||||
|
||||
if !req.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = &req.Expiry
|
||||
}
|
||||
|
||||
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
||||
|
||||
h.state.SetAuthCacheEntry(
|
||||
authID,
|
||||
authRegReq,
|
||||
authRegReq := types.NewRegisterAuthRequest(
|
||||
registrationDataFromRequest(req, machineKey),
|
||||
)
|
||||
|
||||
h.state.SetAuthCacheEntry(authID, authRegReq)
|
||||
|
||||
log.Info().Msgf("starting node registration using auth id: %s", authID)
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
|
||||
@@ -696,7 +696,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
||||
// Step 1: Create user-owned node WITH expiry set
|
||||
clientExpiry := time.Now().Add(24 * time.Hour)
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "personal-to-tagged",
|
||||
@@ -718,7 +718,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
||||
// Step 2: Re-auth with tags (Personal → Tagged conversion)
|
||||
nodeKey2 := key.NewNode()
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "personal-to-tagged",
|
||||
@@ -768,7 +768,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
||||
|
||||
// Step 1: Create tagged node (expiry should be nil)
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "tagged-to-personal",
|
||||
@@ -790,7 +790,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
||||
nodeKey2 := key.NewNode()
|
||||
clientExpiry := time.Now().Add(48 * time.Hour)
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "tagged-to-personal",
|
||||
|
||||
@@ -681,7 +681,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
Hostname: "followup-success-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
@@ -723,7 +723,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
Hostname: "followup-timeout-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
@@ -1341,7 +1341,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
Hostname: "nil-response-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
@@ -2618,7 +2618,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
require.True(t, found, "registration cache entry should exist")
|
||||
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
||||
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
|
||||
require.Equal(t, req.NodeKey, cacheEntry.RegistrationData().NodeKey, "cache entry should have correct node key")
|
||||
}
|
||||
|
||||
case stepTypeAuthCompletion:
|
||||
@@ -3570,7 +3570,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
||||
|
||||
// Simulate a registration cache entry (as would be created during web auth)
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "webauth-tags-node",
|
||||
@@ -3633,7 +3633,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
|
||||
// Step 1: Initial registration with tags
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "reauth-untag-node",
|
||||
@@ -3660,7 +3660,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
// Step 2: Reauth with EMPTY tags to untag
|
||||
nodeKey2 := key.NewNode() // New node key for reauth
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(), // Same machine key
|
||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||
Hostname: "reauth-untag-node",
|
||||
@@ -3746,7 +3746,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
||||
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
||||
nodeKey2 := key.NewNode() // New node key for reauth
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(), // Same machine key
|
||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||
Hostname: "authkey-tagged-node",
|
||||
@@ -3945,7 +3945,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
||||
// This is what happens when running: headscale auth register --auth-id <id> --user alice
|
||||
nodeKey2 := key.NewNode()
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "tagged-orphan-node",
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
@@ -45,19 +44,15 @@ const (
|
||||
)
|
||||
|
||||
type HSDatabase struct {
|
||||
DB *gorm.DB
|
||||
cfg *types.Config
|
||||
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||
DB *gorm.DB
|
||||
cfg *types.Config
|
||||
}
|
||||
|
||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||
// It accepts the full configuration to allow migrations access to policy settings.
|
||||
//
|
||||
//nolint:gocyclo // complex database initialization with many migrations
|
||||
func NewHeadscaleDatabase(
|
||||
cfg *types.Config,
|
||||
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
|
||||
) (*HSDatabase, error) {
|
||||
func NewHeadscaleDatabase(cfg *types.Config) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(cfg.Database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -838,9 +833,8 @@ WHERE tags IS NOT NULL AND tags != '[]' AND tags != '';
|
||||
}
|
||||
|
||||
db := HSDatabase{
|
||||
DB: dbConn,
|
||||
cfg: cfg,
|
||||
regCache: regCache,
|
||||
DB: dbConn,
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
return &db, err
|
||||
|
||||
@@ -8,13 +8,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
// TestSQLiteMigrationAndDataValidation tests specific SQLite migration scenarios
|
||||
@@ -162,10 +160,6 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||||
}
|
||||
|
||||
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
@@ -379,7 +373,6 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
||||
Mode: types.PolicyModeDB,
|
||||
},
|
||||
},
|
||||
emptyCache(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("setting up database: %s", err)
|
||||
@@ -439,7 +432,6 @@ func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||
Mode: types.PolicyModeDB,
|
||||
},
|
||||
},
|
||||
emptyCache(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
@@ -34,7 +34,6 @@ func newSQLiteTestDB() (*HSDatabase, error) {
|
||||
Mode: types.PolicyModeDB,
|
||||
},
|
||||
},
|
||||
emptyCache(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -95,7 +94,6 @@ func newHeadscaleDBFromPostgresURL(t *testing.T, pu *url.URL) *HSDatabase {
|
||||
Mode: types.PolicyModeDB,
|
||||
},
|
||||
},
|
||||
emptyCache(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -802,27 +802,16 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
Interface("route-str", request.GetRoutes()).
|
||||
Msg("Creating routes for node")
|
||||
|
||||
hostinfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: routes,
|
||||
OS: "TestOS",
|
||||
Hostname: request.GetName(),
|
||||
}
|
||||
|
||||
registrationId, err := types.AuthIDFromString(request.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newNode := types.Node{
|
||||
regData := &types.RegistrationData{
|
||||
NodeKey: key.NewNode().Public(),
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
Hostname: request.GetName(),
|
||||
User: user,
|
||||
|
||||
Expiry: &time.Time{},
|
||||
LastSeen: &time.Time{},
|
||||
|
||||
Hostinfo: &hostinfo,
|
||||
Expiry: &time.Time{}, // zero time, not nil — preserves proto JSON round-trip semantics
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
@@ -830,10 +819,27 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
||||
Str("registration_id", registrationId.String()).
|
||||
Msg("adding debug machine via CLI, appending to registration cache")
|
||||
|
||||
authRegReq := types.NewRegisterAuthRequest(newNode)
|
||||
authRegReq := types.NewRegisterAuthRequest(regData)
|
||||
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
|
||||
|
||||
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
|
||||
// Echo back a synthetic Node so the debug response surface stays
|
||||
// stable. The actual node is created later by AuthApprove via
|
||||
// HandleNodeFromAuthPath using the cached RegistrationData.
|
||||
echoNode := types.Node{
|
||||
NodeKey: regData.NodeKey,
|
||||
MachineKey: regData.MachineKey,
|
||||
Hostname: regData.Hostname,
|
||||
User: user,
|
||||
Expiry: &time.Time{},
|
||||
LastSeen: &time.Time{},
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
Hostname: request.GetName(),
|
||||
OS: "TestOS",
|
||||
RoutableIPs: routes,
|
||||
},
|
||||
}
|
||||
|
||||
return &v1.DebugCreateNodeResponse{Node: echoNode.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) Health(
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher")
|
||||
@@ -109,11 +108,6 @@ var allBatcherFunctions = []batcherTestCase{
|
||||
{"Default", NewBatcherAndMapper},
|
||||
}
|
||||
|
||||
// emptyCache creates an empty registration cache for testing.
|
||||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||||
}
|
||||
|
||||
// Test configuration constants.
|
||||
const (
|
||||
// Test data configuration.
|
||||
@@ -211,10 +205,7 @@ func setupBatcherWithTestData(
|
||||
}
|
||||
|
||||
// Create database and populate it with test data
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
cfg,
|
||||
emptyCache(),
|
||||
)
|
||||
database, err := db.NewHeadscaleDatabase(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("setting up database: %s", err)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/templates"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@@ -19,14 +20,17 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
randomByteSize = 16
|
||||
defaultOAuthOptionsCount = 3
|
||||
authCacheExpiration = time.Minute * 15
|
||||
authCacheCleanup = time.Minute * 20
|
||||
|
||||
// authCacheMaxEntries bounds the OIDC state→AuthInfo cache to prevent
|
||||
// unauthenticated cache-fill DoS via repeated /register/{auth_id} or
|
||||
// /auth/{auth_id} GETs that mint OIDC state cookies.
|
||||
authCacheMaxEntries = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,9 +59,10 @@ type AuthProviderOIDC struct {
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
|
||||
// authCache holds auth information between
|
||||
// the auth and the callback steps.
|
||||
authCache *zcache.Cache[string, AuthInfo]
|
||||
// authCache holds auth information between the auth and the callback
|
||||
// steps. It is a bounded LRU keyed by OIDC state, evicting oldest
|
||||
// entries to keep the cache footprint constant under attack.
|
||||
authCache *expirable.LRU[string, AuthInfo]
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
@@ -84,9 +89,10 @@ func NewAuthProviderOIDC(
|
||||
Scopes: cfg.Scope,
|
||||
}
|
||||
|
||||
authCache := zcache.New[string, AuthInfo](
|
||||
authCache := expirable.NewLRU[string, AuthInfo](
|
||||
authCacheMaxEntries,
|
||||
nil,
|
||||
authCacheExpiration,
|
||||
authCacheCleanup,
|
||||
)
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
@@ -188,7 +194,7 @@ func (a *AuthProviderOIDC) authHandler(
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
a.authCache.Set(state, registrationInfo)
|
||||
a.authCache.Add(state, registrationInfo)
|
||||
|
||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||
|
||||
@@ -18,10 +18,10 @@ import (
|
||||
// fake-clock advancement, but three blockers prevent adoption
|
||||
// as of Go 1.26:
|
||||
//
|
||||
// 1. zcache janitor goroutine: No Close() method; stopped only via
|
||||
// runtime.SetFinalizer which runs outside synctest bubbles.
|
||||
// - https://github.com/patrickmn/go-cache/issues/185
|
||||
// - https://github.com/golang/go/issues/75113 (Go1.27: finalizers inside bubble)
|
||||
// 1. golang-lru/v2/expirable janitor goroutine: No Close() method;
|
||||
// the deleteExpired ticker goroutine never exits because the done
|
||||
// channel is never closed (documented as a v3 TODO upstream).
|
||||
// - https://github.com/hashicorp/golang-lru/blob/v2.0.7/expirable/expirable_lru.go#L78-L81
|
||||
//
|
||||
// 2. database/sql internal goroutines: Uses sync.RWMutex which is not
|
||||
// durably blocking in synctest, causing hangs.
|
||||
|
||||
64
hscontrol/state/auth_cache_test.go
Normal file
64
hscontrol/state/auth_cache_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestAuthCacheBoundedLRU verifies that the registration auth cache is
|
||||
// bounded by a maximum entry count, that exceeding the maxEntries evicts the
|
||||
// oldest entry, and that the eviction callback resolves the parked
|
||||
// AuthRequest with ErrRegistrationExpired so any waiting goroutine wakes.
|
||||
func TestAuthCacheBoundedLRU(t *testing.T) {
|
||||
const maxEntries = 4
|
||||
|
||||
cache := expirable.NewLRU[types.AuthID, *types.AuthRequest](
|
||||
maxEntries,
|
||||
func(_ types.AuthID, rn *types.AuthRequest) {
|
||||
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||
},
|
||||
time.Hour, // long TTL — we test eviction by size, not by time
|
||||
)
|
||||
|
||||
entries := make([]*types.AuthRequest, 0, maxEntries+1)
|
||||
ids := make([]types.AuthID, 0, maxEntries+1)
|
||||
|
||||
for range maxEntries + 1 {
|
||||
id := types.MustAuthID()
|
||||
entry := types.NewAuthRequest()
|
||||
cache.Add(id, entry)
|
||||
ids = append(ids, id)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
// Cap should be respected.
|
||||
assert.Equal(t, maxEntries, cache.Len(), "cache must not exceed the configured maxEntries")
|
||||
|
||||
// The oldest entry must have been evicted.
|
||||
_, ok := cache.Get(ids[0])
|
||||
assert.False(t, ok, "oldest entry must be evicted when maxEntries is exceeded")
|
||||
|
||||
// The eviction callback must have woken the parked AuthRequest.
|
||||
select {
|
||||
case verdict := <-entries[0].WaitForAuth():
|
||||
require.False(t, verdict.Accept(), "evicted entry must not signal Accept")
|
||||
require.ErrorIs(t,
|
||||
verdict.Err, ErrRegistrationExpired,
|
||||
"evicted entry must surface ErrRegistrationExpired, got: %v",
|
||||
verdict.Err,
|
||||
)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("eviction callback did not wake the parked AuthRequest")
|
||||
}
|
||||
|
||||
// All non-evicted entries must still be retrievable.
|
||||
for i := 1; i <= maxEntries; i++ {
|
||||
_, ok := cache.Get(ids[i])
|
||||
assert.True(t, ok, "non-evicted entry %d should still be in the cache", i)
|
||||
}
|
||||
}
|
||||
@@ -211,15 +211,13 @@ func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy {
|
||||
|
||||
// DebugRegistrationCache returns debug information about the registration cache.
|
||||
func (s *State) DebugRegistrationCache() map[string]any {
|
||||
// The cache doesn't expose internal statistics, so we provide basic info
|
||||
result := map[string]any{
|
||||
"type": "zcache",
|
||||
"expiration": registerCacheExpiration.String(),
|
||||
"cleanup": registerCacheCleanup.String(),
|
||||
"status": "active",
|
||||
return map[string]any{
|
||||
"type": "expirable-lru",
|
||||
"expiration": registerCacheExpiration.String(),
|
||||
"max_entries": defaultRegisterCacheMaxEntries,
|
||||
"current_len": s.authCache.Len(),
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// DebugConfig returns debug information about the current configuration.
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
hsdb "github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
@@ -30,15 +31,18 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
zcache "zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
// registerCacheExpiration defines how long node registration entries remain in cache.
|
||||
registerCacheExpiration = time.Minute * 15
|
||||
|
||||
// registerCacheCleanup defines the interval for cleaning up expired cache entries.
|
||||
registerCacheCleanup = time.Minute * 20
|
||||
// defaultRegisterCacheMaxEntries is the default upper bound on the number
|
||||
// of pending registration entries the auth cache will hold. With a 15-minute
|
||||
// TTL and a stripped-down RegistrationData payload (~200 bytes per entry),
|
||||
// 1024 entries cap the worst-case cache footprint at well under 1 MiB even
|
||||
// under sustained unauthenticated cache-fill attempts.
|
||||
defaultRegisterCacheMaxEntries = 1024
|
||||
|
||||
// defaultNodeStoreBatchSize is the default number of write operations to batch
|
||||
// before rebuilding the in-memory node snapshot.
|
||||
@@ -126,8 +130,12 @@ type State struct {
|
||||
// polMan handles policy evaluation and management
|
||||
polMan policy.PolicyManager
|
||||
|
||||
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
|
||||
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
||||
// authCache holds any pending authentication requests from either auth
|
||||
// type (Web and OIDC). It is a bounded LRU keyed by AuthID; oldest
|
||||
// entries are evicted once the size cap is reached, and entries that
|
||||
// time out have their auth verdict resolved with ErrRegistrationExpired
|
||||
// via the eviction callback so any waiting goroutines wake.
|
||||
authCache *expirable.LRU[types.AuthID, *types.AuthRequest]
|
||||
|
||||
// primaryRoutes tracks primary route assignments for nodes
|
||||
primaryRoutes *routes.PrimaryRoutes
|
||||
@@ -166,26 +174,20 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
|
||||
}
|
||||
|
||||
cacheCleanup := registerCacheCleanup
|
||||
if cfg.Tuning.RegisterCacheCleanup != 0 {
|
||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
||||
cacheMaxEntries := defaultRegisterCacheMaxEntries
|
||||
if cfg.Tuning.RegisterCacheMaxEntries > 0 {
|
||||
cacheMaxEntries = cfg.Tuning.RegisterCacheMaxEntries
|
||||
}
|
||||
|
||||
authCache := zcache.New[types.AuthID, types.AuthRequest](
|
||||
cacheExpiration,
|
||||
cacheCleanup,
|
||||
)
|
||||
|
||||
authCache.OnEvicted(
|
||||
func(id types.AuthID, rn types.AuthRequest) {
|
||||
authCache := expirable.NewLRU[types.AuthID, *types.AuthRequest](
|
||||
cacheMaxEntries,
|
||||
func(id types.AuthID, rn *types.AuthRequest) {
|
||||
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||
},
|
||||
cacheExpiration,
|
||||
)
|
||||
|
||||
db, err := hsdb.NewHeadscaleDatabase(
|
||||
cfg,
|
||||
authCache,
|
||||
)
|
||||
db, err := hsdb.NewHeadscaleDatabase(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing database: %w", err)
|
||||
}
|
||||
@@ -1252,19 +1254,14 @@ func (s *State) DeletePreAuthKey(id uint64) error {
|
||||
return s.db.DeletePreAuthKey(id)
|
||||
}
|
||||
|
||||
// GetAuthCacheEntry retrieves a node registration from cache.
|
||||
// GetAuthCacheEntry retrieves a pending auth request from the cache.
|
||||
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
||||
entry, found := s.authCache.Get(id)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &entry, true
|
||||
return s.authCache.Get(id)
|
||||
}
|
||||
|
||||
// SetAuthCacheEntry stores a node registration in cache.
|
||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
||||
s.authCache.Set(id, entry)
|
||||
// SetAuthCacheEntry stores a pending auth request in the cache.
|
||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry *types.AuthRequest) {
|
||||
s.authCache.Add(id, entry)
|
||||
}
|
||||
|
||||
// SetLastSSHAuth records a successful SSH check authentication
|
||||
@@ -1296,25 +1293,6 @@ func (s *State) ClearSSHCheckAuth() {
|
||||
s.sshCheckAuth = make(map[sshCheckPair]time.Time)
|
||||
}
|
||||
|
||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
||||
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
||||
if !nv.Hostinfo().Valid() {
|
||||
log.Warn().
|
||||
Caller().
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had nil hostinfo, generated default hostname")
|
||||
} else if nv.Hostinfo().Hostname() == "" {
|
||||
log.Warn().
|
||||
Caller().
|
||||
EmbedObject(nv).
|
||||
Str(zf.UserName, username).
|
||||
Str(zf.GeneratedHostname, hostname).
|
||||
Msg("Registration had empty hostname, generated default")
|
||||
}
|
||||
}
|
||||
|
||||
// preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity.
|
||||
// If no existing node is provided, it creates new netinfo from the provided hostinfo.
|
||||
func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo {
|
||||
@@ -1349,15 +1327,15 @@ type newNodeParams struct {
|
||||
type authNodeUpdateParams struct {
|
||||
// Node to update; must be valid and in NodeStore.
|
||||
ExistingNode types.NodeView
|
||||
// Client data: keys, hostinfo, endpoints.
|
||||
RegEntry *types.AuthRequest
|
||||
// Cached registration payload from the originating client request.
|
||||
RegData *types.RegistrationData
|
||||
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
||||
ValidHostinfo *tailcfg.Hostinfo
|
||||
// Hostname from hostinfo, or generated from keys if client omits it.
|
||||
Hostname string
|
||||
// Auth user; may differ from ExistingNode.User() on conversion.
|
||||
User *types.User
|
||||
// Overrides RegEntry.Node.Expiry; ignored for tagged nodes.
|
||||
// Overrides RegData.Expiry; ignored for tagged nodes.
|
||||
Expiry *time.Time
|
||||
// Only used when IsConvertFromTag=true.
|
||||
RegisterMethod string
|
||||
@@ -1369,7 +1347,7 @@ type authNodeUpdateParams struct {
|
||||
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
||||
// persists changes to the database.
|
||||
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
||||
regNv := params.RegEntry.Node()
|
||||
regData := params.RegData
|
||||
// Log the operation type
|
||||
if params.IsConvertFromTag {
|
||||
log.Info().
|
||||
@@ -1379,15 +1357,17 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
} else {
|
||||
log.Info().
|
||||
Object("existing", params.ExistingNode).
|
||||
Object("incoming", regNv).
|
||||
Str("incoming.hostname", regData.Hostname).
|
||||
Str("incoming.machine_key", regData.MachineKey.ShortString()).
|
||||
Msg("Updating existing node registration via reauth")
|
||||
}
|
||||
|
||||
// Process RequestTags during reauth (#2979)
|
||||
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
||||
// Process RequestTags during reauth (#2979).
|
||||
// Due to json:",omitempty", empty/nil from the cached Hostinfo
|
||||
// means "clear tags".
|
||||
var requestTags []string
|
||||
if regNv.Hostinfo().Valid() {
|
||||
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
|
||||
if regData.Hostinfo != nil {
|
||||
requestTags = regData.Hostinfo.RequestTags
|
||||
}
|
||||
|
||||
oldTags := params.ExistingNode.Tags().AsSlice()
|
||||
@@ -1405,8 +1385,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
|
||||
// Update existing node in NodeStore - validation passed, safe to mutate
|
||||
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
||||
node.NodeKey = regNv.NodeKey()
|
||||
node.DiscoKey = regNv.DiscoKey()
|
||||
node.NodeKey = regData.NodeKey
|
||||
node.DiscoKey = regData.DiscoKey
|
||||
node.Hostname = params.Hostname
|
||||
|
||||
// Preserve NetInfo from existing node when re-registering
|
||||
@@ -1417,7 +1397,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
params.ValidHostinfo,
|
||||
)
|
||||
|
||||
node.Endpoints = regNv.Endpoints().AsSlice()
|
||||
node.Endpoints = regData.Endpoints
|
||||
// Do NOT reset IsOnline here. Online status is managed exclusively by
|
||||
// Connect()/Disconnect() in the poll session lifecycle. Resetting it
|
||||
// during re-registration causes a false offline blip: the change
|
||||
@@ -1425,12 +1405,12 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
// to peers, even though Connect() will immediately set it back to true.
|
||||
node.LastSeen = new(time.Now())
|
||||
|
||||
// Set RegisterMethod - for conversion this is the new method,
|
||||
// for reauth we preserve the existing one from regEntry
|
||||
// On conversion (tagged → user) we set the new register method.
|
||||
// On plain reauth we preserve the existing node.RegisterMethod;
|
||||
// the cached RegistrationData no longer carries it because the
|
||||
// producer never populated it.
|
||||
if params.IsConvertFromTag {
|
||||
node.RegisterMethod = params.RegisterMethod
|
||||
} else {
|
||||
node.RegisterMethod = regNv.RegisterMethod()
|
||||
}
|
||||
|
||||
// Track tagged status BEFORE processing tags
|
||||
@@ -1450,7 +1430,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
node.Expiry = regData.Expiry
|
||||
}
|
||||
case !wasTagged && isTagged:
|
||||
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
||||
@@ -1460,14 +1440,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
node.Expiry = regData.Expiry
|
||||
}
|
||||
case !isTagged:
|
||||
// Personal → Personal: update expiry from client
|
||||
if params.Expiry != nil {
|
||||
node.Expiry = params.Expiry
|
||||
} else {
|
||||
node.Expiry = regNv.Expiry().Clone()
|
||||
node.Expiry = regData.Expiry
|
||||
}
|
||||
}
|
||||
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
||||
@@ -1795,29 +1775,20 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("finding user: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we have a valid hostname from the registration cache entry
|
||||
hostname := util.EnsureHostname(
|
||||
regEntry.Node().Hostinfo(),
|
||||
regEntry.Node().MachineKey().String(),
|
||||
regEntry.Node().NodeKey().String(),
|
||||
)
|
||||
regData := regEntry.RegistrationData()
|
||||
|
||||
// Ensure we have valid hostinfo
|
||||
// Hostname was already validated/normalised at producer time. Build
|
||||
// the initial Hostinfo from the cached client-supplied Hostinfo (or
|
||||
// an empty stub if the client did not send one).
|
||||
hostname := regData.Hostname
|
||||
hostinfo := &tailcfg.Hostinfo{}
|
||||
if regEntry.Node().Hostinfo().Valid() {
|
||||
hostinfo = regEntry.Node().Hostinfo().AsStruct()
|
||||
if regData.Hostinfo != nil {
|
||||
hostinfo = regData.Hostinfo.Clone()
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
logHostinfoValidation(
|
||||
regEntry.Node(),
|
||||
user.Name,
|
||||
hostname,
|
||||
)
|
||||
|
||||
// Lookup existing nodes
|
||||
machineKey := regEntry.Node().MachineKey()
|
||||
machineKey := regData.MachineKey
|
||||
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
||||
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
||||
|
||||
@@ -1839,7 +1810,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
|
||||
// Common params for update operations
|
||||
updateParams := authNodeUpdateParams{
|
||||
RegEntry: regEntry,
|
||||
RegData: regData,
|
||||
ValidHostinfo: hostinfo,
|
||||
Hostname: hostname,
|
||||
User: user,
|
||||
@@ -1874,7 +1845,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
Msg("Creating new node for different user (same machine key exists for another user)")
|
||||
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
logger, user, regData, hostname, hostinfo,
|
||||
expiry, registrationMethod, existingNodeAnyUser,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1882,7 +1853,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
}
|
||||
} else {
|
||||
finalNode, err = s.createNewNodeFromAuth(
|
||||
logger, user, regEntry, hostname, hostinfo,
|
||||
logger, user, regData, hostname, hostinfo,
|
||||
expiry, registrationMethod, types.NodeView{},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -1893,8 +1864,8 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
// Signal to waiting clients
|
||||
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
|
||||
|
||||
// Delete from registration cache
|
||||
s.authCache.Delete(authID)
|
||||
// Remove from registration cache
|
||||
s.authCache.Remove(authID)
|
||||
|
||||
// Update policy managers
|
||||
usersChange, err := s.updatePolicyManagerUsers()
|
||||
@@ -1923,7 +1894,7 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) createNewNodeFromAuth(
|
||||
logger zerolog.Logger,
|
||||
user *types.User,
|
||||
regEntry *types.AuthRequest,
|
||||
regData *types.RegistrationData,
|
||||
hostname string,
|
||||
validHostinfo *tailcfg.Hostinfo,
|
||||
expiry *time.Time,
|
||||
@@ -1936,13 +1907,13 @@ func (s *State) createNewNodeFromAuth(
|
||||
|
||||
return s.createAndSaveNewNode(newNodeParams{
|
||||
User: *user,
|
||||
MachineKey: regEntry.Node().MachineKey(),
|
||||
NodeKey: regEntry.Node().NodeKey(),
|
||||
DiscoKey: regEntry.Node().DiscoKey(),
|
||||
MachineKey: regData.MachineKey,
|
||||
NodeKey: regData.NodeKey,
|
||||
DiscoKey: regData.DiscoKey,
|
||||
Hostname: hostname,
|
||||
Hostinfo: validHostinfo,
|
||||
Endpoints: regEntry.Node().Endpoints().AsSlice(),
|
||||
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
|
||||
Endpoints: regData.Endpoints,
|
||||
Expiry: cmp.Or(expiry, regData.Expiry),
|
||||
RegisterMethod: registrationMethod,
|
||||
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
||||
})
|
||||
|
||||
@@ -221,40 +221,65 @@ func (r AuthID) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthRequest represent a pending authentication request from a user or a node.
|
||||
// If it is a registration request, the node field will be populate with the node that is trying to register.
|
||||
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
|
||||
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
|
||||
// AuthRequest represents a pending authentication request from a user or a
|
||||
// node. It carries the minimum data needed to either complete a node
|
||||
// registration (regData populated) or signal the verdict of an interactive
|
||||
// auth flow (no payload). Verdict delivery is via the finished channel; the
|
||||
// closed flag guards FinishAuth against double-close.
|
||||
//
|
||||
// AuthRequest is always handled by pointer so the channel and atomic flag
|
||||
// have a single canonical instance even when stored in caches that
|
||||
// internally copy values.
|
||||
type AuthRequest struct {
|
||||
node *Node
|
||||
// regData is populated for node-registration flows (interactive web
|
||||
// or OIDC). It carries only the minimal subset of registration data
|
||||
// the auth callback needs to promote this request into a real node;
|
||||
// see RegistrationData for the rationale behind keeping the payload
|
||||
// small.
|
||||
//
|
||||
// nil for non-registration flows (e.g. SSH check). Use
|
||||
// RegistrationData() to read it safely.
|
||||
regData *RegistrationData
|
||||
|
||||
finished chan AuthVerdict
|
||||
closed *atomic.Bool
|
||||
}
|
||||
|
||||
func NewAuthRequest() AuthRequest {
|
||||
return AuthRequest{
|
||||
// NewAuthRequest creates a pending auth request with no payload, suitable
|
||||
// for non-registration flows that only need a verdict channel.
|
||||
func NewAuthRequest() *AuthRequest {
|
||||
return &AuthRequest{
|
||||
finished: make(chan AuthVerdict, 1),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func NewRegisterAuthRequest(node Node) AuthRequest {
|
||||
return AuthRequest{
|
||||
node: &node,
|
||||
// NewRegisterAuthRequest creates a pending auth request carrying the
|
||||
// minimal RegistrationData for a node-registration flow. The data is
|
||||
// stored by pointer; callers must not mutate it after handing it off.
|
||||
func NewRegisterAuthRequest(data *RegistrationData) *AuthRequest {
|
||||
return &AuthRequest{
|
||||
regData: data,
|
||||
finished: make(chan AuthVerdict, 1),
|
||||
closed: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// Node returns the node that is trying to register.
|
||||
// It will panic if the AuthRequest is not a registration request.
|
||||
// Can _only_ be used in the registration path.
|
||||
func (rn *AuthRequest) Node() NodeView {
|
||||
if rn.node == nil {
|
||||
panic("Node can only be used in registration requests")
|
||||
// RegistrationData returns the cached registration payload. It panics if
|
||||
// called on an AuthRequest that was not created via
|
||||
// NewRegisterAuthRequest, mirroring the previous Node() contract.
|
||||
func (rn *AuthRequest) RegistrationData() *RegistrationData {
|
||||
if rn.regData == nil {
|
||||
panic("RegistrationData can only be used in registration requests")
|
||||
}
|
||||
|
||||
return rn.node.View()
|
||||
return rn.regData
|
||||
}
|
||||
|
||||
// IsRegistration reports whether this auth request carries registration
|
||||
// data (i.e. it was created via NewRegisterAuthRequest).
|
||||
func (rn *AuthRequest) IsRegistration() bool {
|
||||
return rn.regData != nil
|
||||
}
|
||||
|
||||
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
|
||||
|
||||
@@ -278,14 +278,16 @@ type Tuning struct {
|
||||
// updates for connected clients.
|
||||
BatcherWorkers int
|
||||
|
||||
// RegisterCacheCleanup is the interval between cleanup operations for
|
||||
// expired registration cache entries.
|
||||
RegisterCacheCleanup time.Duration
|
||||
|
||||
// RegisterCacheExpiration is how long registration cache entries remain
|
||||
// valid before being eligible for cleanup.
|
||||
// valid before being eligible for eviction.
|
||||
RegisterCacheExpiration time.Duration
|
||||
|
||||
// RegisterCacheMaxEntries bounds the number of pending registration
|
||||
// entries the auth cache will hold. Older entries are evicted (LRU)
|
||||
// when the cap is reached, preventing unauthenticated cache-fill DoS.
|
||||
// A value of 0 falls back to defaultRegisterCacheMaxEntries (1024).
|
||||
RegisterCacheMaxEntries int
|
||||
|
||||
// NodeStoreBatchSize controls how many write operations are accumulated
|
||||
// before rebuilding the in-memory node snapshot.
|
||||
//
|
||||
@@ -1192,8 +1194,8 @@ func LoadServerConfig() (*Config, error) {
|
||||
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
||||
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
|
||||
RegisterCacheMaxEntries: viper.GetInt("tuning.register_cache_max_entries"),
|
||||
NodeStoreBatchSize: viper.GetInt("tuning.node_store_batch_size"),
|
||||
NodeStoreBatchTimeout: viper.GetDuration("tuning.node_store_batch_timeout"),
|
||||
},
|
||||
|
||||
55
hscontrol/types/registration.go
Normal file
55
hscontrol/types/registration.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// RegistrationData is the payload cached for a pending node registration.
|
||||
// It replaces the previous practice of caching a full *Node and carries
|
||||
// only the fields the registration callback path actually consumes when
|
||||
// promoting a pending registration to a real node.
|
||||
//
|
||||
// Combined with the bounded-LRU cache that holds these entries, this caps
|
||||
// the worst-case memory footprint of unauthenticated cache-fill attempts
|
||||
// at (max_entries × per_entry_size). The cache is sized so that the
|
||||
// product is bounded to a few MiB even with attacker-supplied 1 MiB
|
||||
// Hostinfos (the Noise body limit).
|
||||
type RegistrationData struct {
|
||||
// MachineKey is the cryptographic identity of the machine being
|
||||
// registered. Required.
|
||||
MachineKey key.MachinePublic
|
||||
|
||||
// NodeKey is the cryptographic identity of the node session.
|
||||
// Required.
|
||||
NodeKey key.NodePublic
|
||||
|
||||
// DiscoKey is the disco public key for peer-to-peer connections.
|
||||
DiscoKey key.DiscoPublic
|
||||
|
||||
// Hostname is the resolved hostname for the registering node.
|
||||
// Already validated/normalised by EnsureHostname at producer time.
|
||||
Hostname string
|
||||
|
||||
// Hostinfo is the original Hostinfo from the RegisterRequest,
|
||||
// stored so that the auth callback can populate the new node's
|
||||
// initial Hostinfo (and so that observability/CLI consumers see
|
||||
// fields like OS, OSVersion, and IPNVersion before the first
|
||||
// MapRequest restores the live set).
|
||||
//
|
||||
// May be nil if the client did not send Hostinfo in the original
|
||||
// RegisterRequest.
|
||||
Hostinfo *tailcfg.Hostinfo
|
||||
|
||||
// Endpoints is the initial set of WireGuard endpoints the node
|
||||
// reported. The first MapRequest after registration overwrites
|
||||
// this with the live set.
|
||||
Endpoints []netip.AddrPort
|
||||
|
||||
// Expiry is the optional client-requested expiry for this node.
|
||||
// May be nil if the client did not request a specific expiry.
|
||||
Expiry *time.Time
|
||||
}
|
||||
Reference in New Issue
Block a user