auth: generalise auth flow and introduce AuthVerdict

Generalise the registration pipeline to a more general auth pipeline
supporting both node registrations and SSH check auth requests.
Rename RegistrationID to AuthID, unexport AuthRequest fields, and
introduce AuthVerdict to unify the auth finish API.

Add the urlParam generic helper for extracting typed URL parameters
from chi routes, used by the new auth request handler.

Updates #1850
This commit is contained in:
Kristoffer Dalby
2026-02-24 18:48:57 +00:00
parent 30338441c1
commit cb3b6949ea
19 changed files with 443 additions and 336 deletions

View File

@@ -37,7 +37,7 @@ var createNodeCmd = &cobra.Command{
name, _ := cmd.Flags().GetString("name") name, _ := cmd.Flags().GetString("name")
registrationID, _ := cmd.Flags().GetString("key") registrationID, _ := cmd.Flags().GetString("key")
_, err := types.RegistrationIDFromString(registrationID) _, err := types.AuthIDFromString(registrationID)
if err != nil { if err != nil {
return fmt.Errorf("parsing machine key: %w", err) return fmt.Errorf("parsing machine key: %w", err)
} }

View File

@@ -479,7 +479,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *chi.Mux {
r.Get("/health", h.HealthHandler) r.Get("/health", h.HealthHandler)
r.Get("/version", h.VersionHandler) r.Get("/version", h.VersionHandler)
r.Get("/key", h.KeyHandler) r.Get("/key", h.KeyHandler)
r.Get("/register/{registration_id}", h.authProvider.RegisterHandler) r.Get("/register/{auth_id}", h.authProvider.RegisterHandler)
r.Get("/auth/{auth_id}", h.authProvider.AuthHandler)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
r.Get("/oidc/callback", provider.OIDCCallbackHandler) r.Get("/oidc/callback", provider.OIDCCallbackHandler)

View File

@@ -20,7 +20,9 @@ import (
type AuthProvider interface { type AuthProvider interface {
RegisterHandler(w http.ResponseWriter, r *http.Request) RegisterHandler(w http.ResponseWriter, r *http.Request)
AuthURL(regID types.RegistrationID) string AuthHandler(w http.ResponseWriter, r *http.Request)
RegisterURL(authID types.AuthID) string
AuthURL(authID types.AuthID) string
} }
func (h *Headscale) handleRegister( func (h *Headscale) handleRegister(
@@ -263,22 +265,24 @@ func (h *Headscale) waitForFollowup(
return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err)
} }
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) followupReg, err := types.AuthIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err) return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err)
} }
if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok { if reg, ok := h.state.GetAuthCacheEntry(followupReg); ok {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err)
case node := <-reg.Registered: case verdict := <-reg.WaitForAuth():
if node == nil { if verdict.Accept() {
// registration is expired in the cache, instruct the client to try a new registration if !verdict.Node.Valid() {
return h.reqToNewRegisterResponse(req, machineKey) // registration is expired in the cache, instruct the client to try a new registration
} return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node.View()), nil return nodeToRegisterResponse(verdict.Node), nil
}
} }
} }
@@ -293,14 +297,14 @@ func (h *Headscale) reqToNewRegisterResponse(
req tailcfg.RegisterRequest, req tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
newRegID, err := types.NewRegistrationID() newAuthID, err := types.NewAuthID()
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err)
} }
// Ensure we have a valid hostname // Ensure we have a valid hostname
hostname := util.EnsureHostname( hostname := util.EnsureHostname(
req.Hostinfo, req.Hostinfo.View(),
machineKey.String(), machineKey.String(),
req.NodeKey.String(), req.NodeKey.String(),
) )
@@ -309,25 +313,25 @@ func (h *Headscale) reqToNewRegisterResponse(
hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{})
hostinfo.Hostname = hostname hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode( nodeToRegister := types.Node{
types.Node{ Hostname: hostname,
Hostname: hostname, MachineKey: machineKey,
MachineKey: machineKey, NodeKey: req.NodeKey,
NodeKey: req.NodeKey, Hostinfo: hostinfo,
Hostinfo: hostinfo, LastSeen: new(time.Now()),
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
} }
log.Info().Msgf("new followup node registration using key: %s", newRegID) if !req.Expiry.IsZero() {
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) nodeToRegister.Expiry = &req.Expiry
}
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
log.Info().Msgf("new followup node registration using key: %s", newAuthID)
h.state.SetAuthCacheEntry(newAuthID, authRegReq)
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(newRegID), AuthURL: h.authProvider.RegisterURL(newAuthID),
}, nil }, nil
} }
@@ -378,13 +382,6 @@ func (h *Headscale) handleRegisterWithAuthKey(
// Send both changes. Empty changes are ignored by Change(). // Send both changes. Empty changes are ignored by Change().
h.Change(changed, routesChange) h.Change(changed, routesChange)
// TODO(kradalby): I think this is covered above, but we need to validate that.
// // If policy changed due to node registration, send a separate policy change
// if policyChanged {
// policyChange := change.PolicyChange()
// h.Change(policyChange)
// }
resp := &tailcfg.RegisterResponse{ resp := &tailcfg.RegisterResponse{
MachineAuthorized: true, MachineAuthorized: true,
NodeKeyExpired: node.IsExpired(), NodeKeyExpired: node.IsExpired(),
@@ -406,14 +403,14 @@ func (h *Headscale) handleRegisterInteractive(
req tailcfg.RegisterRequest, req tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) { ) (*tailcfg.RegisterResponse, error) {
registrationId, err := types.NewRegistrationID() authID, err := types.NewAuthID()
if err != nil { if err != nil {
return nil, fmt.Errorf("generating registration ID: %w", err) return nil, fmt.Errorf("generating registration ID: %w", err)
} }
// Ensure we have a valid hostname // Ensure we have a valid hostname
hostname := util.EnsureHostname( hostname := util.EnsureHostname(
req.Hostinfo, req.Hostinfo.View(),
machineKey.String(), machineKey.String(),
req.NodeKey.String(), req.NodeKey.String(),
) )
@@ -436,28 +433,28 @@ func (h *Headscale) handleRegisterInteractive(
hostinfo.Hostname = hostname hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode( nodeToRegister := types.Node{
types.Node{ Hostname: hostname,
Hostname: hostname, MachineKey: machineKey,
MachineKey: machineKey, NodeKey: req.NodeKey,
NodeKey: req.NodeKey, Hostinfo: hostinfo,
Hostinfo: hostinfo, LastSeen: new(time.Now()),
LastSeen: new(time.Now()),
},
)
if !req.Expiry.IsZero() {
nodeToRegister.Node.Expiry = &req.Expiry
} }
h.state.SetRegistrationCacheEntry( if !req.Expiry.IsZero() {
registrationId, nodeToRegister.Expiry = &req.Expiry
nodeToRegister, }
authRegReq := types.NewRegisterAuthRequest(nodeToRegister)
h.state.SetAuthCacheEntry(
authID,
authRegReq,
) )
log.Info().Msgf("starting node registration using key: %s", registrationId) log.Info().Msgf("starting node registration using key: %s", authID)
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId), AuthURL: h.authProvider.RegisterURL(authID),
}, nil }, nil
} }

View File

