mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-21 08:11:43 +02:00
auth: generalise auth flow and introduce AuthVerdict
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
This commit is contained in:
@@ -676,28 +676,23 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
{
|
||||
name: "followup_registration_success",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registered := make(chan *types.Node, 1)
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
Hostname: "followup-success-node",
|
||||
},
|
||||
Registered: registered,
|
||||
}
|
||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
Hostname: "followup-success-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
|
||||
// Simulate successful registration - send to buffered channel
|
||||
// The channel is buffered (size 1), so this can complete immediately
|
||||
// and handleRegister will receive the value when it starts waiting
|
||||
// Simulate successful registration
|
||||
// handleRegister will receive the value when it starts waiting
|
||||
go func() {
|
||||
user := app.state.CreateUserForTest("followup-user")
|
||||
|
||||
node := app.state.CreateNodeForTest(user, "followup-success-node")
|
||||
registered <- node
|
||||
nodeToRegister.FinishAuth(types.AuthVerdict{Node: node.View()})
|
||||
}()
|
||||
|
||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||
@@ -723,20 +718,16 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
{
|
||||
name: "followup_registration_timeout",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registered := make(chan *types.Node, 1)
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
Hostname: "followup-timeout-node",
|
||||
},
|
||||
Registered: registered,
|
||||
}
|
||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
||||
// Don't send anything on channel - will timeout
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
Hostname: "followup-timeout-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
// Don't call FinishRegistration - will timeout
|
||||
|
||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||
},
|
||||
@@ -1345,24 +1336,19 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
{
|
||||
name: "followup_registration_node_nil_response",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
registered := make(chan *types.Node, 1)
|
||||
nodeToRegister := types.RegisterNode{
|
||||
Node: types.Node{
|
||||
Hostname: "nil-response-node",
|
||||
},
|
||||
Registered: registered,
|
||||
}
|
||||
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
|
||||
nodeToRegister := types.NewRegisterAuthRequest(types.Node{
|
||||
Hostname: "nil-response-node",
|
||||
})
|
||||
app.state.SetAuthCacheEntry(regID, nodeToRegister)
|
||||
|
||||
// Simulate registration that returns nil (cache expired during auth)
|
||||
// The channel is buffered (size 1), so this can complete immediately
|
||||
// Simulate registration that returns empty NodeView (cache expired during auth)
|
||||
go func() {
|
||||
registered <- nil // Nil indicates cache expiry
|
||||
nodeToRegister.FinishAuth(types.AuthVerdict{Node: types.NodeView{}}) // Empty view indicates cache expiry
|
||||
}()
|
||||
|
||||
return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil
|
||||
@@ -1815,7 +1801,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper
|
||||
// Generate a registration ID that doesn't exist in cache
|
||||
// This simulates an expired/missing cache entry
|
||||
regID, err := types.NewRegistrationID()
|
||||
regID, err := types.NewAuthID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1847,11 +1833,11 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
|
||||
// Extract and validate the new registration ID exists in cache
|
||||
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
|
||||
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
|
||||
newRegID, err := types.AuthIDFromString(newRegIDStr)
|
||||
assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure
|
||||
|
||||
// Verify new registration entry exists in cache
|
||||
_, found := app.state.GetRegistrationCacheEntry(newRegID)
|
||||
_, found := app.state.GetAuthCacheEntry(newRegID)
|
||||
assert.True(t, found, "new registration should exist in cache")
|
||||
},
|
||||
},
|
||||
@@ -2300,7 +2286,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify cache entry exists
|
||||
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
assert.True(t, found, "registration cache entry should exist initially")
|
||||
assert.NotNil(t, cacheEntry)
|
||||
|
||||
@@ -2315,7 +2301,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern
|
||||
|
||||
// Cache entry should still exist after auth error (for retry scenarios)
|
||||
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
_, stillFound := app.state.GetAuthCacheEntry(registrationID)
|
||||
assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry")
|
||||
},
|
||||
},
|
||||
@@ -2375,8 +2361,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
|
||||
|
||||
// Both cache entries should exist simultaneously
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
_, found1 := app.state.GetAuthCacheEntry(regID1)
|
||||
_, found2 := app.state.GetAuthCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first registration cache entry should exist")
|
||||
assert.True(t, found2, "second registration cache entry should exist")
|
||||
@@ -2427,8 +2413,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both exist
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
_, found1 := app.state.GetAuthCacheEntry(regID1)
|
||||
_, found2 := app.state.GetAuthCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first cache entry should exist")
|
||||
assert.True(t, found2, "second cache entry should exist")
|
||||
@@ -2490,7 +2476,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
||||
}
|
||||
|
||||
// First registration should still be in cache (not completed)
|
||||
_, stillFound := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, stillFound := app.state.GetAuthCacheEntry(regID1)
|
||||
assert.True(t, stillFound, "first registration should still be pending")
|
||||
},
|
||||
},
|
||||
@@ -2601,7 +2587,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
var (
|
||||
initialResp *tailcfg.RegisterResponse
|
||||
authURL string
|
||||
registrationID types.RegistrationID
|
||||
registrationID types.AuthID
|
||||
finalResp *tailcfg.RegisterResponse
|
||||
err error
|
||||
)
|
||||
@@ -2629,10 +2615,10 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
|
||||
if step.expectCacheEntry {
|
||||
// Verify registration cache entry was created
|
||||
cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
cacheEntry, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
require.True(t, found, "registration cache entry should exist")
|
||||
require.NotNil(t, cacheEntry, "cache entry should not be nil")
|
||||
require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key")
|
||||
require.Equal(t, req.NodeKey, cacheEntry.Node().NodeKey(), "cache entry should have correct node key")
|
||||
}
|
||||
|
||||
case stepTypeAuthCompletion:
|
||||
@@ -2692,7 +2678,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
// Check cache cleanup expectation for this step
|
||||
if step.expectCacheEntry == false && registrationID != "" {
|
||||
// Verify cache entry was cleaned up
|
||||
_, found := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
_, found := app.state.GetAuthCacheEntry(registrationID)
|
||||
require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType)
|
||||
}
|
||||
}
|
||||
@@ -2714,7 +2700,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
}
|
||||
|
||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
|
||||
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
|
||||
func extractRegistrationIDFromAuthURL(authURL string) (types.AuthID, error) {
|
||||
// AuthURL format: "http://localhost/register/abc123"
|
||||
const registerPrefix = "/register/"
|
||||
|
||||
@@ -2725,7 +2711,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err
|
||||
|
||||
idStr := authURL[idx+len(registerPrefix):]
|
||||
|
||||
return types.RegistrationIDFromString(idStr)
|
||||
return types.AuthIDFromString(idStr)
|
||||
}
|
||||
|
||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
|
||||
@@ -3583,8 +3569,8 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
// Simulate a registration cache entry (as would be created during web auth)
|
||||
registrationID := types.MustRegistrationID()
|
||||
regEntry := types.NewRegisterNode(types.Node{
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "webauth-tags-node",
|
||||
@@ -3593,7 +3579,7 @@ func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
|
||||
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
||||
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||
|
||||
// Complete the web auth - should fail because tag is unauthorized
|
||||
_, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@@ -3646,8 +3632,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
nodeKey1 := key.NewNode()
|
||||
|
||||
// Step 1: Initial registration with tags
|
||||
registrationID1 := types.MustRegistrationID()
|
||||
regEntry1 := types.NewRegisterNode(types.Node{
|
||||
registrationID1 := types.MustAuthID()
|
||||
regEntry1 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey1.Public(),
|
||||
Hostname: "reauth-untag-node",
|
||||
@@ -3656,7 +3642,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
RequestTags: []string{"tag:valid-owned", "tag:second"},
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID1, regEntry1)
|
||||
app.state.SetAuthCacheEntry(registrationID1, regEntry1)
|
||||
|
||||
// Complete initial registration with tags
|
||||
node, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@@ -3673,8 +3659,8 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
|
||||
// Step 2: Reauth with EMPTY tags to untag
|
||||
nodeKey2 := key.NewNode() // New node key for reauth
|
||||
registrationID2 := types.MustRegistrationID()
|
||||
regEntry2 := types.NewRegisterNode(types.Node{
|
||||
registrationID2 := types.MustAuthID()
|
||||
regEntry2 := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(), // Same machine key
|
||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||
Hostname: "reauth-untag-node",
|
||||
@@ -3683,7 +3669,7 @@ func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) {
|
||||
RequestTags: []string{}, // EMPTY - should untag
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID2, regEntry2)
|
||||
app.state.SetAuthCacheEntry(registrationID2, regEntry2)
|
||||
|
||||
// Complete reauth with empty tags
|
||||
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@@ -3759,8 +3745,8 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
||||
|
||||
// Step 2: Reauth via web auth with EMPTY tags to transition to user-owned
|
||||
nodeKey2 := key.NewNode() // New node key for reauth
|
||||
registrationID := types.MustRegistrationID()
|
||||
regEntry := types.NewRegisterNode(types.Node{
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(), // Same machine key
|
||||
NodeKey: nodeKey2.Public(), // Different node key (rotation)
|
||||
Hostname: "authkey-tagged-node",
|
||||
@@ -3769,7 +3755,7 @@ func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) {
|
||||
RequestTags: []string{}, // EMPTY - should untag
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
||||
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||
|
||||
// Complete reauth with empty tags
|
||||
nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath(
|
||||
@@ -3958,8 +3944,8 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
||||
// Step 4: Re-register the node to alice via HandleNodeFromAuthPath
|
||||
// This is what happens when running: headscale nodes register --user alice --key ...
|
||||
nodeKey2 := key.NewNode()
|
||||
registrationID := types.MustRegistrationID()
|
||||
regEntry := types.NewRegisterNode(types.Node{
|
||||
registrationID := types.MustAuthID()
|
||||
regEntry := types.NewRegisterAuthRequest(types.Node{
|
||||
MachineKey: machineKey.Public(), // Same machine key as the tagged node
|
||||
NodeKey: nodeKey2.Public(),
|
||||
Hostname: "tagged-orphan-node",
|
||||
@@ -3968,7 +3954,7 @@ func TestTaggedNodeWithoutUserToDifferentUser(t *testing.T) {
|
||||
RequestTags: []string{}, // Empty - transition to user-owned
|
||||
},
|
||||
})
|
||||
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
|
||||
app.state.SetAuthCacheEntry(registrationID, regEntry)
|
||||
|
||||
// This should NOT panic - before the fix, this would panic with:
|
||||
// panic: runtime error: invalid memory address or nil pointer dereference
|
||||
|
||||
Reference in New Issue
Block a user