mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-23 17:18:50 +02:00
state: replace zcache with bounded LRU for auth cache
Replace zcache with golang-lru/v2/expirable for both the state auth cache and the OIDC state cache. Add tuning.register_cache_max_entries (default 1024) to cap the number of pending registration entries. Introduce types.RegistrationData to replace caching a full *Node; only the fields the registration callback path reads are retained. Remove the dead HSDatabase.regCache field. Drop zgo.at/zcache/v2 from go.mod.
This commit is contained in:
@@ -28,7 +28,7 @@ func bypassDatabase() (*db.HSDatabase, error) {
|
|||||||
return nil, fmt.Errorf("loading config: %w", err)
|
return nil, fmt.Errorf("loading config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := db.NewHeadscaleDatabase(cfg, nil)
|
d, err := db.NewHeadscaleDatabase(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("opening database: %w", err)
|
return nil, fmt.Errorf("opening database: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -445,12 +445,16 @@ taildrop:
|
|||||||
# When enabled, nodes can send files to other nodes owned by the same user.
|
# When enabled, nodes can send files to other nodes owned by the same user.
|
||||||
# Tagged devices and cross-user transfers are not permitted by Tailscale clients.
|
# Tagged devices and cross-user transfers are not permitted by Tailscale clients.
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|
||||||
# Advanced performance tuning parameters.
|
# Advanced performance tuning parameters.
|
||||||
# The defaults are carefully chosen and should rarely need adjustment.
|
# The defaults are carefully chosen and should rarely need adjustment.
|
||||||
# Only modify these if you have identified a specific performance issue.
|
# Only modify these if you have identified a specific performance issue.
|
||||||
#
|
#
|
||||||
# tuning:
|
# tuning:
|
||||||
|
# # Maximum number of pending registration entries in the auth cache.
|
||||||
|
# # Oldest entries are evicted when the cap is reached.
|
||||||
|
# #
|
||||||
|
# # register_cache_max_entries: 1024
|
||||||
|
#
|
||||||
# # NodeStore write batching configuration.
|
# # NodeStore write batching configuration.
|
||||||
# # The NodeStore batches write operations before rebuilding peer relationships,
|
# # The NodeStore batches write operations before rebuilding peer relationships,
|
||||||
# # which is computationally expensive. Batching reduces rebuild frequency.
|
# # which is computationally expensive. Batching reduces rebuild frequency.
|
||||||
|
|||||||
@@ -27,7 +27,7 @@
|
|||||||
let
|
let
|
||||||
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
|
||||||
buildGo = pkgs.buildGo126Module;
|
buildGo = pkgs.buildGo126Module;
|
||||||
vendorHash = "sha256-G+yhItFhlp2XP/Zd9N4nMQf96YMQLuYd069H+Quewtk=";
|
vendorHash = "sha256-x0xXxa7sjyDwWLq8fO0Z/pbPefctzctK3TAdBea7FtY=";
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
headscale = buildGo {
|
headscale = buildGo {
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -22,6 +22,7 @@ require (
|
|||||||
github.com/google/go-cmp v0.7.0
|
github.com/google/go-cmp v0.7.0
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||||
github.com/jagottsicher/termcolor v1.0.2
|
github.com/jagottsicher/termcolor v1.0.2
|
||||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||||
github.com/ory/dockertest/v3 v3.12.0
|
github.com/ory/dockertest/v3 v3.12.0
|
||||||
@@ -54,7 +55,6 @@ require (
|
|||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/gorm v1.31.1
|
gorm.io/gorm v1.31.1
|
||||||
tailscale.com v1.96.5
|
tailscale.com v1.96.5
|
||||||
zgo.at/zcache/v2 v2.4.1
|
|
||||||
zombiezen.com/go/postgrestest v1.0.1
|
zombiezen.com/go/postgrestest v1.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
3
go.sum
3
go.sum
@@ -252,7 +252,6 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz
|
|||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
|
||||||
github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4=
|
github.com/hashicorp/go-version v1.8.0 h1:KAkNb1HAiZd1ukkxDFGmokVZe1Xy9HG6NUp+bPle2i4=
|
||||||
github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
github.com/hashicorp/go-version v1.8.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||||
github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU=
|
github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU=
|
||||||
@@ -682,7 +681,5 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB
|
|||||||
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
tailscale.com v1.96.5 h1:gNkfA/KSZAl6jCH9cj8urq00HRWItDDTtGsyATI89jA=
|
tailscale.com v1.96.5 h1:gNkfA/KSZAl6jCH9cj8urq00HRWItDDTtGsyATI89jA=
|
||||||
tailscale.com v1.96.5/go.mod h1:/3lnZBYb2UEwnN0MNu2SDXUtT06AGd5k0s+OWx3WmcY=
|
tailscale.com v1.96.5/go.mod h1:/3lnZBYb2UEwnN0MNu2SDXUtT06AGd5k0s+OWx3WmcY=
|
||||||
zgo.at/zcache/v2 v2.4.1 h1:Dfjoi8yI0Uq7NCc4lo2kaQJJmp9Mijo21gef+oJstbY=
|
|
||||||
zgo.at/zcache/v2 v2.4.1/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
|
||||||
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
||||||
zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ=
|
zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ=
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -302,31 +301,10 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||||||
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have a valid hostname
|
authRegReq := types.NewRegisterAuthRequest(
|
||||||
hostname := util.EnsureHostname(
|
registrationDataFromRequest(req, machineKey),
|
||||||
req.Hostinfo.View(),
|
|
||||||
machineKey.String(),
|
|
||||||
req.NodeKey.String(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Ensure we have valid hostinfo
|
|
||||||
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
|
||||||
hostinfo.Hostname = hostname
|
|
||||||
|
|
||||||
nodeToRegister := types.Node{
|
|
||||||
Hostname: hostname,
|
|
||||||
MachineKey: machineKey,
|
|
||||||
NodeKey: req.NodeKey,
|
|
||||||
Hostinfo: hostinfo,
|
|
||||||
LastSeen: new(time.Now()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if !req.Expiry.IsZero() {
|
|
||||||
nodeToRegister.Expiry = &req.Expiry
|
|
||||||
}
|
|
||||||
|
|
||||||
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
|
||||||
|
|
||||||
log.Info().Msgf("new followup node registration using auth id: %s", newAuthID)
|
log.Info().Msgf("new followup node registration using auth id: %s", newAuthID)
|
||||||
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
|
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
|
||||||
|
|
||||||
@@ -335,6 +313,36 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registrationDataFromRequest builds the RegistrationData payload stored
|
||||||
|
// in the auth cache for a pending registration. The original Hostinfo is
|
||||||
|
// retained so that consumers (auth callback, observability) see the
|
||||||
|
// fields the client originally announced; the bounded-LRU cap on the
|
||||||
|
// cache is what bounds the unauthenticated cache-fill DoS surface.
|
||||||
|
func registrationDataFromRequest(
|
||||||
|
req tailcfg.RegisterRequest,
|
||||||
|
machineKey key.MachinePublic,
|
||||||
|
) *types.RegistrationData {
|
||||||
|
hostname := util.EnsureHostname(
|
||||||
|
req.Hostinfo.View(),
|
||||||
|
machineKey.String(),
|
||||||
|
req.NodeKey.String(),
|
||||||
|
)
|
||||||
|
|
||||||
|
regData := &types.RegistrationData{
|
||||||
|
MachineKey: machineKey,
|
||||||
|
NodeKey: req.NodeKey,
|
||||||
|
Hostname: hostname,
|
||||||
|
Hostinfo: req.Hostinfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !req.Expiry.IsZero() {
|
||||||
|
expiry := req.Expiry
|
||||||
|
regData.Expiry = &expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
return regData
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleRegisterWithAuthKey(
|
func (h *Headscale) handleRegisterWithAuthKey(
|
||||||
req tailcfg.RegisterRequest,
|
req tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
@@ -408,50 +416,24 @@ func (h *Headscale) handleRegisterInteractive(
|
|||||||
return nil, fmt.Errorf("generating registration ID: %w", err)
|
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have a valid hostname
|
|
||||||
hostname := util.EnsureHostname(
|
|
||||||
req.Hostinfo.View(),
|
|
||||||
machineKey.String(),
|
|
||||||
req.NodeKey.String(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Ensure we have valid hostinfo
|
|
||||||
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
|
|
||||||
if req.Hostinfo == nil {
|
if req.Hostinfo == nil {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("machine.key", machineKey.ShortString()).
|
Str("machine.key", machineKey.ShortString()).
|
||||||
Str("node.key", req.NodeKey.ShortString()).
|
Str("node.key", req.NodeKey.ShortString()).
|
||||||
Str("generated.hostname", hostname).
|
|
||||||
Msg("Received registration request with nil hostinfo, generated default hostname")
|
Msg("Received registration request with nil hostinfo, generated default hostname")
|
||||||
} else if req.Hostinfo.Hostname == "" {
|
} else if req.Hostinfo.Hostname == "" {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("machine.key", machineKey.ShortString()).
|
Str("machine.key", machineKey.ShortString()).
|
||||||
Str("node.key", req.NodeKey.ShortString()).
|
Str("node.key", req.NodeKey.ShortString()).
|
||||||
Str("generated.hostname", hostname).
|
|
||||||
Msg("Received registration request with empty hostname, generated default")
|
Msg("Received registration request with empty hostname, generated default")
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.Hostname = hostname
|
authRegReq := types.NewRegisterAuthRequest(
|
||||||
|
registrationDataFromRequest(req, machineKey),
|
||||||
nodeToRegister := types.Node{
|
|
||||||
Hostname: hostname,
|
|
||||||
MachineKey: machineKey,
|
|
||||||
NodeKey: req.NodeKey,
|
|
||||||
Hostinfo: hostinfo,
|
|
||||||
LastSeen: new(time.Now()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if !req.Expiry.IsZero() {
|
|
||||||
nodeToRegister.Expiry = &req.Expiry
|
|
||||||
}
|
|
||||||
|
|
||||||
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
|
|
||||||
|
|
||||||
h.state.SetAuthCacheEntry(
|
|
||||||
authID,
|
|
||||||
authRegReq,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
h.state.SetAuthCacheEntry(authID, authRegReq)
|
||||||
|
|
||||||
log.Info().Msgf("starting node registration using auth id: %s", authID)
|
log.Info().Msgf("starting node registration using auth id: %s", authID)
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
|
|||||||
@@ -696,7 +696,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
|||||||
// Step 1: Create user-owned node WITH expiry set
|
// Step 1: Create user-owned node WITH expiry set
|
||||||
clientExpiry := time.Now().Add(24 * time.Hour)
|
clientExpiry := time.Now().Add(24 * time.Hour)
|
||||||
registrationID1 := types.MustAuthID()
|
registrationID1 := types.MustAuthID()
|
||||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey1.Public(),
|
NodeKey: nodeKey1.Public(),
|
||||||
Hostname: "personal-to-tagged",
|
Hostname: "personal-to-tagged",
|
||||||
@@ -718,7 +718,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
|
|||||||
// Step 2: Re-auth with tags (Personal → Tagged conversion)
|
// Step 2: Re-auth with tags (Personal → Tagged conversion)
|
||||||
nodeKey2 := key.NewNode()
|
nodeKey2 := key.NewNode()
|
||||||
registrationID2 := types.MustAuthID()
|
registrationID2 := types.MustAuthID()
|
||||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey2.Public(),
|
NodeKey: nodeKey2.Public(),
|
||||||
Hostname: "personal-to-tagged",
|
Hostname: "personal-to-tagged",
|
||||||
@@ -768,7 +768,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
|||||||
|
|
||||||
// Step 1: Create tagged node (expiry should be nil)
|
// Step 1: Create tagged node (expiry should be nil)
|
||||||
registrationID1 := types.MustAuthID()
|
registrationID1 := types.MustAuthID()
|
||||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey1.Public(),
|
NodeKey: nodeKey1.Public(),
|
||||||
Hostname: "tagged-to-personal",
|
Hostname: "tagged-to-personal",
|
||||||
@@ -790,7 +790,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
|
|||||||
nodeKey2 := key.NewNode()
|
nodeKey2 := key.NewNode()
|
||||||
clientExpiry := time.Now().Add(48 * time.Hour)
|
clientExpiry := time.Now().Add(48 * time.Hour)
|
||||||
registrationID2 := types.MustAuthID()
|
registrationID2 := types.MustAuthID()
|
||||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey2.Public(),
|
NodeKey: nodeKey2.Public(),
|
||||||
Hostname: "tagged-to-personal",
|
Hostname: "tagged-to-personal",
|
||||||
|
|||||||
@@ -681,7 +681,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
Hostname: "followup-success-node",
|
Hostname: "followup-success-node",
|
||||||
})
|
})
|
||||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||||
@@ -723,7 +723,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
Hostname: "followup-timeout-node",
|
Hostname: "followup-timeout-node",
|
||||||
})
|
})
|
||||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||||
@@ -1341,7 +1341,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
nodeToRegister := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
Hostname: "nil-response-node",
|
Hostname: "nil-response-node",
|
||||||
})
|
})
|
||||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||||
@@ -2618,7 +2618,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||||||
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||||
require.True(t, found, "registration cache entry should exist")
|
require.True(t, found, "registration cache entry should exist")
|
||||||
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
||||||
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
|
require.Equal(t, req.NodeKey, cacheEntry.RegistrationData().NodeKey, "cache entry should have correct node key")
|
||||||
}
|
}
|
||||||
|
|
||||||
case stepTypeAuthCompletion:
|
case stepTypeAuthCompletion:
|
||||||
@@ -3570,7 +3570,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
|||||||
|
|
||||||
// Simulate a registration cache entry (as would be created during web auth)
|
// Simulate a registration cache entry (as would be created during web auth)
|
||||||
registrationID := types.MustAuthID()
|
registrationID := types.MustAuthID()
|
||||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey.Public(),
|
NodeKey: nodeKey.Public(),
|
||||||
Hostname: "webauth-tags-node",
|
Hostname: "webauth-tags-node",
|
||||||
@@ -3633,7 +3633,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
|||||||
|
|
||||||
// Step 1: Initial registration with tags
|
// Step 1: Initial registration with tags
|
||||||
registrationID1 := types.MustAuthID()
|
registrationID1 := types.MustAuthID()
|
||||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
regEntry1 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(),
|
MachineKey: machineKey.Public(),
|
||||||
NodeKey: nodeKey1.Public(),
|
NodeKey: nodeKey1.Public(),
|
||||||
Hostname: "reauth-untag-node",
|
Hostname: "reauth-untag-node",
|
||||||
@@ -3660,7 +3660,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
|||||||
// Step 2: Reauth with EMPTY tags to untag
|
// Step 2: Reauth with EMPTY tags to untag
|
||||||
nodeKey2 := key.NewNode() // New node key for reauth
|
nodeKey2 := key.NewNode() // New node key for reauth
|
||||||
registrationID2 := types.MustAuthID()
|
registrationID2 := types.MustAuthID()
|
||||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
regEntry2 := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(), // Same machine key
|
MachineKey: machineKey.Public(), // Same machine key
|
||||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||||
Hostname: "reauth-untag-node",
|
Hostname: "reauth-untag-node",
|
||||||
@@ -3746,7 +3746,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
|||||||
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
||||||
nodeKey2 := key.NewNode() // New node key for reauth
|
nodeKey2 := key.NewNode() // New node key for reauth
|
||||||
registrationID := types.MustAuthID()
|
registrationID := types.MustAuthID()
|
||||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(), // Same machine key
|
MachineKey: machineKey.Public(), // Same machine key
|
||||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||||
Hostname: "authkey-tagged-node",
|
Hostname: "authkey-tagged-node",
|
||||||
@@ -3945,7 +3945,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
|||||||
// This is what happens when running: headscale auth register --auth-id <id> --user alice
|
// This is what happens when running: headscale auth register --auth-id <id> --user alice
|
||||||
nodeKey2 := key.NewNode()
|
nodeKey2 := key.NewNode()
|
||||||
registrationID := types.MustAuthID()
|
registrationID := types.MustAuthID()
|
||||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
regEntry := types.NewRegisterAuthRequest(&types.RegistrationData{
|
||||||
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
||||||
NodeKey: nodeKey2.Public(),
|
NodeKey: nodeKey2.Public(),
|
||||||
Hostname: "tagged-orphan-node",
|
Hostname: "tagged-orphan-node",
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
"gorm.io/gorm/schema"
|
"gorm.io/gorm/schema"
|
||||||
"zgo.at/zcache/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed schema.sql
|
//go:embed schema.sql
|
||||||
@@ -45,19 +44,15 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
DB *gorm.DB
|
DB *gorm.DB
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
regCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||||
// It accepts the full configuration to allow migrations access to policy settings.
|
// It accepts the full configuration to allow migrations access to policy settings.
|
||||||
//
|
//
|
||||||
//nolint:gocyclo // complex database initialization with many migrations
|
//nolint:gocyclo // complex database initialization with many migrations
|
||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(cfg *types.Config) (*HSDatabase, error) {
|
||||||
cfg *types.Config,
|
|
||||||
regCache *zcache.Cache[types.AuthID, types.AuthRequest],
|
|
||||||
) (*HSDatabase, error) {
|
|
||||||
dbConn, err := openDB(cfg.Database)
|
dbConn, err := openDB(cfg.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -838,9 +833,8 @@ WHERE tags IS NOT NULL AND tags != '[]' AND tags != '';
|
|||||||
}
|
}
|
||||||
|
|
||||||
db := HSDatabase{
|
db := HSDatabase{
|
||||||
DB: dbConn,
|
DB: dbConn,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
regCache: regCache,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &db, err
|
return &db, err
|
||||||
|
|||||||
@@ -8,13 +8,11 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"zgo.at/zcache/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestSQLiteMigrationAndDataValidation tests specific SQLite migration scenarios
|
// TestSQLiteMigrationAndDataValidation tests specific SQLite migration scenarios
|
||||||
@@ -162,10 +160,6 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
|
||||||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||||
db, err := sql.Open("sqlite", dbPath)
|
db, err := sql.Open("sqlite", dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -379,7 +373,6 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
|||||||
Mode: types.PolicyModeDB,
|
Mode: types.PolicyModeDB,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
emptyCache(),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("setting up database: %s", err)
|
t.Fatalf("setting up database: %s", err)
|
||||||
@@ -439,7 +432,6 @@ func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
|||||||
Mode: types.PolicyModeDB,
|
Mode: types.PolicyModeDB,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
emptyCache(),
|
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ func newSQLiteTestDB() (*HSDatabase, error) {
|
|||||||
Mode: types.PolicyModeDB,
|
Mode: types.PolicyModeDB,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
emptyCache(),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -95,7 +94,6 @@ func newHeadscaleDBFromPostgresURL(t *testing.T, pu *url.URL) *HSDatabase {
|
|||||||
Mode: types.PolicyModeDB,
|
Mode: types.PolicyModeDB,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
emptyCache(),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
|||||||
@@ -802,27 +802,16 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||||||
Interface("route-str", request.GetRoutes()).
|
Interface("route-str", request.GetRoutes()).
|
||||||
Msg("Creating routes for node")
|
Msg("Creating routes for node")
|
||||||
|
|
||||||
hostinfo := tailcfg.Hostinfo{
|
|
||||||
RoutableIPs: routes,
|
|
||||||
OS: "TestOS",
|
|
||||||
Hostname: request.GetName(),
|
|
||||||
}
|
|
||||||
|
|
||||||
registrationId, err := types.AuthIDFromString(request.GetKey())
|
registrationId, err := types.AuthIDFromString(request.GetKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newNode := types.Node{
|
regData := &types.RegistrationData{
|
||||||
NodeKey: key.NewNode().Public(),
|
NodeKey: key.NewNode().Public(),
|
||||||
MachineKey: key.NewMachine().Public(),
|
MachineKey: key.NewMachine().Public(),
|
||||||
Hostname: request.GetName(),
|
Hostname: request.GetName(),
|
||||||
User: user,
|
Expiry: &time.Time{}, // zero time, not nil — preserves proto JSON round-trip semantics
|
||||||
|
|
||||||
Expiry: &time.Time{},
|
|
||||||
LastSeen: &time.Time{},
|
|
||||||
|
|
||||||
Hostinfo: &hostinfo,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().
|
log.Debug().
|
||||||
@@ -830,10 +819,27 @@ func (api headscaleV1APIServer) DebugCreateNode(
|
|||||||
Str("registration_id", registrationId.String()).
|
Str("registration_id", registrationId.String()).
|
||||||
Msg("adding debug machine via CLI, appending to registration cache")
|
Msg("adding debug machine via CLI, appending to registration cache")
|
||||||
|
|
||||||
authRegReq := types.NewRegisterAuthRequest(newNode)
|
authRegReq := types.NewRegisterAuthRequest(regData)
|
||||||
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
|
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
|
||||||
|
|
||||||
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
|
// Echo back a synthetic Node so the debug response surface stays
|
||||||
|
// stable. The actual node is created later by AuthApprove via
|
||||||
|
// HandleNodeFromAuthPath using the cached RegistrationData.
|
||||||
|
echoNode := types.Node{
|
||||||
|
NodeKey: regData.NodeKey,
|
||||||
|
MachineKey: regData.MachineKey,
|
||||||
|
Hostname: regData.Hostname,
|
||||||
|
User: user,
|
||||||
|
Expiry: &time.Time{},
|
||||||
|
LastSeen: &time.Time{},
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
|
Hostname: request.GetName(),
|
||||||
|
OS: "TestOS",
|
||||||
|
RoutableIPs: routes,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &v1.DebugCreateNodeResponse{Node: echoNode.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) Health(
|
func (api headscaleV1APIServer) Health(
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"zgo.at/zcache/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher")
|
var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher")
|
||||||
@@ -109,11 +108,6 @@ var allBatcherFunctions = []batcherTestCase{
|
|||||||
{"Default", NewBatcherAndMapper},
|
{"Default", NewBatcherAndMapper},
|
||||||
}
|
}
|
||||||
|
|
||||||
// emptyCache creates an empty registration cache for testing.
|
|
||||||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
|
||||||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test configuration constants.
|
// Test configuration constants.
|
||||||
const (
|
const (
|
||||||
// Test data configuration.
|
// Test data configuration.
|
||||||
@@ -211,10 +205,7 @@ func setupBatcherWithTestData(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create database and populate it with test data
|
// Create database and populate it with test data
|
||||||
database, err := db.NewHeadscaleDatabase(
|
database, err := db.NewHeadscaleDatabase(cfg)
|
||||||
cfg,
|
|
||||||
emptyCache(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("setting up database: %s", err)
|
t.Fatalf("setting up database: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/templates"
|
"github.com/juanfont/headscale/hscontrol/templates"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
@@ -19,14 +20,17 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"zgo.at/zcache/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
randomByteSize = 16
|
randomByteSize = 16
|
||||||
defaultOAuthOptionsCount = 3
|
defaultOAuthOptionsCount = 3
|
||||||
authCacheExpiration = time.Minute * 15
|
authCacheExpiration = time.Minute * 15
|
||||||
authCacheCleanup = time.Minute * 20
|
|
||||||
|
// authCacheMaxEntries bounds the OIDC state→AuthInfo cache to prevent
|
||||||
|
// unauthenticated cache-fill DoS via repeated /register/{auth_id} or
|
||||||
|
// /auth/{auth_id} GETs that mint OIDC state cookies.
|
||||||
|
authCacheMaxEntries = 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -55,9 +59,10 @@ type AuthProviderOIDC struct {
|
|||||||
serverURL string
|
serverURL string
|
||||||
cfg *types.OIDCConfig
|
cfg *types.OIDCConfig
|
||||||
|
|
||||||
// authCache holds auth information between
|
// authCache holds auth information between the auth and the callback
|
||||||
// the auth and the callback steps.
|
// steps. It is a bounded LRU keyed by OIDC state, evicting oldest
|
||||||
authCache *zcache.Cache[string, AuthInfo]
|
// entries to keep the cache footprint constant under attack.
|
||||||
|
authCache *expirable.LRU[string, AuthInfo]
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
@@ -84,9 +89,10 @@ func NewAuthProviderOIDC(
|
|||||||
Scopes: cfg.Scope,
|
Scopes: cfg.Scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
authCache := zcache.New[string, AuthInfo](
|
authCache := expirable.NewLRU[string, AuthInfo](
|
||||||
|
authCacheMaxEntries,
|
||||||
|
nil,
|
||||||
authCacheExpiration,
|
authCacheExpiration,
|
||||||
authCacheCleanup,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return &AuthProviderOIDC{
|
return &AuthProviderOIDC{
|
||||||
@@ -188,7 +194,7 @@ func (a *AuthProviderOIDC) authHandler(
|
|||||||
extras = append(extras, oidc.Nonce(nonce))
|
extras = append(extras, oidc.Nonce(nonce))
|
||||||
|
|
||||||
// Cache the registration info
|
// Cache the registration info
|
||||||
a.authCache.Set(state, registrationInfo)
|
a.authCache.Add(state, registrationInfo)
|
||||||
|
|
||||||
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
|
||||||
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ import (
|
|||||||
// fake-clock advancement, but three blockers prevent adoption
|
// fake-clock advancement, but three blockers prevent adoption
|
||||||
// as of Go 1.26:
|
// as of Go 1.26:
|
||||||
//
|
//
|
||||||
// 1. zcache janitor goroutine: No Close() method; stopped only via
|
// 1. golang-lru/v2/expirable janitor goroutine: No Close() method;
|
||||||
// runtime.SetFinalizer which runs outside synctest bubbles.
|
// the deleteExpired ticker goroutine never exits because the done
|
||||||
// - https://github.com/patrickmn/go-cache/issues/185
|
// channel is never closed (documented as a v3 TODO upstream).
|
||||||
// - https://github.com/golang/go/issues/75113 (Go1.27: finalizers inside bubble)
|
// - https://github.com/hashicorp/golang-lru/blob/v2.0.7/expirable/expirable_lru.go#L78-L81
|
||||||
//
|
//
|
||||||
// 2. database/sql internal goroutines: Uses sync.RWMutex which is not
|
// 2. database/sql internal goroutines: Uses sync.RWMutex which is not
|
||||||
// durably blocking in synctest, causing hangs.
|
// durably blocking in synctest, causing hangs.
|
||||||
|
|||||||
64
hscontrol/state/auth_cache_test.go
Normal file
64
hscontrol/state/auth_cache_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package state
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAuthCacheBoundedLRU verifies that the registration auth cache is
|
||||||
|
// bounded by a maximum entry count, that exceeding the maxEntries evicts the
|
||||||
|
// oldest entry, and that the eviction callback resolves the parked
|
||||||
|
// AuthRequest with ErrRegistrationExpired so any waiting goroutine wakes.
|
||||||
|
func TestAuthCacheBoundedLRU(t *testing.T) {
|
||||||
|
const maxEntries = 4
|
||||||
|
|
||||||
|
cache := expirable.NewLRU[types.AuthID, *types.AuthRequest](
|
||||||
|
maxEntries,
|
||||||
|
func(_ types.AuthID, rn *types.AuthRequest) {
|
||||||
|
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||||
|
},
|
||||||
|
time.Hour, // long TTL — we test eviction by size, not by time
|
||||||
|
)
|
||||||
|
|
||||||
|
entries := make([]*types.AuthRequest, 0, maxEntries+1)
|
||||||
|
ids := make([]types.AuthID, 0, maxEntries+1)
|
||||||
|
|
||||||
|
for range maxEntries + 1 {
|
||||||
|
id := types.MustAuthID()
|
||||||
|
entry := types.NewAuthRequest()
|
||||||
|
cache.Add(id, entry)
|
||||||
|
ids = append(ids, id)
|
||||||
|
entries = append(entries, entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap should be respected.
|
||||||
|
assert.Equal(t, maxEntries, cache.Len(), "cache must not exceed the configured maxEntries")
|
||||||
|
|
||||||
|
// The oldest entry must have been evicted.
|
||||||
|
_, ok := cache.Get(ids[0])
|
||||||
|
assert.False(t, ok, "oldest entry must be evicted when maxEntries is exceeded")
|
||||||
|
|
||||||
|
// The eviction callback must have woken the parked AuthRequest.
|
||||||
|
select {
|
||||||
|
case verdict := <-entries[0].WaitForAuth():
|
||||||
|
require.False(t, verdict.Accept(), "evicted entry must not signal Accept")
|
||||||
|
require.ErrorIs(t,
|
||||||
|
verdict.Err, ErrRegistrationExpired,
|
||||||
|
"evicted entry must surface ErrRegistrationExpired, got: %v",
|
||||||
|
verdict.Err,
|
||||||
|
)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("eviction callback did not wake the parked AuthRequest")
|
||||||
|
}
|
||||||
|
|
||||||
|
// All non-evicted entries must still be retrievable.
|
||||||
|
for i := 1; i <= maxEntries; i++ {
|
||||||
|
_, ok := cache.Get(ids[i])
|
||||||
|
assert.True(t, ok, "non-evicted entry %d should still be in the cache", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -211,15 +211,13 @@ func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy {
|
|||||||
|
|
||||||
// DebugRegistrationCache returns debug information about the registration cache.
|
// DebugRegistrationCache returns debug information about the registration cache.
|
||||||
func (s *State) DebugRegistrationCache() map[string]any {
|
func (s *State) DebugRegistrationCache() map[string]any {
|
||||||
// The cache doesn't expose internal statistics, so we provide basic info
|
return map[string]any{
|
||||||
result := map[string]any{
|
"type": "expirable-lru",
|
||||||
"type": "zcache",
|
"expiration": registerCacheExpiration.String(),
|
||||||
"expiration": registerCacheExpiration.String(),
|
"max_entries": defaultRegisterCacheMaxEntries,
|
||||||
"cleanup": registerCacheCleanup.String(),
|
"current_len": s.authCache.Len(),
|
||||||
"status": "active",
|
"status": "active",
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DebugConfig returns debug information about the current configuration.
|
// DebugConfig returns debug information about the current configuration.
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||||
hsdb "github.com/juanfont/headscale/hscontrol/db"
|
hsdb "github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
@@ -30,15 +31,18 @@ import (
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
zcache "zgo.at/zcache/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// registerCacheExpiration defines how long node registration entries remain in cache.
|
// registerCacheExpiration defines how long node registration entries remain in cache.
|
||||||
registerCacheExpiration = time.Minute * 15
|
registerCacheExpiration = time.Minute * 15
|
||||||
|
|
||||||
// registerCacheCleanup defines the interval for cleaning up expired cache entries.
|
// defaultRegisterCacheMaxEntries is the default upper bound on the number
|
||||||
registerCacheCleanup = time.Minute * 20
|
// of pending registration entries the auth cache will hold. With a 15-minute
|
||||||
|
// TTL and a stripped-down RegistrationData payload (~200 bytes per entry),
|
||||||
|
// 1024 entries cap the worst-case cache footprint at well under 1 MiB even
|
||||||
|
// under sustained unauthenticated cache-fill attempts.
|
||||||
|
defaultRegisterCacheMaxEntries = 1024
|
||||||
|
|
||||||
// defaultNodeStoreBatchSize is the default number of write operations to batch
|
// defaultNodeStoreBatchSize is the default number of write operations to batch
|
||||||
// before rebuilding the in-memory node snapshot.
|
// before rebuilding the in-memory node snapshot.
|
||||||
@@ -126,8 +130,12 @@ type State struct {
|
|||||||
// polMan handles policy evaluation and management
|
// polMan handles policy evaluation and management
|
||||||
polMan policy.PolicyManager
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
// authCache caches any pending authentication requests, from either auth type (Web and OIDC).
|
// authCache holds any pending authentication requests from either auth
|
||||||
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
|
// type (Web and OIDC). It is a bounded LRU keyed by AuthID; oldest
|
||||||
|
// entries are evicted once the size cap is reached, and entries that
|
||||||
|
// time out have their auth verdict resolved with ErrRegistrationExpired
|
||||||
|
// via the eviction callback so any waiting goroutines wake.
|
||||||
|
authCache *expirable.LRU[types.AuthID, *types.AuthRequest]
|
||||||
|
|
||||||
// primaryRoutes tracks primary route assignments for nodes
|
// primaryRoutes tracks primary route assignments for nodes
|
||||||
primaryRoutes *routes.PrimaryRoutes
|
primaryRoutes *routes.PrimaryRoutes
|
||||||
@@ -166,26 +174,20 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||||||
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
|
cacheExpiration = cfg.Tuning.RegisterCacheExpiration
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheCleanup := registerCacheCleanup
|
cacheMaxEntries := defaultRegisterCacheMaxEntries
|
||||||
if cfg.Tuning.RegisterCacheCleanup != 0 {
|
if cfg.Tuning.RegisterCacheMaxEntries > 0 {
|
||||||
cacheCleanup = cfg.Tuning.RegisterCacheCleanup
|
cacheMaxEntries = cfg.Tuning.RegisterCacheMaxEntries
|
||||||
}
|
}
|
||||||
|
|
||||||
authCache := zcache.New[types.AuthID, types.AuthRequest](
|
authCache := expirable.NewLRU[types.AuthID, *types.AuthRequest](
|
||||||
cacheExpiration,
|
cacheMaxEntries,
|
||||||
cacheCleanup,
|
func(id types.AuthID, rn *types.AuthRequest) {
|
||||||
)
|
|
||||||
|
|
||||||
authCache.OnEvicted(
|
|
||||||
func(id types.AuthID, rn types.AuthRequest) {
|
|
||||||
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
|
||||||
},
|
},
|
||||||
|
cacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
||||||
db, err := hsdb.NewHeadscaleDatabase(
|
db, err := hsdb.NewHeadscaleDatabase(cfg)
|
||||||
cfg,
|
|
||||||
authCache,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("initializing database: %w", err)
|
return nil, fmt.Errorf("initializing database: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1252,19 +1254,14 @@ func (s *State) DeletePreAuthKey(id uint64) error {
|
|||||||
return s.db.DeletePreAuthKey(id)
|
return s.db.DeletePreAuthKey(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthCacheEntry retrieves a node registration from cache.
|
// GetAuthCacheEntry retrieves a pending auth request from the cache.
|
||||||
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
|
||||||
entry, found := s.authCache.Get(id)
|
return s.authCache.Get(id)
|
||||||
if !found {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return &entry, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAuthCacheEntry stores a node registration in cache.
|
// SetAuthCacheEntry stores a pending auth request in the cache.
|
||||||
func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
|
func (s *State) SetAuthCacheEntry(id types.AuthID, entry *types.AuthRequest) {
|
||||||
s.authCache.Set(id, entry)
|
s.authCache.Add(id, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLastSSHAuth records a successful SSH check authentication
|
// SetLastSSHAuth records a successful SSH check authentication
|
||||||
@@ -1296,25 +1293,6 @@ func (s *State) ClearSSHCheckAuth() {
|
|||||||
s.sshCheckAuth = make(map[sshCheckPair]time.Time)
|
s.sshCheckAuth = make(map[sshCheckPair]time.Time)
|
||||||
}
|
}
|
||||||
|
|
||||||
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
|
|
||||||
func logHostinfoValidation(nv types.NodeView, username, hostname string) {
|
|
||||||
if !nv.Hostinfo().Valid() {
|
|
||||||
log.Warn().
|
|
||||||
Caller().
|
|
||||||
EmbedObject(nv).
|
|
||||||
Str(zf.UserName, username).
|
|
||||||
Str(zf.GeneratedHostname, hostname).
|
|
||||||
Msg("Registration had nil hostinfo, generated default hostname")
|
|
||||||
} else if nv.Hostinfo().Hostname() == "" {
|
|
||||||
log.Warn().
|
|
||||||
Caller().
|
|
||||||
EmbedObject(nv).
|
|
||||||
Str(zf.UserName, username).
|
|
||||||
Str(zf.GeneratedHostname, hostname).
|
|
||||||
Msg("Registration had empty hostname, generated default")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity.
|
// preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity.
|
||||||
// If no existing node is provided, it creates new netinfo from the provided hostinfo.
|
// If no existing node is provided, it creates new netinfo from the provided hostinfo.
|
||||||
func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo {
|
func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo {
|
||||||
@@ -1349,15 +1327,15 @@ type newNodeParams struct {
|
|||||||
type authNodeUpdateParams struct {
|
type authNodeUpdateParams struct {
|
||||||
// Node to update; must be valid and in NodeStore.
|
// Node to update; must be valid and in NodeStore.
|
||||||
ExistingNode types.NodeView
|
ExistingNode types.NodeView
|
||||||
// Client data: keys, hostinfo, endpoints.
|
// Cached registration payload from the originating client request.
|
||||||
RegEntry *types.AuthRequest
|
RegData *types.RegistrationData
|
||||||
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
// Pre-validated hostinfo; NetInfo preserved from ExistingNode.
|
||||||
ValidHostinfo *tailcfg.Hostinfo
|
ValidHostinfo *tailcfg.Hostinfo
|
||||||
// Hostname from hostinfo, or generated from keys if client omits it.
|
// Hostname from hostinfo, or generated from keys if client omits it.
|
||||||
Hostname string
|
Hostname string
|
||||||
// Auth user; may differ from ExistingNode.User() on conversion.
|
// Auth user; may differ from ExistingNode.User() on conversion.
|
||||||
User *types.User
|
User *types.User
|
||||||
// Overrides RegEntry.Node.Expiry; ignored for tagged nodes.
|
// Overrides RegData.Expiry; ignored for tagged nodes.
|
||||||
Expiry *time.Time
|
Expiry *time.Time
|
||||||
// Only used when IsConvertFromTag=true.
|
// Only used when IsConvertFromTag=true.
|
||||||
RegisterMethod string
|
RegisterMethod string
|
||||||
@@ -1369,7 +1347,7 @@ type authNodeUpdateParams struct {
|
|||||||
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
// an existing node. It updates the node in NodeStore, processes RequestTags, and
|
||||||
// persists changes to the database.
|
// persists changes to the database.
|
||||||
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
|
||||||
regNv := params.RegEntry.Node()
|
regData := params.RegData
|
||||||
// Log the operation type
|
// Log the operation type
|
||||||
if params.IsConvertFromTag {
|
if params.IsConvertFromTag {
|
||||||
log.Info().
|
log.Info().
|
||||||
@@ -1379,15 +1357,17 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
} else {
|
} else {
|
||||||
log.Info().
|
log.Info().
|
||||||
Object("existing", params.ExistingNode).
|
Object("existing", params.ExistingNode).
|
||||||
Object("incoming", regNv).
|
Str("incoming.hostname", regData.Hostname).
|
||||||
|
Str("incoming.machine_key", regData.MachineKey.ShortString()).
|
||||||
Msg("Updating existing node registration via reauth")
|
Msg("Updating existing node registration via reauth")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process RequestTags during reauth (#2979)
|
// Process RequestTags during reauth (#2979).
|
||||||
// Due to json:",omitempty", we treat empty/nil as "clear tags"
|
// Due to json:",omitempty", empty/nil from the cached Hostinfo
|
||||||
|
// means "clear tags".
|
||||||
var requestTags []string
|
var requestTags []string
|
||||||
if regNv.Hostinfo().Valid() {
|
if regData.Hostinfo != nil {
|
||||||
requestTags = regNv.Hostinfo().RequestTags().AsSlice()
|
requestTags = regData.Hostinfo.RequestTags
|
||||||
}
|
}
|
||||||
|
|
||||||
oldTags := params.ExistingNode.Tags().AsSlice()
|
oldTags := params.ExistingNode.Tags().AsSlice()
|
||||||
@@ -1405,8 +1385,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
|
|
||||||
// Update existing node in NodeStore - validation passed, safe to mutate
|
// Update existing node in NodeStore - validation passed, safe to mutate
|
||||||
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
|
||||||
node.NodeKey = regNv.NodeKey()
|
node.NodeKey = regData.NodeKey
|
||||||
node.DiscoKey = regNv.DiscoKey()
|
node.DiscoKey = regData.DiscoKey
|
||||||
node.Hostname = params.Hostname
|
node.Hostname = params.Hostname
|
||||||
|
|
||||||
// Preserve NetInfo from existing node when re-registering
|
// Preserve NetInfo from existing node when re-registering
|
||||||
@@ -1417,7 +1397,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
params.ValidHostinfo,
|
params.ValidHostinfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
node.Endpoints = regNv.Endpoints().AsSlice()
|
node.Endpoints = regData.Endpoints
|
||||||
// Do NOT reset IsOnline here. Online status is managed exclusively by
|
// Do NOT reset IsOnline here. Online status is managed exclusively by
|
||||||
// Connect()/Disconnect() in the poll session lifecycle. Resetting it
|
// Connect()/Disconnect() in the poll session lifecycle. Resetting it
|
||||||
// during re-registration causes a false offline blip: the change
|
// during re-registration causes a false offline blip: the change
|
||||||
@@ -1425,12 +1405,12 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
// to peers, even though Connect() will immediately set it back to true.
|
// to peers, even though Connect() will immediately set it back to true.
|
||||||
node.LastSeen = new(time.Now())
|
node.LastSeen = new(time.Now())
|
||||||
|
|
||||||
// Set RegisterMethod - for conversion this is the new method,
|
// On conversion (tagged → user) we set the new register method.
|
||||||
// for reauth we preserve the existing one from regEntry
|
// On plain reauth we preserve the existing node.RegisterMethod;
|
||||||
|
// the cached RegistrationData no longer carries it because the
|
||||||
|
// producer never populated it.
|
||||||
if params.IsConvertFromTag {
|
if params.IsConvertFromTag {
|
||||||
node.RegisterMethod = params.RegisterMethod
|
node.RegisterMethod = params.RegisterMethod
|
||||||
} else {
|
|
||||||
node.RegisterMethod = regNv.RegisterMethod()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track tagged status BEFORE processing tags
|
// Track tagged status BEFORE processing tags
|
||||||
@@ -1450,7 +1430,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
if params.Expiry != nil {
|
if params.Expiry != nil {
|
||||||
node.Expiry = params.Expiry
|
node.Expiry = params.Expiry
|
||||||
} else {
|
} else {
|
||||||
node.Expiry = regNv.Expiry().Clone()
|
node.Expiry = regData.Expiry
|
||||||
}
|
}
|
||||||
case !wasTagged && isTagged:
|
case !wasTagged && isTagged:
|
||||||
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
// Personal → Tagged: clear expiry (tagged nodes don't expire)
|
||||||
@@ -1460,14 +1440,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
|
|||||||
if params.Expiry != nil {
|
if params.Expiry != nil {
|
||||||
node.Expiry = params.Expiry
|
node.Expiry = params.Expiry
|
||||||
} else {
|
} else {
|
||||||
node.Expiry = regNv.Expiry().Clone()
|
node.Expiry = regData.Expiry
|
||||||
}
|
}
|
||||||
case !isTagged:
|
case !isTagged:
|
||||||
// Personal → Personal: update expiry from client
|
// Personal → Personal: update expiry from client
|
||||||
if params.Expiry != nil {
|
if params.Expiry != nil {
|
||||||
node.Expiry = params.Expiry
|
node.Expiry = params.Expiry
|
||||||
} else {
|
} else {
|
||||||
node.Expiry = regNv.Expiry().Clone()
|
node.Expiry = regData.Expiry
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
// Tagged → Tagged: keep existing expiry (nil) - no action needed
|
||||||
@@ -1795,29 +1775,20 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
return types.NodeView{}, change.Change{}, fmt.Errorf("finding user: %w", err)
|
return types.NodeView{}, change.Change{}, fmt.Errorf("finding user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have a valid hostname from the registration cache entry
|
regData := regEntry.RegistrationData()
|
||||||
hostname := util.EnsureHostname(
|
|
||||||
regEntry.Node().Hostinfo(),
|
|
||||||
regEntry.Node().MachineKey().String(),
|
|
||||||
regEntry.Node().NodeKey().String(),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Ensure we have valid hostinfo
|
// Hostname was already validated/normalised at producer time. Build
|
||||||
|
// the initial Hostinfo from the cached client-supplied Hostinfo (or
|
||||||
|
// an empty stub if the client did not send one).
|
||||||
|
hostname := regData.Hostname
|
||||||
hostinfo := &tailcfg.Hostinfo{}
|
hostinfo := &tailcfg.Hostinfo{}
|
||||||
if regEntry.Node().Hostinfo().Valid() {
|
if regData.Hostinfo != nil {
|
||||||
hostinfo = regEntry.Node().Hostinfo().AsStruct()
|
hostinfo = regData.Hostinfo.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
hostinfo.Hostname = hostname
|
hostinfo.Hostname = hostname
|
||||||
|
|
||||||
logHostinfoValidation(
|
|
||||||
regEntry.Node(),
|
|
||||||
user.Name,
|
|
||||||
hostname,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Lookup existing nodes
|
// Lookup existing nodes
|
||||||
machineKey := regEntry.Node().MachineKey()
|
machineKey := regData.MachineKey
|
||||||
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
|
||||||
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
|
||||||
|
|
||||||
@@ -1839,7 +1810,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
|
|
||||||
// Common params for update operations
|
// Common params for update operations
|
||||||
updateParams := authNodeUpdateParams{
|
updateParams := authNodeUpdateParams{
|
||||||
RegEntry: regEntry,
|
RegData: regData,
|
||||||
ValidHostinfo: hostinfo,
|
ValidHostinfo: hostinfo,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
User: user,
|
User: user,
|
||||||
@@ -1874,7 +1845,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
Msg("Creating new node for different user (same machine key exists for another user)")
|
Msg("Creating new node for different user (same machine key exists for another user)")
|
||||||
|
|
||||||
finalNode, err = s.createNewNodeFromAuth(
|
finalNode, err = s.createNewNodeFromAuth(
|
||||||
logger, user, regEntry, hostname, hostinfo,
|
logger, user, regData, hostname, hostinfo,
|
||||||
expiry, registrationMethod, existingNodeAnyUser,
|
expiry, registrationMethod, existingNodeAnyUser,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1882,7 +1853,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
finalNode, err = s.createNewNodeFromAuth(
|
finalNode, err = s.createNewNodeFromAuth(
|
||||||
logger, user, regEntry, hostname, hostinfo,
|
logger, user, regData, hostname, hostinfo,
|
||||||
expiry, registrationMethod, types.NodeView{},
|
expiry, registrationMethod, types.NodeView{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1893,8 +1864,8 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
// Signal to waiting clients
|
// Signal to waiting clients
|
||||||
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
|
regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
|
||||||
|
|
||||||
// Delete from registration cache
|
// Remove from registration cache
|
||||||
s.authCache.Delete(authID)
|
s.authCache.Remove(authID)
|
||||||
|
|
||||||
// Update policy managers
|
// Update policy managers
|
||||||
usersChange, err := s.updatePolicyManagerUsers()
|
usersChange, err := s.updatePolicyManagerUsers()
|
||||||
@@ -1923,7 +1894,7 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
func (s *State) createNewNodeFromAuth(
|
func (s *State) createNewNodeFromAuth(
|
||||||
logger zerolog.Logger,
|
logger zerolog.Logger,
|
||||||
user *types.User,
|
user *types.User,
|
||||||
regEntry *types.AuthRequest,
|
regData *types.RegistrationData,
|
||||||
hostname string,
|
hostname string,
|
||||||
validHostinfo *tailcfg.Hostinfo,
|
validHostinfo *tailcfg.Hostinfo,
|
||||||
expiry *time.Time,
|
expiry *time.Time,
|
||||||
@@ -1936,13 +1907,13 @@ func (s *State) createNewNodeFromAuth(
|
|||||||
|
|
||||||
return s.createAndSaveNewNode(newNodeParams{
|
return s.createAndSaveNewNode(newNodeParams{
|
||||||
User: *user,
|
User: *user,
|
||||||
MachineKey: regEntry.Node().MachineKey(),
|
MachineKey: regData.MachineKey,
|
||||||
NodeKey: regEntry.Node().NodeKey(),
|
NodeKey: regData.NodeKey,
|
||||||
DiscoKey: regEntry.Node().DiscoKey(),
|
DiscoKey: regData.DiscoKey,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
Hostinfo: validHostinfo,
|
Hostinfo: validHostinfo,
|
||||||
Endpoints: regEntry.Node().Endpoints().AsSlice(),
|
Endpoints: regData.Endpoints,
|
||||||
Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
|
Expiry: cmp.Or(expiry, regData.Expiry),
|
||||||
RegisterMethod: registrationMethod,
|
RegisterMethod: registrationMethod,
|
||||||
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
ExistingNodeForNetinfo: existingNodeForNetinfo,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -221,40 +221,65 @@ func (r AuthID) Validate() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRequest represent a pending authentication request from a user or a node.
|
// AuthRequest represents a pending authentication request from a user or a
|
||||||
// If it is a registration request, the node field will be populate with the node that is trying to register.
|
// node. It carries the minimum data needed to either complete a node
|
||||||
// When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
|
// registration (regData populated) or signal the verdict of an interactive
|
||||||
// The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
|
// auth flow (no payload). Verdict delivery is via the finished channel; the
|
||||||
|
// closed flag guards FinishAuth against double-close.
|
||||||
|
//
|
||||||
|
// AuthRequest is always handled by pointer so the channel and atomic flag
|
||||||
|
// have a single canonical instance even when stored in caches that
|
||||||
|
// internally copy values.
|
||||||
type AuthRequest struct {
|
type AuthRequest struct {
|
||||||
node *Node
|
// regData is populated for node-registration flows (interactive web
|
||||||
|
// or OIDC). It carries only the minimal subset of registration data
|
||||||
|
// the auth callback needs to promote this request into a real node;
|
||||||
|
// see RegistrationData for the rationale behind keeping the payload
|
||||||
|
// small.
|
||||||
|
//
|
||||||
|
// nil for non-registration flows (e.g. SSH check). Use
|
||||||
|
// RegistrationData() to read it safely.
|
||||||
|
regData *RegistrationData
|
||||||
|
|
||||||
finished chan AuthVerdict
|
finished chan AuthVerdict
|
||||||
closed *atomic.Bool
|
closed *atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthRequest() AuthRequest {
|
// NewAuthRequest creates a pending auth request with no payload, suitable
|
||||||
return AuthRequest{
|
// for non-registration flows that only need a verdict channel.
|
||||||
|
func NewAuthRequest() *AuthRequest {
|
||||||
|
return &AuthRequest{
|
||||||
finished: make(chan AuthVerdict, 1),
|
finished: make(chan AuthVerdict, 1),
|
||||||
closed: &atomic.Bool{},
|
closed: &atomic.Bool{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRegisterAuthRequest(node Node) AuthRequest {
|
// NewRegisterAuthRequest creates a pending auth request carrying the
|
||||||
return AuthRequest{
|
// minimal RegistrationData for a node-registration flow. The data is
|
||||||
node: &node,
|
// stored by pointer; callers must not mutate it after handing it off.
|
||||||
|
func NewRegisterAuthRequest(data *RegistrationData) *AuthRequest {
|
||||||
|
return &AuthRequest{
|
||||||
|
regData: data,
|
||||||
finished: make(chan AuthVerdict, 1),
|
finished: make(chan AuthVerdict, 1),
|
||||||
closed: &atomic.Bool{},
|
closed: &atomic.Bool{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Node returns the node that is trying to register.
|
// RegistrationData returns the cached registration payload. It panics if
|
||||||
// It will panic if the AuthRequest is not a registration request.
|
// called on an AuthRequest that was not created via
|
||||||
// Can _only_ be used in the registration path.
|
// NewRegisterAuthRequest, mirroring the previous Node() contract.
|
||||||
func (rn *AuthRequest) Node() NodeView {
|
func (rn *AuthRequest) RegistrationData() *RegistrationData {
|
||||||
if rn.node == nil {
|
if rn.regData == nil {
|
||||||
panic("Node can only be used in registration requests")
|
panic("RegistrationData can only be used in registration requests")
|
||||||
}
|
}
|
||||||
|
|
||||||
return rn.node.View()
|
return rn.regData
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRegistration reports whether this auth request carries registration
|
||||||
|
// data (i.e. it was created via NewRegisterAuthRequest).
|
||||||
|
func (rn *AuthRequest) IsRegistration() bool {
|
||||||
|
return rn.regData != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
|
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
|
||||||
|
|||||||
@@ -278,14 +278,16 @@ type Tuning struct {
|
|||||||
// updates for connected clients.
|
// updates for connected clients.
|
||||||
BatcherWorkers int
|
BatcherWorkers int
|
||||||
|
|
||||||
// RegisterCacheCleanup is the interval between cleanup operations for
|
|
||||||
// expired registration cache entries.
|
|
||||||
RegisterCacheCleanup time.Duration
|
|
||||||
|
|
||||||
// RegisterCacheExpiration is how long registration cache entries remain
|
// RegisterCacheExpiration is how long registration cache entries remain
|
||||||
// valid before being eligible for cleanup.
|
// valid before being eligible for eviction.
|
||||||
RegisterCacheExpiration time.Duration
|
RegisterCacheExpiration time.Duration
|
||||||
|
|
||||||
|
// RegisterCacheMaxEntries bounds the number of pending registration
|
||||||
|
// entries the auth cache will hold. Older entries are evicted (LRU)
|
||||||
|
// when the cap is reached, preventing unauthenticated cache-fill DoS.
|
||||||
|
// A value of 0 falls back to defaultRegisterCacheMaxEntries (1024).
|
||||||
|
RegisterCacheMaxEntries int
|
||||||
|
|
||||||
// NodeStoreBatchSize controls how many write operations are accumulated
|
// NodeStoreBatchSize controls how many write operations are accumulated
|
||||||
// before rebuilding the in-memory node snapshot.
|
// before rebuilding the in-memory node snapshot.
|
||||||
//
|
//
|
||||||
@@ -1192,8 +1194,8 @@ func LoadServerConfig() (*Config, error) {
|
|||||||
|
|
||||||
return DefaultBatcherWorkers()
|
return DefaultBatcherWorkers()
|
||||||
}(),
|
}(),
|
||||||
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
|
||||||
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
|
RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"),
|
||||||
|
RegisterCacheMaxEntries: viper.GetInt("tuning.register_cache_max_entries"),
|
||||||
NodeStoreBatchSize: viper.GetInt("tuning.node_store_batch_size"),
|
NodeStoreBatchSize: viper.GetInt("tuning.node_store_batch_size"),
|
||||||
NodeStoreBatchTimeout: viper.GetDuration("tuning.node_store_batch_timeout"),
|
NodeStoreBatchTimeout: viper.GetDuration("tuning.node_store_batch_timeout"),
|
||||||
},
|
},
|
||||||
|
|||||||
55
hscontrol/types/registration.go
Normal file
55
hscontrol/types/registration.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegistrationData is the payload cached for a pending node registration.
|
||||||
|
// It replaces the previous practice of caching a full *Node and carries
|
||||||
|
// only the fields the registration callback path actually consumes when
|
||||||
|
// promoting a pending registration to a real node.
|
||||||
|
//
|
||||||
|
// Combined with the bounded-LRU cache that holds these entries, this caps
|
||||||
|
// the worst-case memory footprint of unauthenticated cache-fill attempts
|
||||||
|
// at (max_entries × per_entry_size). The cache is sized so that the
|
||||||
|
// product is bounded to a few MiB even with attacker-supplied 1 MiB
|
||||||
|
// Hostinfos (the Noise body limit).
|
||||||
|
type RegistrationData struct {
|
||||||
|
// MachineKey is the cryptographic identity of the machine being
|
||||||
|
// registered. Required.
|
||||||
|
MachineKey key.MachinePublic
|
||||||
|
|
||||||
|
// NodeKey is the cryptographic identity of the node session.
|
||||||
|
// Required.
|
||||||
|
NodeKey key.NodePublic
|
||||||
|
|
||||||
|
// DiscoKey is the disco public key for peer-to-peer connections.
|
||||||
|
DiscoKey key.DiscoPublic
|
||||||
|
|
||||||
|
// Hostname is the resolved hostname for the registering node.
|
||||||
|
// Already validated/normalised by EnsureHostname at producer time.
|
||||||
|
Hostname string
|
||||||
|
|
||||||
|
// Hostinfo is the original Hostinfo from the RegisterRequest,
|
||||||
|
// stored so that the auth callback can populate the new node's
|
||||||
|
// initial Hostinfo (and so that observability/CLI consumers see
|
||||||
|
// fields like OS, OSVersion, and IPNVersion before the first
|
||||||
|
// MapRequest restores the live set).
|
||||||
|
//
|
||||||
|
// May be nil if the client did not send Hostinfo in the original
|
||||||
|
// RegisterRequest.
|
||||||
|
Hostinfo *tailcfg.Hostinfo
|
||||||
|
|
||||||
|
// Endpoints is the initial set of WireGuard endpoints the node
|
||||||
|
// reported. The first MapRequest after registration overwrites
|
||||||
|
// this with the live set.
|
||||||
|
Endpoints []netip.AddrPort
|
||||||
|
|
||||||
|
// Expiry is the optional client-requested expiry for this node.
|
||||||
|
// May be nil if the client did not request a specific expiry.
|
||||||
|
Expiry *time.Time
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user