@@ -651,8 +651,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 1: Create user-owned node WITH expiry set // Step 1: Create user-owned node WITH expiry set
clientExpiry := time.Now().Add(24 * time.Hour) clientExpiry := time.Now().Add(24 * time.Hour)
registrationID1 := types.MustRegistrationID() registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterNode(types.Node{ regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(), NodeKey: nodeKey1.Public(),
Hostname: "personal-to-tagged", Hostname: "personal-to-tagged",
@@ -662,7 +662,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
}, },
Expiry: &clientExpiry, Expiry: &clientExpiry,
}) })
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) app.state.SetAuthCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath( node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth", registrationID1, types.UserID(user.ID), nil, "webauth",
@@ -673,8 +673,8 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
// Step 2: Re-auth with tags (Personal → Tagged conversion) // Step 2: Re-auth with tags (Personal → Tagged conversion)
nodeKey2 := key.NewNode() nodeKey2 := key.NewNode()
registrationID2 := types.MustRegistrationID() registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterNode(types.Node{ regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(), NodeKey: nodeKey2.Public(),
Hostname: "personal-to-tagged", Hostname: "personal-to-tagged",
@@ -684,7 +684,7 @@ func TestExpiryDuringPersonalToTaggedConversion(t *testing.T) {
}, },
Expiry: &clientExpiry, // Client still sends expiry Expiry: &clientExpiry, // Client still sends expiry
}) })
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) app.state.SetAuthCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath( nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth", registrationID2, types.UserID(user.ID), nil, "webauth",
@@ -723,8 +723,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
nodeKey1 := key.NewNode() nodeKey1 := key.NewNode()
// Step 1: Create tagged node (expiry should be nil) // Step 1: Create tagged node (expiry should be nil)
registrationID1 := types.MustRegistrationID() registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterNode(types.Node{ regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(), NodeKey: nodeKey1.Public(),
Hostname: "tagged-to-personal", Hostname: "tagged-to-personal",
@@ -733,7 +733,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
RequestTags: []string{"tag:server"}, // Tagged node RequestTags: []string{"tag:server"}, // Tagged node
}, },
}) })
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) app.state.SetAuthCacheEntry(registrationID1, regEntry1)
node, _, err := app.state.HandleNodeFromAuthPath( node, _, err := app.state.HandleNodeFromAuthPath(
registrationID1, types.UserID(user.ID), nil, "webauth", registrationID1, types.UserID(user.ID), nil, "webauth",
@@ -745,8 +745,8 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
// Step 2: Re-auth with empty tags (Tagged → Personal conversion) // Step 2: Re-auth with empty tags (Tagged → Personal conversion)
nodeKey2 := key.NewNode() nodeKey2 := key.NewNode()
clientExpiry := time.Now().Add(48 * time.Hour) clientExpiry := time.Now().Add(48 * time.Hour)
registrationID2 := types.MustRegistrationID() registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterNode(types.Node{ regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey2.Public(), NodeKey: nodeKey2.Public(),
Hostname: "tagged-to-personal", Hostname: "tagged-to-personal",
@@ -756,7 +756,7 @@ func TestExpiryDuringTaggedToPersonalConversion(t *testing.T) {
}, },
Expiry: &clientExpiry, // Client requests expiry Expiry: &clientExpiry, // Client requests expiry
}) })
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) app.state.SetAuthCacheEntry(registrationID2, regEntry2)
nodeAfter, _, err := app.state.HandleNodeFromAuthPath( nodeAfter, _, err := app.state.HandleNodeFromAuthPath(
registrationID2, types.UserID(user.ID), nil, "webauth", registrationID2, types.UserID(user.ID), nil, "webauth",

View File

@@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) {
{ {
name: "followup_registration_success", name: "followup_registration_success",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewRegistrationID() regID, err := types.NewAuthID()
if err != nil { if err != nil {
return "", err return "", err
} }
registered := make(chan *types.Node, 1) nodeToRegister := types.NewRegisterAuthRequest(types.Node{
nodeToRegister := types.RegisterNode{ Hostname: "followup-success-node",
Node: types.Node{ })
Hostname: "followup-success-node", app.state.SetAuthCacheEntry(regID, nodeToRegister)
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Simulate successful registration - send to buffered channel // Simulate successful registration
// The channel is buffered (size 1), so this can complete immediately // handleRegister will receive the value when it starts waiting
// and handleRegister will receive the value when it starts waiting
go func() { go func() {
user := app.state.CreateUserForTest("followup-user") user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node") node := app.state.CreateNodeForTest(user, "followup-success-node")
registered <- node nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()})
}() }()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) {
{ {
name: "followup_registration_timeout", name: "followup_registration_timeout",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
regID, err := types.NewRegistrationID() regID, err := types.NewAuthID()
if err != nil { if err != nil {
return "", err return "", err
} }
registered := make(chan *types.Node, 1) nodeToRegister := types.NewRegisterAuthRequest(types.Node{
nodeToRegister := types.RegisterNode{ Hostname: "followup-timeout-node",
Node: types.Node{ })
Hostname: "followup-timeout-node", app.state.SetAuthCacheEntry(regID, nodeToRegister)
}, // Don't call FinishRegistration - will timeout
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Don't send anything on channel - will timeout
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
}, },
@@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) {
{ {
name: "followup_registration_node_nil_response", name: "followup_registration_node_nil_response",
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
regID, err := types.NewRegistrationID() regID, err := types.NewAuthID()
if err != nil { if err != nil {
return "", err return "", err
} }
registered := make(chan *types.Node, 1) nodeToRegister := types.NewRegisterAuthRequest(types.Node{
nodeToRegister := types.RegisterNode{ Hostname: "nil-response-node",
Node: types.Node{ })
Hostname: "nil-response-node", app.state.SetAuthCacheEntry(regID, nodeToRegister)
},
Registered: registered,
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Simulate registration that returns nil (cache expired during auth) // Simulate registration that returns empty NodeView (cache expired during auth)
// The channel is buffered (size 1), so this can complete immediately
go func() { go func() {
registered <- nil // Nil indicates cache expiry nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry
}() }()
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
@@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
// Generate a registration ID that doesn't exist in cache // Generate a registration ID that doesn't exist in cache
// This simulates an expired/missing cache entry // This simulates an expired/missing cache entry
regID, err := types.NewRegistrationID() regID, err := types.NewAuthID()
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) {
// Extract and validate the new registration ID exists in cache // Extract and validate the new registration ID exists in cache
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
newRegID, err := types.RegistrationIDFromString(newRegIDStr) newRegID, err := types.AuthIDFromString(newRegIDStr)
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
// Verify new registration entry exists in cache // Verify new registration entry exists in cache
_, found := app.state.GetRegistrationCacheEntry(newRegID) _, found := app.state.GetAuthCacheEntry(newRegID)
assert.True(t, found, "new registration should exist in cache") assert.True(t, found, "new registration should exist in cache")
}, },
}, },
@@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify cache entry exists // Verify cache entry exists
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
assert.True(t, found, "registration cache entry should exist initially") assert.True(t, found, "registration cache entry should exist initially")
assert.NotNil(t, cacheEntry) assert.NotNil(t, cacheEntry)
@@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
// Cache entry should still exist after auth error (for retry scenarios) // Cache entry should still exist after auth error (for retry scenarios)
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID) _, stillFound := app.state.GetAuthCacheEntry(registrationID)
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
}, },
}, },
@@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) {
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
// Both cache entries should exist simultaneously // Both cache entries should exist simultaneously
_, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2) _, found2 := app.state.GetAuthCacheEntry(regID2)
assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found1, "first registration cache entry should exist")
assert.True(t, found2, "second registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist")
@@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify both exist // Verify both exist
_, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found1 := app.state.GetAuthCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2) _, found2 := app.state.GetAuthCacheEntry(regID2)
assert.True(t, found1, "first cache entry should exist") assert.True(t, found1, "first cache entry should exist")
assert.True(t, found2, "second cache entry should exist") assert.True(t, found2, "second cache entry should exist")
@@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) {
} }
// First registration should still be in cache (not completed) // First registration should still be in cache (not completed)
_, stillFound := app.state.GetRegistrationCacheEntry(regID1) _, stillFound := app.state.GetAuthCacheEntry(regID1)
assert.True(t, stillFound, "first registration should still be pending") assert.True(t, stillFound, "first registration should still be pending")
}, },
}, },
@@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
var ( var (
initialResp *tailcfg.RegisterResponse initialResp *tailcfg.RegisterResponse
authURL string authURL string
registrationID types.RegistrationID registrationID types.AuthID
finalResp *tailcfg.RegisterResponse finalResp *tailcfg.RegisterResponse
err error err error
) )
@@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
if step.expectCacheEntry { if step.expectCacheEntry {
// Verify registration cache entry was created // Verify registration cache entry was created
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
require.True(t, found, "registration cache entry should exist") require.True(t, found, "registration cache entry should exist")
require.NotNil(t, cacheEntry, "cache entry should not be nil") require.NotNil(t, cacheEntry, "cache entry should not be nil")
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
} }
case stepTypeAuthCompletion: case stepTypeAuthCompletion:
@@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
// Check cache cleanup expectation for this step // Check cache cleanup expectation for this step
if step.expectCacheEntry == false && registrationID != "" { if step.expectCacheEntry == false && registrationID != "" {
// Verify cache entry was cleaned up // Verify cache entry was cleaned up
_, found := app.state.GetRegistrationCacheEntry(registrationID) _, found := app.state.GetAuthCacheEntry(registrationID)
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
} }
} }
@@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
} }
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. // extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
// AuthURL format: "http://localhost/register/abc123" // AuthURL format: "http://localhost/register/abc123"
const registerPrefix = "/register/" const registerPrefix = "/register/"
@@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err
idStr := authURL[idx+len(registerPrefix):] idStr := authURL[idx+len(registerPrefix):]
return types.RegistrationIDFromString(idStr) return types.AuthIDFromString(idStr)
} }
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response. // validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
@@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
// Simulate a registration cache entry (as would be created during web auth) // Simulate a registration cache entry (as would be created during web auth)
registrationID := types.MustRegistrationID() registrationID := types.MustAuthID()
regEntry := types.NewRegisterNode(types.Node{ regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostname: "webauth-tags-node", Hostname: "webauth-tags-node",
@@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
}, },
}) })
app.state.SetRegistrationCacheEntry(registrationID, regEntry) app.state.SetAuthCacheEntry(registrationID, regEntry)
// Complete the web auth - should fail because tag is unauthorized // Complete the web auth - should fail because tag is unauthorized
_, _, err := app.state.HandleNodeFromAuthPath( _, _, err := app.state.HandleNodeFromAuthPath(
@@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
nodeKey1 := key.NewNode() nodeKey1 := key.NewNode()
// Step 1: Initial registration with tags // Step 1: Initial registration with tags
registrationID1 := types.MustRegistrationID() registrationID1 := types.MustAuthID()
regEntry1 := types.NewRegisterNode(types.Node{ regEntry1 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey1.Public(), NodeKey: nodeKey1.Public(),
Hostname: "reauth-untag-node", Hostname: "reauth-untag-node",
@@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{"tag:valid-owned", "tag:second"}, RequestTags: []string{"tag:valid-owned", "tag:second"},
}, },
}) })
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) app.state.SetAuthCacheEntry(registrationID1, regEntry1)
// Complete initial registration with tags // Complete initial registration with tags
node, _, err := app.state.HandleNodeFromAuthPath( node, _, err := app.state.HandleNodeFromAuthPath(
@@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
// Step 2: Reauth with EMPTY tags to untag // Step 2: Reauth with EMPTY tags to untag
nodeKey2 := key.NewNode() // New node key for reauth nodeKey2 := key.NewNode() // New node key for reauth
registrationID2 := types.MustRegistrationID() registrationID2 := types.MustAuthID()
regEntry2 := types.NewRegisterNode(types.Node{ regEntry2 := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation) NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "reauth-untag-node", Hostname: "reauth-untag-node",
@@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag RequestTags: []string{}, // EMPTY - should untag
}, },
}) })
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) app.state.SetAuthCacheEntry(registrationID2, regEntry2)
// Complete reauth with empty tags // Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
nodeKey2 := key.NewNode() // New node key for reauth nodeKey2 := key.NewNode() // New node key for reauth
registrationID := types.MustRegistrationID() registrationID := types.MustAuthID()
regEntry := types.NewRegisterNode(types.Node{ regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key MachineKey: machineKey.Public(), // Same machine key
NodeKey: nodeKey2.Public(), // Different node key (rotation) NodeKey: nodeKey2.Public(), // Different node key (rotation)
Hostname: "authkey-tagged-node", Hostname: "authkey-tagged-node",
@@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
RequestTags: []string{}, // EMPTY - should untag RequestTags: []string{}, // EMPTY - should untag
}, },
}) })
app.state.SetRegistrationCacheEntry(registrationID, regEntry) app.state.SetAuthCacheEntry(registrationID, regEntry)
// Complete reauth with empty tags // Complete reauth with empty tags
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
@@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath // Step 4: Re-register the node to alice via HandleNodeFromAuthPath
// This is what happens when running: headscale nodes register --user alice --key ... // This is what happens when running: headscale nodes register --user alice --key ...
nodeKey2 := key.NewNode() nodeKey2 := key.NewNode()
registrationID := types.MustRegistrationID() registrationID := types.MustAuthID()
regEntry := types.NewRegisterNode(types.Node{ regEntry := types.NewRegisterAuthRequest(types.Node{
MachineKey: machineKey.Public(), // Same machine key as the tagged node MachineKey: machineKey.Public(), // Same machine key as the tagged node
NodeKey: nodeKey2.Public(), NodeKey: nodeKey2.Public(),
Hostname: "tagged-orphan-node", Hostname: "tagged-orphan-node",
@@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
RequestTags: []string{}, // Empty - transition to user-owned RequestTags: []string{}, // Empty - transition to user-owned
}, },
}) })
app.state.SetRegistrationCacheEntry(registrationID, regEntry) app.state.SetAuthCacheEntry(registrationID, regEntry)
// This should NOT panic - before the fix, this would panic with: // This should NOT panic - before the fix, this would panic with:
// panic: runtime error: invalid memory address or nil pointer dereference // panic: runtime error: invalid memory address or nil pointer dereference

View File

@@ -47,7 +47,7 @@ const (
type HSDatabase struct { type HSDatabase struct {
DB *gorm.DB DB *gorm.DB
cfg *types.Config cfg *types.Config
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] regCache *zcache.Cache[types.AuthID, types.AuthRequest]
} }
// NewHeadscaleDatabase creates a new database connection and runs migrations. // NewHeadscaleDatabase creates a new database connection and runs migrations.
@@ -56,7 +56,7 @@ type HSDatabase struct {
//nolint:gocyclo // complex database initialization with many migrations //nolint:gocyclo // complex database initialization with many migrations
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
cfg *types.Config, cfg *types.Config,
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], regCache *zcache.Cache[types.AuthID, types.AuthRequest],
) (*HSDatabase, error) { ) (*HSDatabase, error) {
dbConn, err := openDB(cfg.Database) dbConn, err := openDB(cfg.Database)
if err != nil { if err != nil {

View File

@@ -162,8 +162,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
} }
} }
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
} }
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {

View File

@@ -247,7 +247,7 @@ func (api headscaleV1APIServer) RegisterNode(
Str(zf.RegistrationKey, registrationKey). Str(zf.RegistrationKey, registrationKey).
Msg("registering node") Msg("registering node")
registrationId, err := types.RegistrationIDFromString(request.GetKey()) registrationId, err := types.AuthIDFromString(request.GetKey())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -808,33 +808,32 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: request.GetName(), Hostname: request.GetName(),
} }
registrationId, err := types.RegistrationIDFromString(request.GetKey()) registrationId, err := types.AuthIDFromString(request.GetKey())
if err != nil { if err != nil {
return nil, err return nil, err
} }
newNode := types.NewRegisterNode( newNode := types.Node{
types.Node{ NodeKey: key.NewNode().Public(),
NodeKey: key.NewNode().Public(), MachineKey: key.NewMachine().Public(),
MachineKey: key.NewMachine().Public(), Hostname: request.GetName(),
Hostname: request.GetName(), User: user,
User: user,
Expiry: &time.Time{}, Expiry: &time.Time{},
LastSeen: &time.Time{}, LastSeen: &time.Time{},
Hostinfo: &hostinfo, Hostinfo: &hostinfo,
}, }
)
log.Debug(). log.Debug().
Caller(). Caller().
Str("registration_id", registrationId.String()). Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache") Msg("adding debug machine via CLI, appending to registration cache")
api.h.state.SetRegistrationCacheEntry(registrationId, newNode) authRegReq := types.NewRegisterAuthRequest(newNode)
api.h.state.SetAuthCacheEntry(registrationId, authRegReq)
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
} }
func (api headscaleV1APIServer) Health( func (api headscaleV1APIServer) Health(

View File

@@ -11,7 +11,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/assets" "github.com/juanfont/headscale/hscontrol/assets"
"github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@@ -245,11 +244,41 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
} }
} }
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { func (a *AuthProviderWeb) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf( return fmt.Sprintf(
"%s/register/%s", "%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"), strings.TrimSuffix(a.serverURL, "/"),
registrationId.String()) authID.String())
}
func (a *AuthProviderWeb) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderWeb) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
}
func authIDFromRequest(req *http.Request) (types.AuthID, error) {
registrationId, err := urlParam[types.AuthID](req, "auth_id")
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
err = registrationId.Validate()
if err != nil {
return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err))
}
return registrationId, nil
} }
// RegisterHandler shows a simple message in the browser to point to the CLI // RegisterHandler shows a simple message in the browser to point to the CLI
@@ -261,15 +290,9 @@ func (a *AuthProviderWeb) RegisterHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
vars := mux.Vars(req) registrationId, err := authIDFromRequest(req)
registrationIdStr := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil { if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) httpError(writer, err)
return return
} }

View File

@@ -95,8 +95,8 @@ var allBatcherFunctions = []batcherTestCase{
} }
// emptyCache creates an empty registration cache for testing. // emptyCache creates an empty registration cache for testing.
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
} }
// Test configuration constants. // Test configuration constants.

View File

@@ -24,6 +24,12 @@ import (
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version. // ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
var ErrUnsupportedClientVersion = errors.New("unsupported client version") var ErrUnsupportedClientVersion = errors.New("unsupported client version")
// ErrMissingURLParameter is returned when a required URL parameter is not provided.
var ErrMissingURLParameter = errors.New("missing URL parameter")
// ErrUnsupportedURLParameterType is returned when a URL parameter has an unsupported type.
var ErrUnsupportedURLParameterType = errors.New("unsupported URL parameter type")
const ( const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021" ts2021UpgradePath = "/ts2021"
@@ -374,3 +380,28 @@ func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.
return nv, nil return nv, nil
} }
// urlParam extracts a typed URL parameter from a chi router request.
func urlParam[T any](req *http.Request, key string) (T, error) {
var zero T
param := chi.URLParam(req, key)
if param == "" {
return zero, fmt.Errorf("%w: %s", ErrMissingURLParameter, key)
}
var value T
switch any(value).(type) {
case string:
v, ok := any(param).(T)
if !ok {
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
value = v
default:
return zero, fmt.Errorf("%w: %T", ErrUnsupportedURLParameterType, value)
}
return value, nil
}

View File

@@ -12,7 +12,6 @@ import (
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@@ -26,8 +25,8 @@ import (
const ( const (
randomByteSize = 16 randomByteSize = 16
defaultOAuthOptionsCount = 3 defaultOAuthOptionsCount = 3
registerCacheExpiration = time.Minute * 15 authCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20 authCacheCleanup = time.Minute * 20
) )
var ( var (
@@ -44,17 +43,21 @@ var (
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email") errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
) )
// RegistrationInfo contains both machine key and verifier information for OIDC validation. // AuthInfo contains both auth ID and verifier information for OIDC validation.
type RegistrationInfo struct { type AuthInfo struct {
RegistrationID types.RegistrationID AuthID types.AuthID
Verifier *string Verifier *string
Registration bool
} }
type AuthProviderOIDC struct { type AuthProviderOIDC struct {
h *Headscale h *Headscale
serverURL string serverURL string
cfg *types.OIDCConfig cfg *types.OIDCConfig
registrationCache *zcache.Cache[string, RegistrationInfo]
// authCache holds auth information between
// the auth and the callback steps.
authCache *zcache.Cache[string, AuthInfo]
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
@@ -81,45 +84,63 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope, Scopes: cfg.Scope,
} }
registrationCache := zcache.New[string, RegistrationInfo]( authCache := zcache.New[string, AuthInfo](
registerCacheExpiration, authCacheExpiration,
registerCacheCleanup, authCacheCleanup,
) )
return &AuthProviderOIDC{ return &AuthProviderOIDC{
h: h, h: h,
serverURL: serverURL, serverURL: serverURL,
cfg: cfg, cfg: cfg,
registrationCache: registrationCache, authCache: authCache,
oidcProvider: oidcProvider, oidcProvider: oidcProvider,
oauth2Config: oauth2Config, oauth2Config: oauth2Config,
}, nil }, nil
} }
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { func (a *AuthProviderOIDC) AuthURL(authID types.AuthID) string {
return fmt.Sprintf(
"%s/auth/%s",
strings.TrimSuffix(a.serverURL, "/"),
authID.String())
}
func (a *AuthProviderOIDC) AuthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
a.authHandler(writer, req, false)
}
func (a *AuthProviderOIDC) RegisterURL(authID types.AuthID) string {
return fmt.Sprintf( return fmt.Sprintf(
"%s/register/%s", "%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"), strings.TrimSuffix(a.serverURL, "/"),
registrationID.String()) authID.String())
} }
// RegisterHandler registers the OIDC callback handler with the given router. // RegisterHandler registers the OIDC callback handler with the given router.
// It puts NodeKey in cache so the callback can retrieve it using the oidc state param. // It puts NodeKey in cache so the callback can retrieve it using the oidc state param.
// Listens in /register/:registration_id. // Listens in /register/:auth_id.
func (a *AuthProviderOIDC) RegisterHandler( func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
vars := mux.Vars(req) a.authHandler(writer, req, true)
registrationIdStr := vars["registration_id"] }
// We need to make sure we dont open for XSS style injections, if the parameter that // authHandler takes an incoming request that needs to be authenticated and
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render // validates and prepares it for the OIDC flow.
// the template and log an error. func (a *AuthProviderOIDC) authHandler(
registrationId, err := types.RegistrationIDFromString(registrationIdStr) writer http.ResponseWriter,
req *http.Request,
registration bool,
) {
authID, err := authIDFromRequest(req)
if err != nil { if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) httpError(writer, err)
return return
} }
@@ -137,9 +158,9 @@ func (a *AuthProviderOIDC) RegisterHandler(
return return
} }
// Initialize registration info with machine key registrationInfo := AuthInfo{
registrationInfo := RegistrationInfo{ AuthID: authID,
RegistrationID: registrationId, Registration: registration,
} }
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@@ -167,7 +188,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
extras = append(extras, oidc.Nonce(nonce)) extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info // Cache the registration info
a.registrationCache.Set(state, registrationInfo) a.authCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...) authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
@@ -302,16 +323,22 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// If the node exists, then the node should be reauthenticated, // If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then // if the node does not exist, and the machine key exists, then
// this is a new node that should be registered. // this is a new node that should be registered.
registrationId := a.getRegistrationIDFromState(state) authInfo := a.getAuthInfoFromState(state)
if authInfo == nil {
log.Debug().Caller().Str("state", state).Msg("state not found in cache, login session may have expired")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
// Register the node if it does not exist. return
if registrationId != nil { }
// If this is a registration flow, then we need to register the node.
if authInfo.Registration {
verb := "Reauthenticated" verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) newNode, err := a.handleRegistration(user, authInfo.AuthID, nodeExpiry)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") log.Debug().Caller().Str("registration_id", authInfo.AuthID.String()).Msg("registration session expired before authorization completed")
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err))
return return
@@ -339,9 +366,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return return
} }
// Neither node nor machine key was found in the state cache meaning // TODO(kradalby): handle login flow (without registration) if needed.
// that we could not reauth nor register the node. // We need to send an update here to whatever might be waiting for this auth flow.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
} }
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@@ -374,7 +400,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
var exchangeOpts []oauth2.AuthCodeOption var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled { if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state) regInfo, ok := a.authCache.Get(state)
if !ok { if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
} }
@@ -507,14 +533,14 @@ func doOIDCAuthorization(
return nil return nil
} }
// getRegistrationIDFromState retrieves the registration ID from the state. // getAuthInfoFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { func (a *AuthProviderOIDC) getAuthInfoFromState(state string) *AuthInfo {
regInfo, ok := a.registrationCache.Get(state) authInfo, ok := a.authCache.Get(state)
if !ok { if !ok {
return nil return nil
} }
return &regInfo.RegistrationID return &authInfo
} }
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
@@ -562,7 +588,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) handleRegistration( func (a *AuthProviderOIDC) handleRegistration(
user *types.User, user *types.User,
registrationID types.RegistrationID, registrationID types.AuthID,
expiry time.Time, expiry time.Time,
) (bool, error) { ) (bool, error) {
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(

View File

@@ -64,6 +64,9 @@ var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
// ErrNodeNameNotUnique is returned when a node name is not unique. // ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique") var ErrNodeNameNotUnique = errors.New("node name is not unique")
// ErrRegistrationExpired is returned when a registration has expired.
var ErrRegistrationExpired = errors.New("registration expired")
// State manages Headscale's core state, coordinating between database, policy management, // State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe. // IP allocation, and DERP routing. All methods are thread-safe.
type State struct { type State struct {
@@ -82,8 +85,10 @@ type State struct {
derpMap atomic.Pointer[tailcfg.DERPMap] derpMap atomic.Pointer[tailcfg.DERPMap]
// polMan handles policy evaluation and management // polMan handles policy evaluation and management
polMan policy.PolicyManager polMan policy.PolicyManager
// registrationCache caches node registration data to reduce database load
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] // authCache caches any pending authentication requests, from either auth type (Web and OIDC).
authCache *zcache.Cache[types.AuthID, types.AuthRequest]
// primaryRoutes tracks primary route assignments for nodes // primaryRoutes tracks primary route assignments for nodes
primaryRoutes *routes.PrimaryRoutes primaryRoutes *routes.PrimaryRoutes
} }
@@ -101,20 +106,20 @@ func NewState(cfg *types.Config) (*State, error) {
cacheCleanup = cfg.Tuning.RegisterCacheCleanup cacheCleanup = cfg.Tuning.RegisterCacheCleanup
} }
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( authCache := zcache.New[types.AuthID, types.AuthRequest](
cacheExpiration, cacheExpiration,
cacheCleanup, cacheCleanup,
) )
registrationCache.OnEvicted( authCache.OnEvicted(
func(id types.RegistrationID, rn types.RegisterNode) { func(id types.AuthID, rn types.AuthRequest) {
rn.SendAndClose(nil) rn.FinishAuth(types.AuthVerdict{Err: ErrRegistrationExpired})
}, },
) )
db, err := hsdb.NewHeadscaleDatabase( db, err := hsdb.NewHeadscaleDatabase(
cfg, cfg,
registrationCache, authCache,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("initializing database: %w", err) return nil, fmt.Errorf("initializing database: %w", err)
@@ -178,12 +183,12 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{ return &State{
cfg: cfg, cfg: cfg,
db: db, db: db,
ipAlloc: ipAlloc, ipAlloc: ipAlloc,
polMan: polMan, polMan: polMan,
registrationCache: registrationCache, authCache: authCache,
primaryRoutes: routes.New(), primaryRoutes: routes.New(),
nodeStore: nodeStore, nodeStore: nodeStore,
}, nil }, nil
} }
@@ -1057,9 +1062,9 @@ func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id) return s.db.DeletePreAuthKey(id)
} }
// GetRegistrationCacheEntry retrieves a node registration from cache. // GetAuthCacheEntry retrieves a node registration from cache.
func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) { func (s *State) GetAuthCacheEntry(id types.AuthID) (*types.AuthRequest, bool) {
entry, found := s.registrationCache.Get(id) entry, found := s.authCache.Get(id)
if !found { if !found {
return nil, false return nil, false
} }
@@ -1067,26 +1072,24 @@ func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.Regis
return &entry, true return &entry, true
} }
// SetRegistrationCacheEntry stores a node registration in cache. // SetAuthCacheEntry stores a node registration in cache.
func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) { func (s *State) SetAuthCacheEntry(id types.AuthID, entry types.AuthRequest) {
s.registrationCache.Set(id, entry) s.authCache.Set(id, entry)
} }
// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. // logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname.
func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { func logHostinfoValidation(nv types.NodeView, username, hostname string) {
if hostinfo == nil { if !nv.Hostinfo().Valid() {
log.Warn(). log.Warn().
Caller(). Caller().
Str(zf.MachineKey, machineKey). EmbedObject(nv).
Str(zf.NodeKey, nodeKey).
Str(zf.UserName, username). Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname). Str(zf.GeneratedHostname, hostname).
Msg("Registration had nil hostinfo, generated default hostname") Msg("Registration had nil hostinfo, generated default hostname")
} else if hostinfo.Hostname == "" { } else if nv.Hostinfo().Hostname() == "" {
log.Warn(). log.Warn().
Caller(). Caller().
Str(zf.MachineKey, machineKey). EmbedObject(nv).
Str(zf.NodeKey, nodeKey).
Str(zf.UserName, username). Str(zf.UserName, username).
Str(zf.GeneratedHostname, hostname). Str(zf.GeneratedHostname, hostname).
Msg("Registration had empty hostname, generated default") Msg("Registration had empty hostname, generated default")
@@ -1128,7 +1131,7 @@ type authNodeUpdateParams struct {
// Node to update; must be valid and in NodeStore. // Node to update; must be valid and in NodeStore.
ExistingNode types.NodeView ExistingNode types.NodeView
// Client data: keys, hostinfo, endpoints. // Client data: keys, hostinfo, endpoints.
RegEntry *types.RegisterNode RegEntry *types.AuthRequest
// Pre-validated hostinfo; NetInfo preserved from ExistingNode. // Pre-validated hostinfo; NetInfo preserved from ExistingNode.
ValidHostinfo *tailcfg.Hostinfo ValidHostinfo *tailcfg.Hostinfo
// Hostname from hostinfo, or generated from keys if client omits it. // Hostname from hostinfo, or generated from keys if client omits it.
@@ -1147,6 +1150,7 @@ type authNodeUpdateParams struct {
// an existing node. It updates the node in NodeStore, processes RequestTags, and // an existing node. It updates the node in NodeStore, processes RequestTags, and
// persists changes to the database. // persists changes to the database.
func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) { func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView, error) {
regNv := params.RegEntry.Node()
// Log the operation type // Log the operation type
if params.IsConvertFromTag { if params.IsConvertFromTag {
log.Info(). log.Info().
@@ -1155,16 +1159,16 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
Msg("Converting tagged node to user-owned node") Msg("Converting tagged node to user-owned node")
} else { } else {
log.Info(). log.Info().
EmbedObject(params.ExistingNode). Object("existing", params.ExistingNode).
Interface("hostinfo", params.RegEntry.Node.Hostinfo). Object("incoming", regNv).
Msg("Updating existing node registration via reauth") Msg("Updating existing node registration via reauth")
} }
// Process RequestTags during reauth (#2979) // Process RequestTags during reauth (#2979)
// Due to json:",omitempty", we treat empty/nil as "clear tags" // Due to json:",omitempty", we treat empty/nil as "clear tags"
var requestTags []string var requestTags []string
if params.RegEntry.Node.Hostinfo != nil { if regNv.Hostinfo().Valid() {
requestTags = params.RegEntry.Node.Hostinfo.RequestTags requestTags = regNv.Hostinfo().RequestTags().AsSlice()
} }
oldTags := params.ExistingNode.Tags().AsSlice() oldTags := params.ExistingNode.Tags().AsSlice()
@@ -1182,8 +1186,8 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
// Update existing node in NodeStore - validation passed, safe to mutate // Update existing node in NodeStore - validation passed, safe to mutate
updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) { updatedNodeView, ok := s.nodeStore.UpdateNode(params.ExistingNode.ID(), func(node *types.Node) {
node.NodeKey = params.RegEntry.Node.NodeKey node.NodeKey = regNv.NodeKey()
node.DiscoKey = params.RegEntry.Node.DiscoKey node.DiscoKey = regNv.DiscoKey()
node.Hostname = params.Hostname node.Hostname = params.Hostname
// Preserve NetInfo from existing node when re-registering // Preserve NetInfo from existing node when re-registering
@@ -1194,7 +1198,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
params.ValidHostinfo, params.ValidHostinfo,
) )
node.Endpoints = params.RegEntry.Node.Endpoints node.Endpoints = regNv.Endpoints().AsSlice()
node.IsOnline = new(false) node.IsOnline = new(false)
node.LastSeen = new(time.Now()) node.LastSeen = new(time.Now())
@@ -1203,7 +1207,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.IsConvertFromTag { if params.IsConvertFromTag {
node.RegisterMethod = params.RegisterMethod node.RegisterMethod = params.RegisterMethod
} else { } else {
node.RegisterMethod = params.RegEntry.Node.RegisterMethod node.RegisterMethod = regNv.RegisterMethod()
} }
// Track tagged status BEFORE processing tags // Track tagged status BEFORE processing tags
@@ -1223,7 +1227,7 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil { if params.Expiry != nil {
node.Expiry = params.Expiry node.Expiry = params.Expiry
} else { } else {
node.Expiry = params.RegEntry.Node.Expiry node.Expiry = regNv.Expiry().Clone()
} }
case !wasTagged && isTagged: case !wasTagged && isTagged:
// Personal → Tagged: clear expiry (tagged nodes don't expire) // Personal → Tagged: clear expiry (tagged nodes don't expire)
@@ -1233,14 +1237,14 @@ func (s *State) applyAuthNodeUpdate(params authNodeUpdateParams) (types.NodeView
if params.Expiry != nil { if params.Expiry != nil {
node.Expiry = params.Expiry node.Expiry = params.Expiry
} else { } else {
node.Expiry = params.RegEntry.Node.Expiry node.Expiry = regNv.Expiry().Clone()
} }
case !isTagged: case !isTagged:
// Personal → Personal: update expiry from client // Personal → Personal: update expiry from client
if params.Expiry != nil { if params.Expiry != nil {
node.Expiry = params.Expiry node.Expiry = params.Expiry
} else { } else {
node.Expiry = params.RegEntry.Node.Expiry node.Expiry = regNv.Expiry().Clone()
} }
} }
// Tagged → Tagged: keep existing expiry (nil) - no action needed // Tagged → Tagged: keep existing expiry (nil) - no action needed
@@ -1527,13 +1531,13 @@ func (s *State) processReauthTags(
// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). // HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC).
func (s *State) HandleNodeFromAuthPath( func (s *State) HandleNodeFromAuthPath(
registrationID types.RegistrationID, authID types.AuthID,
userID types.UserID, userID types.UserID,
expiry *time.Time, expiry *time.Time,
registrationMethod string, registrationMethod string,
) (types.NodeView, change.Change, error) { ) (types.NodeView, change.Change, error) {
// Get the registration entry from cache // Get the registration entry from cache
regEntry, ok := s.GetRegistrationCacheEntry(registrationID) regEntry, ok := s.GetAuthCacheEntry(authID)
if !ok { if !ok {
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
} }
@@ -1546,25 +1550,27 @@ func (s *State) HandleNodeFromAuthPath(
// Ensure we have a valid hostname from the registration cache entry // Ensure we have a valid hostname from the registration cache entry
hostname := util.EnsureHostname( hostname := util.EnsureHostname(
regEntry.Node.Hostinfo, regEntry.Node().Hostinfo(),
regEntry.Node.MachineKey.String(), regEntry.Node().MachineKey().String(),
regEntry.Node.NodeKey.String(), regEntry.Node().NodeKey().String(),
) )
// Ensure we have valid hostinfo // Ensure we have valid hostinfo
validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{}) hostinfo := &tailcfg.Hostinfo{}
validHostinfo.Hostname = hostname if regEntry.Node().Hostinfo().Valid() {
hostinfo = regEntry.Node().Hostinfo().AsStruct()
}
hostinfo.Hostname = hostname
logHostinfoValidation( logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(), regEntry.Node(),
regEntry.Node.NodeKey.String(),
user.Name, user.Name,
hostname, hostname,
regEntry.Node.Hostinfo,
) )
// Lookup existing nodes // Lookup existing nodes
machineKey := regEntry.Node.MachineKey machineKey := regEntry.Node().MachineKey()
existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID)) existingNodeSameUser, _ := s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(user.ID))
existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) existingNodeAnyUser, _ := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
@@ -1578,7 +1584,7 @@ func (s *State) HandleNodeFromAuthPath(
// Create logger with common fields for all auth operations // Create logger with common fields for all auth operations
logger := log.With(). logger := log.With().
Str(zf.RegistrationID, registrationID.String()). Str(zf.RegistrationID, authID.String()).
Str(zf.UserName, user.Name). Str(zf.UserName, user.Name).
Str(zf.MachineKey, machineKey.ShortString()). Str(zf.MachineKey, machineKey.ShortString()).
Str(zf.Method, registrationMethod). Str(zf.Method, registrationMethod).
@@ -1587,7 +1593,7 @@ func (s *State) HandleNodeFromAuthPath(
// Common params for update operations // Common params for update operations
updateParams := authNodeUpdateParams{ updateParams := authNodeUpdateParams{
RegEntry: regEntry, RegEntry: regEntry,
ValidHostinfo: validHostinfo, ValidHostinfo: hostinfo,
Hostname: hostname, Hostname: hostname,
User: user, User: user,
Expiry: expiry, Expiry: expiry,
@@ -1621,7 +1627,7 @@ func (s *State) HandleNodeFromAuthPath(
Msg("Creating new node for different user (same machine key exists for another user)") Msg("Creating new node for different user (same machine key exists for another user)")
finalNode, err = s.createNewNodeFromAuth( finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo, logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, existingNodeAnyUser, expiry, registrationMethod, existingNodeAnyUser,
) )
if err != nil { if err != nil {
@@ -1629,7 +1635,7 @@ func (s *State) HandleNodeFromAuthPath(
} }
} else { } else {
finalNode, err = s.createNewNodeFromAuth( finalNode, err = s.createNewNodeFromAuth(
logger, user, regEntry, hostname, validHostinfo, logger, user, regEntry, hostname, hostinfo,
expiry, registrationMethod, types.NodeView{}, expiry, registrationMethod, types.NodeView{},
) )
if err != nil { if err != nil {
@@ -1638,10 +1644,10 @@ func (s *State) HandleNodeFromAuthPath(
} }
// Signal to waiting clients // Signal to waiting clients
regEntry.SendAndClose(finalNode.AsStruct()) regEntry.FinishAuth(types.AuthVerdict{Node: finalNode})
// Delete from registration cache // Delete from registration cache
s.registrationCache.Delete(registrationID) s.authCache.Delete(authID)
// Update policy managers // Update policy managers
usersChange, err := s.updatePolicyManagerUsers() usersChange, err := s.updatePolicyManagerUsers()
@@ -1670,7 +1676,7 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) createNewNodeFromAuth( func (s *State) createNewNodeFromAuth(
logger zerolog.Logger, logger zerolog.Logger,
user *types.User, user *types.User,
regEntry *types.RegisterNode, regEntry *types.AuthRequest,
hostname string, hostname string,
validHostinfo *tailcfg.Hostinfo, validHostinfo *tailcfg.Hostinfo,
expiry *time.Time, expiry *time.Time,
@@ -1683,13 +1689,13 @@ func (s *State) createNewNodeFromAuth(
return s.createAndSaveNewNode(newNodeParams{ return s.createAndSaveNewNode(newNodeParams{
User: *user, User: *user,
MachineKey: regEntry.Node.MachineKey, MachineKey: regEntry.Node().MachineKey(),
NodeKey: regEntry.Node.NodeKey, NodeKey: regEntry.Node().NodeKey(),
DiscoKey: regEntry.Node.DiscoKey, DiscoKey: regEntry.Node().DiscoKey(),
Hostname: hostname, Hostname: hostname,
Hostinfo: validHostinfo, Hostinfo: validHostinfo,
Endpoints: regEntry.Node.Endpoints, Endpoints: regEntry.Node().Endpoints().AsSlice(),
Expiry: cmp.Or(expiry, regEntry.Node.Expiry), Expiry: cmp.Or(expiry, regEntry.Node().Expiry().Clone()),
RegisterMethod: registrationMethod, RegisterMethod: registrationMethod,
ExistingNodeForNetinfo: existingNodeForNetinfo, ExistingNodeForNetinfo: existingNodeForNetinfo,
}) })
@@ -1784,7 +1790,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Ensure we have a valid hostname - handle nil/empty cases // Ensure we have a valid hostname - handle nil/empty cases
hostname := util.EnsureHostname( hostname := util.EnsureHostname(
regReq.Hostinfo, regReq.Hostinfo.View(),
machineKey.String(), machineKey.String(),
regReq.NodeKey.String(), regReq.NodeKey.String(),
) )
@@ -1793,14 +1799,6 @@ func (s *State) HandleNodeFromPreAuthKey(
validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{}) validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{})
validHostinfo.Hostname = hostname validHostinfo.Hostname = hostname
logHostinfoValidation(
machineKey.ShortString(),
regReq.NodeKey.ShortString(),
pakUsername(),
hostname,
regReq.Hostinfo,
)
log.Debug(). log.Debug().
Caller(). Caller().
Str(zf.NodeName, hostname). Str(zf.NodeName, hostname).

View File

@@ -7,7 +7,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
) )
func RegisterWeb(registrationID types.RegistrationID) *elem.Element { func RegisterWeb(registrationID types.AuthID) *elem.Element {
return HtmlStructure( return HtmlStructure(
elem.Title(nil, elem.Text("Registration - Headscale")), elem.Title(nil, elem.Text("Registration - Headscale")),
mdTypesetBody( mdTypesetBody(

View File

@@ -21,7 +21,7 @@ func TestTemplateHTMLConsistency(t *testing.T) {
}, },
{ {
name: "Register Web", name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
}, },
{ {
name: "Windows Config", name: "Windows Config",
@@ -77,7 +77,7 @@ func TestTemplateModernHTMLFeatures(t *testing.T) {
}, },
{ {
name: "Register Web", name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
}, },
{ {
name: "Windows Config", name: "Windows Config",
@@ -125,7 +125,7 @@ func TestTemplateExternalLinkSecurity(t *testing.T) {
}, },
{ {
name: "Register Web", name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
externalURLs: []string{}, // No external links externalURLs: []string{}, // No external links
}, },
{ {
@@ -190,7 +190,7 @@ func TestTemplateAccessibilityAttributes(t *testing.T) {
}, },
{ {
name: "Register Web", name: "Register Web",
html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), html: templates.RegisterWeb(types.AuthID("test-key-123")).Render(),
}, },
{ {
name: "Windows Config", name: "Windows Config",

View File

@@ -22,8 +22,8 @@ const (
// Common errors. // Common errors.
var ( var (
ErrCannotParsePrefix = errors.New("cannot parse prefix") ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length") ErrInvalidAuthIDLength = errors.New("registration ID has invalid length")
) )
type StateUpdateType int type StateUpdateType int
@@ -159,21 +159,21 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
} }
} }
const RegistrationIDLength = 24 const AuthIDLength = 24
type RegistrationID string type AuthID string
func NewRegistrationID() (RegistrationID, error) { func NewAuthID() (AuthID, error) {
rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) rid, err := util.GenerateRandomStringURLSafe(AuthIDLength)
if err != nil { if err != nil {
return "", err return "", err
} }
return RegistrationID(rid), nil return AuthID(rid), nil
} }
func MustRegistrationID() RegistrationID { func MustAuthID() AuthID {
rid, err := NewRegistrationID() rid, err := NewAuthID()
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -181,43 +181,89 @@ func MustRegistrationID() RegistrationID {
return rid return rid
} }
func RegistrationIDFromString(str string) (RegistrationID, error) { func AuthIDFromString(str string) (AuthID, error) {
if len(str) != RegistrationIDLength { r := AuthID(str)
return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str))
err := r.Validate()
if err != nil {
return "", err
} }
return RegistrationID(str), nil return r, nil
} }
func (r RegistrationID) String() string { func (r AuthID) String() string {
return string(r) return string(r)
} }
type RegisterNode struct { func (r AuthID) Validate() error {
Node Node if len(r) != AuthIDLength {
Registered chan *Node return fmt.Errorf("%w: expected %d, got %d", ErrInvalidAuthIDLength, AuthIDLength, len(r))
closed *atomic.Bool }
return nil
} }
func NewRegisterNode(node Node) RegisterNode { // AuthRequest represent a pending authentication request from a user or a node.
return RegisterNode{ // If it is a registration request, the node field will be populate with the node that is trying to register.
Node: node, // When the authentication process is finished, the node that has been authenticated will be sent through the Finished channel.
Registered: make(chan *Node), // The closed field is used to ensure that the Finished channel is only closed once, and that no more nodes are sent through it after it has been closed.
closed: &atomic.Bool{}, type AuthRequest struct {
node *Node
finished chan AuthVerdict
closed *atomic.Bool
}
func NewRegisterAuthRequest(node Node) AuthRequest {
return AuthRequest{
node: &node,
finished: make(chan AuthVerdict),
closed: &atomic.Bool{},
} }
} }
func (rn *RegisterNode) SendAndClose(node *Node) { // Node returns the node that is trying to register.
// It will panic if the AuthRequest is not a registration request.
// Can _only_ be used in the registration path.
func (rn *AuthRequest) Node() NodeView {
if rn.node == nil {
panic("Node can only be used in registration requests")
}
return rn.node.View()
}
func (rn *AuthRequest) FinishAuth(verdict AuthVerdict) {
if rn.closed.Swap(true) { if rn.closed.Swap(true) {
return return
} }
select { select {
case rn.Registered <- node: case rn.finished <- verdict:
default: default:
} }
close(rn.Registered) close(rn.finished)
}
func (rn *AuthRequest) WaitForAuth() <-chan AuthVerdict {
return rn.finished
}
type AuthVerdict struct {
// Err is the error that occurred during the authentication process, if any.
// If Err is nil, the authentication process has succeeded.
// If Err is not nil, the authentication process has failed and the node should not be authenticated.
Err error
// Node is the node that has been authenticated.
// Node is only valid if the auth request was a registration request
// and the authentication process has succeeded.
Node NodeView
}
func (v AuthVerdict) Accept() bool {
return v.Err == nil
} }
// DefaultBatcherWorkers returns the default number of batcher workers. // DefaultBatcherWorkers returns the default number of batcher workers.

View File

@@ -295,8 +295,8 @@ func IsCI() bool {
// 3. If normalisation fails → generate invalid-<random> replacement // 3. If normalisation fails → generate invalid-<random> replacement
// //
// Returns the guaranteed-valid hostname to use. // Returns the guaranteed-valid hostname to use.
func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { func EnsureHostname(hostinfo tailcfg.HostinfoView, machineKey, nodeKey string) string {
if hostinfo == nil || hostinfo.Hostname == "" { if !hostinfo.Valid() || hostinfo.Hostname() == "" {
key := cmp.Or(machineKey, nodeKey) key := cmp.Or(machineKey, nodeKey)
if key == "" { if key == "" {
return "unknown-node" return "unknown-node"
@@ -310,7 +310,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
return "node-" + keyPrefix return "node-" + keyPrefix
} }
lowercased := strings.ToLower(hostinfo.Hostname) lowercased := strings.ToLower(hostinfo.Hostname())
err := ValidateHostname(lowercased) err := ValidateHostname(lowercased)
if err == nil { if err == nil {

View File

@@ -1070,7 +1070,7 @@ func TestEnsureHostname(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) got := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies // For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.want, "invalid-") { if strings.HasPrefix(tt.want, "invalid-") {
if !strings.HasPrefix(got, "invalid-") { if !strings.HasPrefix(got, "invalid-") {
@@ -1255,7 +1255,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) gotHostname := EnsureHostname(tt.hostinfo.View(), tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies // For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.wantHostname, "invalid-") { if strings.HasPrefix(tt.wantHostname, "invalid-") {
if !strings.HasPrefix(gotHostname, "invalid-") { if !strings.HasPrefix(gotHostname, "invalid-") {
@@ -1284,7 +1284,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname} hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := EnsureHostname(hostinfo, "mkey", "nkey") result := EnsureHostname(hostinfo.View(), "mkey", "nkey")
if len(result) > 63 { if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
} }
@@ -1300,8 +1300,8 @@ func TestEnsureHostname_Idempotent(t *testing.T) {
OS: "linux", OS: "linux",
} }
hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey") hostname1 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey") hostname2 := EnsureHostname(originalHostinfo.View(), "mkey", "nkey")
if hostname1 != hostname2 { if hostname1 != hostname2 {
t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2)

View File

@@ -1065,11 +1065,11 @@ func TestNodeCommand(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
regIDs := []string{ regIDs := []string{
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
} }
nodes := make([]*v1.Node, len(regIDs)) nodes := make([]*v1.Node, len(regIDs))
@@ -1153,8 +1153,8 @@ func TestNodeCommand(t *testing.T) {
assert.Equal(t, "node-5", listAll[4].GetName()) assert.Equal(t, "node-5", listAll[4].GetName())
otherUserRegIDs := []string{ otherUserRegIDs := []string{
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
} }
otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) otherUserMachines := make([]*v1.Node, len(otherUserRegIDs))
@@ -1326,11 +1326,11 @@ func TestNodeExpireCommand(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
regIDs := []string{ regIDs := []string{
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
} }
nodes := make([]*v1.Node, len(regIDs)) nodes := make([]*v1.Node, len(regIDs))
@@ -1461,11 +1461,11 @@ func TestNodeRenameCommand(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
regIDs := []string{ regIDs := []string{
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
types.MustRegistrationID().String(), types.MustAuthID().String(),
} }
nodes := make([]*v1.Node, len(regIDs)) nodes := make([]*v1.Node, len(regIDs))