make tags first class node owner (#2885)

This PR changes tags to be something that exists on nodes in addition to users, to being its own thing. It is part of moving our tags support towards the correct tailscale compatible implementation.

There are probably rough edges in this PR, but the intention is to get it in, and then start fixing bugs from 0.28.0 milestone (long standing tags issue) to discover what works and what doesnt.

Updates #2417
Closes #2619
This commit is contained in:
Kristoffer Dalby
2025-12-02 12:01:25 +01:00
committed by GitHub
parent 705b239677
commit eb788cd007
49 changed files with 3102 additions and 757 deletions

View File

@@ -233,11 +233,7 @@ func isAuthKey(req tailcfg.RegisterRequest) bool {
}
func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{
// TODO(kradalby): Only send for user-owned nodes
// and not tagged nodes when tags is working.
User: node.UserView().TailscaleUser(),
Login: node.UserView().TailscaleLogin(),
resp := &tailcfg.RegisterResponse{
NodeKeyExpired: node.IsExpired(),
// Headscale does not implement the concept of machine authorization
@@ -245,6 +241,18 @@ func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse {
// Revisit this if #2176 gets implemented.
MachineAuthorized: true,
}
// For tagged nodes, use the TaggedDevices special user
// For user-owned nodes, include User and Login information from the actual user
if node.IsTagged() {
resp.User = types.TaggedDevices.View().TailscaleUser()
resp.Login = types.TaggedDevices.View().TailscaleLogin()
} else if node.UserView().Valid() {
resp.User = node.UserView().TailscaleUser()
resp.Login = node.UserView().TailscaleLogin()
}
return resp
}
func (h *Headscale) waitForFollowup(

535
hscontrol/auth_tags_test.go Normal file
View File

@@ -0,0 +1,535 @@
package hscontrol
import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// TestTaggedPreAuthKeyCreatesTaggedNode tests that a PreAuthKey with tags creates
// a tagged node with:
// - Tags from the PreAuthKey
// - UserID tracking who created the key (informational "created by")
// - IsTagged() returns true.
func TestTaggedPreAuthKeyCreatesTaggedNode(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server", "tag:prod"}
// Create a tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
require.NotEmpty(t, pak.Tags, "PreAuthKey should have tags")
require.ElementsMatch(t, tags, pak.Tags, "PreAuthKey should have specified tags")
// Register a node using the tagged key
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify the node was created with tags
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
// Critical assertions for tags-as-identity model
assert.True(t, node.IsTagged(), "Node should be tagged")
assert.ElementsMatch(t, tags, node.Tags().AsSlice(), "Node should have tags from PreAuthKey")
assert.True(t, node.UserID().Valid(), "Node should have UserID tracking creator")
assert.Equal(t, user.ID, node.UserID().Get(), "UserID should track PreAuthKey creator")
// Verify node is identified correctly
assert.True(t, node.IsTagged(), "Tagged node is not user-owned")
assert.True(t, node.HasTag("tag:server"), "Node should have tag:server")
assert.True(t, node.HasTag("tag:prod"), "Node should have tag:prod")
assert.False(t, node.HasTag("tag:other"), "Node should not have tag:other")
}
// TestReAuthDoesNotReapplyTags tests that when a node re-authenticates using the
// same PreAuthKey, the tags are NOT re-applied. Tags are only set during initial
// authentication. This is critical for the container restart scenario (#2830).
//
// NOTE: This test verifies that re-authentication preserves the node's current tags
// without testing tag modification via SetNodeTags (which requires ACL policy setup).
func TestReAuthDoesNotReapplyTags(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
initialTags := []string{"tag:server", "tag:dev"}
// Create a tagged PreAuthKey with reusable=true for re-auth
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, initialTags)
require.NoError(t, err)
// Initial registration
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "reauth-test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify initial tags
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
require.True(t, node.IsTagged())
require.ElementsMatch(t, initialTags, node.Tags().AsSlice())
// Re-authenticate with the SAME PreAuthKey (container restart scenario)
// Key behavior: Tags should NOT be re-applied during re-auth
reAuthReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same key
},
NodeKey: nodeKey.Public(), // Same node key
Hostinfo: &tailcfg.Hostinfo{
Hostname: "reauth-test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
reAuthResp, err := app.handleRegisterWithAuthKey(reAuthReq, machineKey.Public())
require.NoError(t, err)
require.True(t, reAuthResp.MachineAuthorized)
// CRITICAL: Tags should remain unchanged after re-auth
// They should match the original tags, proving they weren't re-applied
nodeAfterReauth, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, nodeAfterReauth.IsTagged(), "Node should still be tagged")
assert.ElementsMatch(t, initialTags, nodeAfterReauth.Tags().AsSlice(), "Tags should remain unchanged on re-auth")
// Verify only one node was created (no duplicates)
nodes := app.state.ListNodesByUser(types.UserID(user.ID))
assert.Equal(t, 1, nodes.Len(), "Should have exactly one node")
}
// NOTE: TestSetTagsOnUserOwnedNode functionality is covered by gRPC tests in grpcv1_test.go
// which properly handle ACL policy setup. The test verifies that SetTags can convert
// user-owned nodes to tagged nodes while preserving UserID.
// TestCannotRemoveAllTags tests that attempting to remove all tags from a
// tagged node fails with ErrCannotRemoveAllTags. Once a node is tagged,
// it must always have at least one tag (Tailscale requirement).
func TestCannotRemoveAllTags(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a tagged node
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify node is tagged
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
require.True(t, node.IsTagged())
// Attempt to remove all tags by setting empty array
_, _, err = app.state.SetNodeTags(node.ID(), []string{})
require.Error(t, err, "Should not be able to remove all tags")
require.ErrorIs(t, err, types.ErrCannotRemoveAllTags, "Error should be ErrCannotRemoveAllTags")
// Verify node still has original tags
nodeAfter, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, nodeAfter.IsTagged(), "Node should still be tagged")
assert.ElementsMatch(t, tags, nodeAfter.Tags().AsSlice(), "Tags should be unchanged")
}
// TestUserOwnedNodeCreatedWithUntaggedPreAuthKey tests that using a PreAuthKey
// without tags creates a user-owned node (no tags, UserID is the owner).
func TestUserOwnedNodeCreatedWithUntaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("node-owner")
// Create an untagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
require.Empty(t, pak.Tags, "PreAuthKey should not be tagged")
require.Empty(t, pak.Tags, "PreAuthKey should have no tags")
// Register a node
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "user-owned-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify node is user-owned
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
// Critical assertions for user-owned node
assert.False(t, node.IsTagged(), "Node should not be tagged")
assert.False(t, node.IsTagged(), "Node should be user-owned (not tagged)")
assert.Empty(t, node.Tags().AsSlice(), "Node should have no tags")
assert.True(t, node.UserID().Valid(), "Node should have UserID")
assert.Equal(t, user.ID, node.UserID().Get(), "UserID should be the PreAuthKey owner")
}
// TestMultipleNodesWithSameReusableTaggedPreAuthKey tests that a reusable
// PreAuthKey with tags can be used to register multiple nodes, and all nodes
// receive the same tags from the key.
func TestMultipleNodesWithSameReusableTaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server", "tag:prod"}
// Create a REUSABLE tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
require.ElementsMatch(t, tags, pak.Tags)
// Register first node
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-1",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
// Register second node with SAME PreAuthKey
machineKey2 := key.NewMachine()
nodeKey2 := key.NewNode()
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same key
},
NodeKey: nodeKey2.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-2",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.NoError(t, err)
require.True(t, resp2.MachineAuthorized)
// Verify both nodes exist and have the same tags
node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public())
require.True(t, found)
// Both nodes should be tagged with the same tags
assert.True(t, node1.IsTagged(), "First node should be tagged")
assert.True(t, node2.IsTagged(), "Second node should be tagged")
assert.ElementsMatch(t, tags, node1.Tags().AsSlice(), "First node should have PreAuthKey tags")
assert.ElementsMatch(t, tags, node2.Tags().AsSlice(), "Second node should have PreAuthKey tags")
// Both nodes should track the same creator
assert.Equal(t, user.ID, node1.UserID().Get(), "First node should track creator")
assert.Equal(t, user.ID, node2.UserID().Get(), "Second node should track creator")
// Verify we have exactly 2 nodes
nodes := app.state.ListNodesByUser(types.UserID(user.ID))
assert.Equal(t, 2, nodes.Len(), "Should have exactly two nodes")
}
// TestNonReusableTaggedPreAuthKey tests that a non-reusable PreAuthKey with tags
// can only be used once. The second attempt should fail.
func TestNonReusableTaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a NON-REUSABLE tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, tags)
require.NoError(t, err)
require.ElementsMatch(t, tags, pak.Tags)
// Register first node - should succeed
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-1",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
// Verify first node was created with tags
node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
assert.True(t, node1.IsTagged())
assert.ElementsMatch(t, tags, node1.Tags().AsSlice())
// Attempt to register second node with SAME non-reusable key - should fail
machineKey2 := key.NewMachine()
nodeKey2 := key.NewNode()
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same non-reusable key
},
NodeKey: nodeKey2.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-2",
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.Error(t, err, "Should not be able to reuse non-reusable PreAuthKey")
// Verify only one node was created
nodes := app.state.ListNodesByUser(types.UserID(user.ID))
assert.Equal(t, 1, nodes.Len(), "Should have exactly one node")
}
// TestExpiredTaggedPreAuthKey tests that an expired PreAuthKey with tags
// cannot be used to register a node.
func TestExpiredTaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a PreAuthKey that expires immediately
expiration := time.Now().Add(-1 * time.Hour) // Already expired
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiration, tags)
require.NoError(t, err)
require.ElementsMatch(t, tags, pak.Tags)
// Attempt to register with expired key
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.Error(t, err, "Should not be able to use expired PreAuthKey")
// Verify no node was created
_, found := app.state.GetNodeByNodeKey(nodeKey.Public())
assert.False(t, found, "No node should be created with expired key")
}
// TestSingleVsMultipleTags tests that PreAuthKeys work correctly with both
// a single tag and multiple tags.
func TestSingleVsMultipleTags(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
// Test with single tag
singleTag := []string{"tag:server"}
pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, singleTag)
require.NoError(t, err)
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak1.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "single-tag-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
assert.True(t, node1.IsTagged())
assert.ElementsMatch(t, singleTag, node1.Tags().AsSlice())
// Test with multiple tags
multipleTags := []string{"tag:server", "tag:prod", "tag:database"}
pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, multipleTags)
require.NoError(t, err)
machineKey2 := key.NewMachine()
nodeKey2 := key.NewNode()
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak2.Key,
},
NodeKey: nodeKey2.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "multi-tag-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.NoError(t, err)
require.True(t, resp2.MachineAuthorized)
node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public())
require.True(t, found)
assert.True(t, node2.IsTagged())
assert.ElementsMatch(t, multipleTags, node2.Tags().AsSlice())
// Verify HasTag works for all tags
assert.True(t, node2.HasTag("tag:server"))
assert.True(t, node2.HasTag("tag:prod"))
assert.True(t, node2.HasTag("tag:database"))
assert.False(t, node2.HasTag("tag:other"))
}
// TestReAuthWithDifferentMachineKey tests the edge case where a node attempts
// to re-authenticate with the same NodeKey but a DIFFERENT MachineKey.
// This scenario should be handled gracefully (currently creates a new node).
func TestReAuthWithDifferentMachineKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a reusable tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
// Initial registration
machineKey1 := key.NewMachine()
nodeKey := key.NewNode() // Same NodeKey for both attempts
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
// Verify initial node
node1, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, node1.IsTagged())
// Re-authenticate with DIFFERENT MachineKey but SAME NodeKey
machineKey2 := key.NewMachine() // Different machine key
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(), // Same NodeKey
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.NoError(t, err)
require.True(t, resp2.MachineAuthorized)
// Verify the node still exists and has tags
// Note: Depending on implementation, this might be the same node or a new node
node2, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, node2.IsTagged())
assert.ElementsMatch(t, tags, node2.Tags().AsSlice())
}

View File

@@ -70,7 +70,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_valid_new_node",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("preauth-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -111,7 +112,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_reusable_multiple_nodes",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("reusable-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -177,7 +179,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_single_use_exhausted",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("single-use-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
if err != nil {
return "", err
}
@@ -264,7 +267,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_ephemeral_node",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("ephemeral-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
if err != nil {
return "", err
}
@@ -370,7 +374,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_logout",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("logout-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -429,7 +434,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_machine_key_mismatch",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("mismatch-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -477,7 +483,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_key_extension_not_allowed",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("extend-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -525,7 +532,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_expired_forces_reauth",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("reauth-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -585,7 +593,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "ephemeral_node_logout_deletion",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("ephemeral-logout-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
if err != nil {
return "", err
}
@@ -767,7 +776,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "empty_hostname",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("empty-hostname-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -805,7 +815,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "nil_hostinfo",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("nil-hostinfo-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -848,7 +859,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("expired-pak-user")
expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil)
if err != nil {
return "", err
}
@@ -880,7 +892,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("tagged-pak-user")
tags := []string{"tag:server", "tag:database"}
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, tags)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
if err != nil {
return "", err
}
@@ -926,7 +939,7 @@ func TestAuthenticationFlows(t *testing.T) {
user := app.state.CreateUserForTest("reauth-user")
// First, register with initial auth key
pak1, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -953,7 +966,7 @@ func TestAuthenticationFlows(t *testing.T) {
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
// Create new auth key for re-authentication
pak2, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -992,7 +1005,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_reauth_interactive_flow",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("interactive-reauth-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1053,7 +1067,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "node_key_rotation_same_machine",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("rotation-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1081,7 +1096,7 @@ func TestAuthenticationFlows(t *testing.T) {
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
// Create new auth key for rotation
pakRotation, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pakRotation, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1129,7 +1144,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "malformed_expiry_zero_time",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("zero-expiry-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1167,7 +1183,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "malformed_hostinfo_invalid_data",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("invalid-hostinfo-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1353,7 +1370,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_usage_count_tracking",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("usage-count-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // Single use
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // Single use
if err != nil {
return "", err
}
@@ -1432,7 +1450,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "concurrent_registration_same_node_key",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("concurrent-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1473,7 +1492,8 @@ func TestAuthenticationFlows(t *testing.T) {
user := app.state.CreateUserForTest("future-expiry-user")
// Auth key expires in the future
expiry := time.Now().Add(48 * time.Hour)
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil)
if err != nil {
return "", err
}
@@ -1517,7 +1537,7 @@ func TestAuthenticationFlows(t *testing.T) {
user2 := app.state.CreateUserForTest("user2-context")
// Register node with user1's auth key
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1544,7 +1564,7 @@ func TestAuthenticationFlows(t *testing.T) {
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
// Return user2's auth key for re-authentication
pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil)
pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1571,15 +1591,15 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify NEW node was created for user2
node2, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2))
require.True(t, found, "new node should exist for user2")
assert.Equal(t, uint(2), node2.UserID(), "new node should belong to user2")
assert.Equal(t, uint(2), node2.UserID().Get(), "new node should belong to user2")
user := node2.User()
assert.Equal(t, "user2-context", user.Username(), "new node should show user2 username")
assert.Equal(t, "user2-context", user.Name(), "new node should show user2 username")
// Verify original node still exists for user1
node1, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1))
require.True(t, found, "original node should still exist for user1")
assert.Equal(t, uint(1), node1.UserID(), "original node should still belong to user1")
assert.Equal(t, uint(1), node1.UserID().Get(), "original node should still belong to user1")
// Verify they are different nodes (different IDs)
assert.NotEqual(t, node1.ID(), node2.ID(), "should be different node IDs")
@@ -1595,7 +1615,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
// Create user1 and register a node with auth key
user1 := app.state.CreateUserForTest("interactive-user-1")
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1645,16 +1666,16 @@ func TestAuthenticationFlows(t *testing.T) {
// User1's original node should STILL exist (not transferred)
node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1))
require.True(t, found1, "user1's original node should still exist")
assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1")
assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1")
assert.Equal(t, nodeKey1.Public(), node1.NodeKey(), "user1's node should have original node key")
// User2 should have a NEW node created
node2, found2 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2))
require.True(t, found2, "user2 should have new node created")
assert.Equal(t, uint(2), node2.UserID(), "user2's node should belong to user2")
assert.Equal(t, uint(2), node2.UserID().Get(), "user2's node should belong to user2")
user := node2.User()
assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should show correct username")
assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should show correct username")
// Both nodes should have the same machine key but different IDs
assert.NotEqual(t, node1.ID(), node2.ID(), "should be different nodes (different IDs)")
@@ -1720,7 +1741,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "logout_with_exactly_now_expiry",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("exact-now-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1813,7 +1835,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
// First create a node under user1
user1 := app.state.CreateUserForTest("existing-user-1")
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -1863,7 +1886,7 @@ func TestAuthenticationFlows(t *testing.T) {
// User1's original node with nodeKey1 should STILL exist
node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found1, "user1's original node with nodeKey1 should still exist")
assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1")
assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1")
assert.Equal(t, uint64(1), node1.ID().Uint64(), "user1's node should be ID=1")
// User2 should have a NEW node with nodeKey2
@@ -1872,7 +1895,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Equal(t, "existing-node-user2", node2.Hostname(), "hostname should be from new registration")
user := node2.User()
assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2")
assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2")
assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should be the same")
// Verify it's a NEW node, not transferred
@@ -2022,7 +2045,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
// Register initial node
user := app.state.CreateUserForTest("rotation-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@@ -2072,7 +2096,7 @@ func TestAuthenticationFlows(t *testing.T) {
// User1's original node with nodeKey1 should STILL exist
oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, foundOld, "user1's original node with nodeKey1 should still exist")
assert.Equal(t, uint(1), oldNode.UserID(), "user1's node should still belong to user1")
assert.Equal(t, uint(1), oldNode.UserID().Get(), "user1's node should still belong to user1")
assert.Equal(t, uint64(1), oldNode.ID().Uint64(), "user1's node should be ID=1")
// User2 should have a NEW node with nodeKey2
@@ -2082,7 +2106,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Equal(t, machineKey1.Public(), newNode.MachineKey())
user := newNode.User()
assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2")
assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2")
// Verify it's a NEW node, not transferred
assert.NotEqual(t, uint64(1), newNode.ID().Uint64(), "should be a NEW node (different ID)")
@@ -2333,7 +2357,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.True(t, found, "node should be registered")
if found {
assert.Equal(t, "pending-node-2", node.Hostname())
assert.Equal(t, "second-registration-user", node.User().Name)
assert.Equal(t, "second-registration-user", node.User().Name())
}
// First registration should still be in cache (not completed)
@@ -2593,7 +2617,7 @@ func TestNodeStoreLookup(t *testing.T) {
nodeKey := key.NewNode()
user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
// Register a node
@@ -2642,9 +2666,9 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
user2 := app.state.CreateUserForTest("user2")
// Create pre-auth keys for both users
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
require.NoError(t, err)
pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil)
pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil)
require.NoError(t, err)
// Create machine and node keys for 4 nodes (2 per user)
@@ -2720,7 +2744,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
t.Logf("All nodes logged out")
// Create a new pre-auth key for user1 (reusable for all nodes)
newPak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
newPak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
require.NoError(t, err)
// Re-login all nodes using user1's new pre-auth key
@@ -2765,7 +2789,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// User1's original nodes should still be owned by user1
registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID))
require.True(t, found, "User1's original node %s should still exist", node.hostname)
require.Equal(t, user1.ID, registeredNode.UserID(), "Node %s should still belong to user1", node.hostname)
require.Equal(t, user1.ID, registeredNode.UserID().Get(), "Node %s should still belong to user1", node.hostname)
t.Logf("✓ User1's original node %s (ID=%d) still owned by user1", node.hostname, registeredNode.ID().Uint64())
}
@@ -2774,7 +2798,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// User2's original nodes should still be owned by user2
registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user2.ID))
require.True(t, found, "User2's original node %s should still exist", node.hostname)
require.Equal(t, user2.ID, registeredNode.UserID(), "Node %s should still belong to user2", node.hostname)
require.Equal(t, user2.ID, registeredNode.UserID().Get(), "Node %s should still belong to user2", node.hostname)
t.Logf("✓ User2's original node %s (ID=%d) still owned by user2", node.hostname, registeredNode.ID().Uint64())
}
@@ -2785,7 +2809,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Should be able to find a node with user1 and this machine key (the new one)
newNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID))
require.True(t, found, "Should have created new node for user1 with machine key from %s", node.hostname)
require.Equal(t, user1.ID, newNode.UserID(), "New node should belong to user1")
require.Equal(t, user1.ID, newNode.UserID().Get(), "New node should belong to user1")
t.Logf("✓ New node created for user1 with machine key from %s (ID=%d)", node.hostname, newNode.ID().Uint64())
}
}
@@ -2813,7 +2837,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
// Step 1: Register node for user1 via pre-auth key (simulating initial web flow registration)
user1 := app.state.CreateUserForTest("user1")
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
require.NoError(t, err)
regReq1 := tailcfg.RegisterRequest{
@@ -2834,7 +2858,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
// Verify node exists for user1
user1Node, found := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID))
require.True(t, found, "Node should exist for user1")
require.Equal(t, user1.ID, user1Node.UserID(), "Node should belong to user1")
require.Equal(t, user1.ID, user1Node.UserID().Get(), "Node should belong to user1")
user1NodeID := user1Node.ID()
t.Logf("✓ User1 node created with ID: %d", user1NodeID)
@@ -2896,7 +2920,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
t.Fatal("User1's node was transferred or deleted - this breaks the integration test!")
}
assert.Equal(t, user1.ID, user1NodeAfter.UserID(), "User1's node should still belong to user1")
assert.Equal(t, user1.ID, user1NodeAfter.UserID().Get(), "User1's node should still belong to user1")
assert.Equal(t, user1NodeID, user1NodeAfter.ID(), "Should be the same node (same ID)")
assert.True(t, user1NodeAfter.IsExpired(), "User1's node should still be expired")
t.Logf("✓ User1's original node still exists (ID: %d, expired: %v)", user1NodeAfter.ID(), user1NodeAfter.IsExpired())
@@ -2911,7 +2935,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
t.Fatal("User2 doesn't have a node - registration failed!")
}
assert.Equal(t, user2.ID, user2Node.UserID(), "User2's node should belong to user2")
assert.Equal(t, user2.ID, user2Node.UserID().Get(), "User2's node should belong to user2")
assert.NotEqual(t, user1NodeID, user2Node.ID(), "Should be a NEW node (different ID), not transfer!")
assert.Equal(t, machineKey.Public(), user2Node.MachineKey(), "Should have same machine key")
assert.Equal(t, nodeKey2.Public(), user2Node.NodeKey(), "Should have new node key")
@@ -2921,7 +2945,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
t.Run("returned_node_is_user2_new_node", func(t *testing.T) {
// The node returned from HandleNodeFromAuthPath should be user2's NEW node
assert.Equal(t, user2.ID, node.UserID(), "Returned node should belong to user2")
assert.Equal(t, user2.ID, node.UserID().Get(), "Returned node should belong to user2")
assert.NotEqual(t, user1NodeID, node.ID(), "Returned node should be NEW, not transferred from user1")
t.Logf("✓ HandleNodeFromAuthPath returned user2's new node (ID: %d)", node.ID())
})
@@ -2949,10 +2973,11 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
user2Nodes := 0
for i := 0; i < allNodesSlice.Len(); i++ {
n := allNodesSlice.At(i)
if n.UserID() == user1.ID {
if n.UserID().Get() == user1.ID {
user1Nodes++
}
if n.UserID() == user2.ID {
if n.UserID().Get() == user2.ID {
user2Nodes++
}
}
@@ -3026,7 +3051,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// Create user and single-use pre-auth key
user := app.state.CreateUserForTest("test-user")
pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false
pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // reusable=false
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable field
@@ -3117,7 +3142,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("test-user")
pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true
pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) // reusable=true
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable field
@@ -3173,7 +3198,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) {
user := app.state.CreateUserForTest("test-user")
expiry := time.Now().Add(-1 * time.Hour) // Already expired
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil)
require.NoError(t, err)
machineKey := key.NewMachine()
@@ -3306,7 +3331,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
// Create a SINGLE-USE pre-auth key (reusable=false)
// This is the type of key that triggers the bug in issue #2830
preAuthKeyNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
preAuthKeyNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable and Used fields

View File

@@ -577,6 +577,21 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
// Rename forced_tags column to tags in nodes table.
// This must run after migration 202505141324 which creates tables with forced_tags.
ID: "202511131445-node-forced-tags-to-tags",
Migrate: func(tx *gorm.DB) error {
// Rename the column from forced_tags to tags
err := tx.Migrator().RenameColumn(&types.Node{}, "forced_tags", "tags")
if err != nil {
return fmt.Errorf("renaming forced_tags to tags: %w", err)
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View File

@@ -231,8 +231,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
name string
dbPath string
wantFunc func(*testing.T, *HSDatabase)
}{
}
}{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -95,7 +95,7 @@ func TestIPAllocatorSequential(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
@@ -123,7 +123,7 @@ func TestIPAllocatorSequential(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.2"),
IPv6: nap("fd7a:115c:a1e0::2"),
})
@@ -309,7 +309,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
})
@@ -334,7 +334,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv6: nap("fd7a:115c:a1e0::1"),
})
@@ -359,7 +359,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
@@ -383,7 +383,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
@@ -407,19 +407,19 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
})
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.2"),
})
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.3"),
})
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.4"),
})

View File

@@ -196,8 +196,9 @@ func SetTags(
tags []string,
) error {
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
// if no tags are provided, we remove all tags
err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error
if err != nil {
return fmt.Errorf("removing tags: %w", err)
}
@@ -211,7 +212,8 @@ func SetTags(
return err
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", string(b)).Error
if err != nil {
return fmt.Errorf("updating tags: %w", err)
}
@@ -349,12 +351,20 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
panic("RegisterNodeForTest can only be called during tests")
}
log.Debug().
logEvent := log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Username()).
Msg("Registering test node")
Str("node_key", node.NodeKey.ShortString())
if node.User != nil {
logEvent = logEvent.Str("user", node.User.Username())
} else if node.UserID != nil {
logEvent = logEvent.Uint("user_id", *node.UserID)
} else {
logEvent = logEvent.Str("user", "none")
}
logEvent.Msg("Registering test node")
// If the a new node is registered with the same machine key, to the same user,
// update the existing node.
@@ -642,7 +652,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
}
// Create a preauth key for the node
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := hsdb.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
if err != nil {
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
}
@@ -656,7 +666,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: nodeName,
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}

View File

@@ -83,7 +83,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode")
@@ -97,7 +97,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
Expiry: &time.Time{},
@@ -124,7 +124,7 @@ func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode")
@@ -138,7 +138,7 @@ func (s *Suite) TestSetTags(c *check.C) {
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
@@ -152,7 +152,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
c.Assert(node.Tags, check.DeepEquals, sTags)
// assign duplicate tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
@@ -161,17 +161,10 @@ func (s *Suite) TestSetTags(c *check.C) {
node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(
node.ForcedTags,
node.Tags,
check.DeepEquals,
[]string{"tag:bar", "tag:test", "tag:unknown"},
)
// test removing tags
err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
}
func TestHeadscale_generateGivenName(t *testing.T) {
@@ -430,7 +423,7 @@ func TestAutoApproveRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.routes,
@@ -446,12 +439,12 @@ func TestAutoApproveRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "taggednode",
UserID: taggedUser.ID,
UserID: &taggedUser.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.routes,
},
ForcedTags: []string{"tag:exit"},
Tags: []string{"tag:exit"},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
@@ -593,10 +586,10 @@ func TestListEphemeralNodes(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
require.NoError(t, err)
node := types.Node{
@@ -604,7 +597,7 @@ func TestListEphemeralNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
@@ -614,7 +607,7 @@ func TestListEphemeralNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "ephemeral",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pakEph.ID),
}
@@ -657,7 +650,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -667,7 +660,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -680,7 +673,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "我的电脑",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
@@ -688,7 +681,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "a",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
@@ -808,7 +801,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -931,7 +924,7 @@ func TestListPeers(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test1",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -941,7 +934,7 @@ func TestListPeers(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test2",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -1016,7 +1009,7 @@ func TestListNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test1",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -1026,7 +1019,7 @@ func TestListNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test2",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}

View File

@@ -15,15 +15,15 @@ import (
)
var (
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
ErrPreAuthKeyNotFound = errors.New("auth-key not found")
ErrPreAuthKeyExpired = errors.New("auth-key expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("auth-key has already been used")
ErrUserMismatch = errors.New("user mismatch")
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
ErrPreAuthKeyACLTagInvalid = errors.New("auth-key tag is invalid")
)
func (hsdb *HSDatabase) CreatePreAuthKey(
uid types.UserID,
uid *types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
@@ -41,17 +41,40 @@ const (
)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
// The uid parameter can be nil for system-created tagged keys.
// For tagged keys, uid tracks "created by" (who created the key).
// For user-owned keys, uid tracks the node owner.
func CreatePreAuthKey(
tx *gorm.DB,
uid types.UserID,
uid *types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKeyNew, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
// Validate: must be tagged OR user-owned, not neither
if uid == nil && len(aclTags) == 0 {
return nil, ErrPreAuthKeyNotTaggedOrOwned
}
// If uid != nil && len(aclTags) > 0:
// Both are allowed: UserID tracks "created by", tags define node ownership
// This is valid per the new model
var (
user *types.User
userID *uint
)
if uid != nil {
var err error
user, err = GetUserByID(tx, *uid)
if err != nil {
return nil, err
}
userID = &user.ID
}
// Remove duplicates and sort for consistency
@@ -108,15 +131,15 @@ func CreatePreAuthKey(
}
key := types.PreAuthKey{
UserID: user.ID,
User: *user,
UserID: userID, // nil for system-created keys, or "created by" for tagged keys
User: user, // nil for system-created keys
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: aclTags,
Prefix: prefix, // Store prefix
Hash: hash, // Store hash
Tags: aclTags, // empty for user-owned keys
Prefix: prefix, // Store prefix
Hash: hash, // Store hash
}
if err := tx.Save(&key).Error; err != nil {
@@ -149,14 +172,19 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
}
keys := []types.PreAuthKey{}
if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
if err != nil {
return nil, err
}
return keys, nil
}
var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey")
var (
ErrPreAuthKeyFailedToParse = errors.New("failed to parse auth-key")
ErrPreAuthKeyNotTaggedOrOwned = errors.New("auth-key must be either tagged or owned by user")
)
func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) {
var pak types.PreAuthKey

View File

@@ -24,7 +24,7 @@ func TestCreatePreAuthKey(t *testing.T) {
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil)
assert.Error(t, err)
},
},
@@ -36,7 +36,7 @@ func TestCreatePreAuthKey(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
key, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key.Key)
@@ -83,7 +83,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test-tags-1"})
require.NoError(t, err)
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"badtag"})
assert.Error(t, err)
},
},
@@ -98,7 +98,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
expectedTags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate)
require.NoError(t, err)
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
@@ -128,13 +128,13 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test8"})
require.NoError(t, err)
key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"})
key, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:good"})
require.NoError(t, err)
node := types.Node{
ID: 0,
Hostname: "testest",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(key.ID),
}
@@ -180,7 +180,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
validateResult: func(t *testing.T, pak *types.PreAuthKey) {
t.Helper()
assert.Equal(t, user.ID, pak.UserID)
assert.Equal(t, user.ID, *pak.UserID)
assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated
assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix
assert.Nil(t, pak.Hash) // Legacy keys have nil Hash
@@ -191,7 +191,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
setupKey: func() string {
// Create new key via API
keyStr, err := db.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
true, false, nil, []string{"tag:test"},
)
require.NoError(t, err)
@@ -203,7 +203,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
validateResult: func(t *testing.T, pak *types.PreAuthKey) {
t.Helper()
assert.Equal(t, user.ID, pak.UserID)
assert.Equal(t, user.ID, *pak.UserID)
assert.Empty(t, pak.Key) // New keys have empty Key
assert.NotEmpty(t, pak.Prefix) // New keys have Prefix
assert.NotNil(t, pak.Hash) // New keys have Hash
@@ -214,7 +214,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
name: "new_key_format_validation",
setupKey: func() string {
keyStr, err := db.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
true, false, nil, nil,
)
require.NoError(t, err)
@@ -244,7 +244,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
setupKey: func() string {
// Create valid key
key, err := db.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
true, false, nil, nil,
)
require.NoError(t, err)
@@ -415,11 +415,11 @@ func TestMultipleLegacyKeysAllowed(t *testing.T) {
assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys")
// Now create new bcrypt-based keys - these should have unique prefixes
key1, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
key1, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key1.Key)
key2, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
key2, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key2.Key)

View File

@@ -81,7 +81,7 @@ CREATE TABLE nodes(
given_name varchar(63),
user_id integer,
register_method text,
forced_tags text,
tags text,
auth_key_id integer,
last_seen datetime,
expiry datetime,

View File

@@ -189,7 +189,11 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
// ListNodesByUser gets all the nodes in a given user.
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil {
uidPtr := uint(uid)
err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: &uidPtr}).Find(&nodes).Error
if err != nil {
return nil, err
}

View File

@@ -50,7 +50,7 @@ func TestDestroyUserErrors(t *testing.T) {
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
err = db.DestroyUser(types.UserID(user.ID))
@@ -71,13 +71,13 @@ func TestDestroyUserErrors(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}

View File

@@ -172,7 +172,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
}
preAuthKey, err := api.h.state.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
request.GetReusable(),
request.GetEphemeral(),
&expiration,
@@ -341,6 +341,17 @@ func (api headscaleV1APIServer) SetTags(
ctx context.Context,
request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) {
// Validate tags not empty - tagged nodes must have at least one tag
if len(request.GetTags()) == 0 {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(
codes.InvalidArgument,
"cannot remove all tags from a node - tagged nodes must have at least one tag",
)
}
// Validate tag format
for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
@@ -348,6 +359,16 @@ func (api headscaleV1APIServer) SetTags(
}
}
// User XOR Tags: nodes are either tagged or user-owned, never both.
// Setting tags on a user-owned node converts it to a tagged node.
// Once tagged, a node cannot be converted back to user-owned.
_, found := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if !found {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.NotFound, "node not found")
}
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
@@ -529,13 +550,19 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N
for index, node := range nodes.All() {
resp := node.Proto()
// Tags-as-identity: tagged nodes show as TaggedDevices user in API responses
// (UserID may be set internally for "created by" tracking)
if node.IsTagged() {
resp.User = types.TaggedDevices.Proto()
}
var tags []string
for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
resp.ValidTags = lo.Uniq(append(tags, node.Tags().AsSlice()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp
@@ -780,7 +807,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: *user,
User: user,
Expiry: &time.Time{},
LastSeen: &time.Time{},

View File

@@ -1,6 +1,17 @@
package hscontrol
import "testing"
import (
"context"
"testing"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func Test_validateTag(t *testing.T) {
type args struct {
@@ -40,3 +51,212 @@ func Test_validateTag(t *testing.T) {
})
}
}
// TestSetTags_Conversion tests the conversion of user-owned nodes to tagged nodes.
// The tags-as-identity model allows one-way conversion from user-owned to tagged.
// Tag authorization is checked via the policy manager - unauthorized tags are rejected.
func TestSetTags_Conversion(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create test user and nodes
user := app.state.CreateUserForTest("test-user")
// Create a pre-auth key WITHOUT tags for user-owned node
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
// Register a user-owned node (via untagged PreAuthKey)
userOwnedReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "user-owned-node",
},
}
_, err = app.handleRegisterWithAuthKey(userOwnedReq, machineKey1.Public())
require.NoError(t, err)
// Get the created node
userOwnedNode, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
// Create API server instance
apiServer := newHeadscaleV1APIServer(app)
tests := []struct {
name string
nodeID uint64
tags []string
wantErr bool
wantCode codes.Code
wantErrMessage string
}{
{
// Conversion is allowed, but tag authorization fails without tagOwners
name: "reject unauthorized tags on user-owned node",
nodeID: uint64(userOwnedNode.ID()),
tags: []string{"tag:server"},
wantErr: true,
wantCode: codes.InvalidArgument,
wantErrMessage: "invalid or unauthorized tags",
},
{
// Conversion is allowed, but tag authorization fails without tagOwners
name: "reject multiple unauthorized tags",
nodeID: uint64(userOwnedNode.ID()),
tags: []string{"tag:server", "tag:database"},
wantErr: true,
wantCode: codes.InvalidArgument,
wantErrMessage: "invalid or unauthorized tags",
},
{
name: "reject non-existent node",
nodeID: 99999,
tags: []string{"tag:server"},
wantErr: true,
wantCode: codes.NotFound,
wantErrMessage: "node not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{
NodeId: tt.nodeID,
Tags: tt.tags,
})
if tt.wantErr {
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok, "error should be a gRPC status error")
assert.Equal(t, tt.wantCode, st.Code())
assert.Contains(t, st.Message(), tt.wantErrMessage)
assert.Nil(t, resp.GetNode())
} else {
require.NoError(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.GetNode())
}
})
}
}
// TestSetTags_TaggedNode tests that SetTags correctly identifies tagged nodes
// and doesn't reject them with the "user-owned nodes" error.
// Note: This test doesn't validate ACL tag authorization - that's tested elsewhere.
func TestSetTags_TaggedNode(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create test user and tagged pre-auth key
user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:initial"})
require.NoError(t, err)
machineKey := key.NewMachine()
nodeKey := key.NewNode()
// Register a tagged node (via tagged PreAuthKey)
taggedReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
}
_, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public())
require.NoError(t, err)
// Get the created node
taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, taggedNode.IsTagged(), "Node should be tagged")
assert.True(t, taggedNode.UserID().Valid(), "Tagged node should have UserID for tracking")
// Create API server instance
apiServer := newHeadscaleV1APIServer(app)
// Test: SetTags should NOT reject tagged nodes with "user-owned" error
// (Even though they have UserID set, IsTagged() identifies them correctly)
resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{
NodeId: uint64(taggedNode.ID()),
Tags: []string{"tag:initial"}, // Keep existing tag to avoid ACL validation issues
})
// The call should NOT fail with "cannot set tags on user-owned nodes"
if err != nil {
st, ok := status.FromError(err)
require.True(t, ok)
// If error is about unauthorized tags, that's fine - ACL validation is working
// If error is about user-owned nodes, that's the bug we're testing for
assert.NotContains(t, st.Message(), "user-owned nodes", "Should not reject tagged nodes as user-owned")
} else {
// Success is also fine
assert.NotNil(t, resp)
}
}
// TestSetTags_CannotRemoveAllTags tests that SetTags rejects attempts to remove
// all tags from a tagged node, enforcing Tailscale's requirement that tagged
// nodes must have at least one tag.
func TestSetTags_CannotRemoveAllTags(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create test user and tagged pre-auth key
user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:server"})
require.NoError(t, err)
machineKey := key.NewMachine()
nodeKey := key.NewNode()
// Register a tagged node
taggedReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
}
_, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public())
require.NoError(t, err)
// Get the created node
taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, taggedNode.IsTagged())
// Create API server instance
apiServer := newHeadscaleV1APIServer(app)
// Attempt to remove all tags (empty array)
resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{
NodeId: uint64(taggedNode.ID()),
Tags: []string{}, // Empty - attempting to remove all tags
})
// Should fail with InvalidArgument error
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok, "error should be a gRPC status error")
assert.Equal(t, codes.InvalidArgument, st.Code())
assert.Contains(t, st.Message(), "cannot remove all tags")
assert.Nil(t, resp.GetNode())
}

View File

@@ -73,15 +73,17 @@ func generateUserProfiles(
node types.NodeView,
peers views.Slice[types.NodeView],
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User)
userMap := make(map[uint]*types.UserView)
ids := make([]uint, 0, len(userMap))
user := node.User()
userMap[user.ID] = &user
ids = append(ids, user.ID)
userID := user.Model().ID
userMap[userID] = &user
ids = append(ids, userID)
for _, peer := range peers.All() {
peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
peerUserID := peerUser.Model().ID
userMap[peerUserID] = &peerUser
ids = append(ids, peerUserID)
}
slices.Sort(ids)

View File

@@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/ptr"
)
var iap = func(ipStr string) *netip.Addr {
@@ -50,8 +51,8 @@ func TestDNSConfigMapResponse(t *testing.T) {
mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{
Hostname: hostname,
UserID: userid,
User: types.User{
UserID: ptr.To(userid),
User: &types.User{
Name: username,
},
}

View File

@@ -83,7 +83,8 @@ func tailNode(
tags = append(tags, tag)
}
}
for _, tag := range node.ForcedTags().All() {
for _, tag := range node.Tags().All() {
tags = append(tags, tag)
}
tags = lo.Uniq(tags)
@@ -99,7 +100,7 @@ func tailNode(
Name: hostname,
Cap: capVer,
User: tailcfg.UserID(node.UserID()),
User: node.TailscaleUserID(),
Key: node.NodeKey(),
KeyExpiry: keyExpiry.UTC(),

View File

@@ -15,6 +15,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestTailNode(t *testing.T) {
@@ -97,14 +98,14 @@ func TestTailNode(t *testing.T) {
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: 0,
User: types.User{
UserID: ptr.To(uint(0)),
User: &types.User{
Name: "mini",
},
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Tags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),

View File

@@ -1,13 +1,10 @@
package hscontrol
import (
"os"
"path/filepath"
"testing"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCCallbackTemplate(t *testing.T) {
@@ -49,15 +46,6 @@ func TestOIDCCallbackTemplate(t *testing.T) {
assert.Contains(t, html, "<svg")
assert.Contains(t, html, "class=\"headscale-logo\"")
assert.Contains(t, html, "id=\"checkbox\"")
// Save the output for manual inspection
testDataDir := filepath.Join("testdata", "oidc_templates")
err := os.MkdirAll(testDataDir, 0o755)
require.NoError(t, err)
outputFile := filepath.Join(testDataDir, tt.name+".html")
err = os.WriteFile(outputFile, []byte(html), 0o600)
require.NoError(t, err)
})
}
}

View File

@@ -32,11 +32,11 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test-node",
UserID: user1.ID,
User: user1,
UserID: ptr.To(user1.ID),
User: ptr.To(user1),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ForcedTags: []string{"tag:test"},
Tags: []string{"tag:test"},
}
node2 := &types.Node{
@@ -44,8 +44,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "other-node",
UserID: user2.ID,
User: user2,
UserID: ptr.To(user2.ID),
User: ptr.To(user2),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
@@ -304,8 +304,8 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,

View File

@@ -168,15 +168,15 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname,
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
ForcedTags: tt.nodeTags,
Tags: tt.nodeTags,
}
nodes := types.Nodes{&node}
@@ -294,8 +294,8 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
@@ -343,8 +343,8 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes,

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var ap = func(ipStr string) *netip.Addr {
@@ -44,17 +45,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{
@@ -68,19 +69,19 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
},
want: types.Nodes{
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@@ -91,17 +92,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -115,14 +116,14 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
},
want: types.Nodes{
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
},
@@ -133,17 +134,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -157,14 +158,14 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: types.Nodes{
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@@ -175,17 +176,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -199,14 +200,14 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
},
want: types.Nodes{
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
},
@@ -217,17 +218,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -241,19 +242,19 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: types.Nodes{
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@@ -264,17 +265,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -288,19 +289,19 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: types.Nodes{
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@@ -311,17 +312,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -329,7 +330,7 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: nil,
@@ -347,28 +348,28 @@ func TestReduceNodes(t *testing.T) {
Hostname: "ts-head-upcrmb",
IPv4: ap("100.64.0.3"),
IPv6: ap("fd7a:115c:a1e0::3"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
&types.Node{
ID: 2,
Hostname: "ts-unstable-rlwpvr",
IPv4: ap("100.64.0.4"),
IPv6: ap("fd7a:115c:a1e0::4"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
&types.Node{
ID: 3,
Hostname: "ts-head-8w6paa",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
&types.Node{
ID: 4,
Hostname: "ts-unstable-lys2ib",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@@ -390,7 +391,7 @@ func TestReduceNodes(t *testing.T) {
Hostname: "ts-head-8w6paa",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
want: types.Nodes{
@@ -399,14 +400,14 @@ func TestReduceNodes(t *testing.T) {
Hostname: "ts-head-upcrmb",
IPv4: ap("100.64.0.3"),
IPv6: ap("fd7a:115c:a1e0::3"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
&types.Node{
ID: 2,
Hostname: "ts-unstable-rlwpvr",
IPv4: ap("100.64.0.4"),
IPv6: ap("fd7a:115c:a1e0::4"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
},
},
@@ -418,13 +419,13 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "peer1",
User: types.User{Name: "mini"},
User: &types.User{Name: "mini"},
},
{
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "peer2",
User: types.User{Name: "peer2"},
User: &types.User{Name: "peer2"},
},
},
rules: []tailcfg.FilterRule{
@@ -440,7 +441,7 @@ func TestReduceNodes(t *testing.T) {
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "mini",
User: types.User{Name: "mini"},
User: &types.User{Name: "mini"},
},
},
want: []*types.Node{
@@ -448,7 +449,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "peer2",
User: types.User{Name: "peer2"},
User: &types.User{Name: "peer2"},
},
},
},
@@ -460,19 +461,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "user1-1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
rules: []tailcfg.FilterRule{
@@ -509,7 +510,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "user-2-1",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
want: []*types.Node{
@@ -517,19 +518,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "user1-1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
},
@@ -541,19 +542,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "user-2-1",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
rules: []tailcfg.FilterRule{
@@ -590,7 +591,7 @@ func TestReduceNodes(t *testing.T) {
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "user1-1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
},
want: []*types.Node{
@@ -598,19 +599,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "user-2-1",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
},
@@ -622,13 +623,13 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "user1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
@@ -649,7 +650,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "user1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
},
want: []*types.Node{
@@ -657,7 +658,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
@@ -673,7 +674,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@@ -683,7 +684,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
rules: []tailcfg.FilterRule{
@@ -700,7 +701,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@@ -712,7 +713,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
},
@@ -724,7 +725,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@@ -734,7 +735,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
rules: []tailcfg.FilterRule{
@@ -751,7 +752,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
want: []*types.Node{
@@ -759,7 +760,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@@ -804,7 +805,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
ID: id,
IPv4: ap(ip),
Hostname: hostname,
User: types.User{Name: username},
User: &types.User{Name: username},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: routes,
},
@@ -812,8 +813,6 @@ func TestReduceNodesFromPolicy(t *testing.T) {
}
}
type args struct {
}
tests := []struct {
name string
nodes types.Nodes
@@ -1075,22 +1074,22 @@ func TestSSHPolicyRules(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: ap("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(uint(1)),
User: ptr.To(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: ap("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
}
taggedClient := types.Node{
Hostname: "tagged-client",
IPv4: ap("100.64.0.4"),
UserID: 2,
User: users[1],
ForcedTags: []string{"tag:client"},
Hostname: "tagged-client",
IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
Tags: []string{"tag:client"},
}
tests := []struct {
@@ -1447,7 +1446,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@@ -1475,7 +1474,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@@ -1501,7 +1500,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@@ -1529,7 +1528,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@@ -1556,7 +1555,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@@ -1581,7 +1580,7 @@ func TestReduceRoutes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@@ -1614,7 +1613,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // Node IP
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1646,7 +1645,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1673,7 +1672,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1701,7 +1700,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1739,7 +1738,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // node with IP 100.64.0.2
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1774,7 +1773,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"), // router with IP 100.64.0.1
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1816,7 +1815,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // node
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1850,7 +1849,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // node
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@@ -1887,7 +1886,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.123.45.89"), // Node B - regular node
User: types.User{Name: "node-b"},
User: &types.User{Name: "node-b"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet connected to Node A
@@ -1917,7 +1916,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.123.45.67"), // Node A - router node
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet connected to this router
@@ -1946,7 +1945,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.123.45.89"), // Node B - regular node that should be reachable
User: types.User{Name: "node-b"},
User: &types.User{Name: "node-b"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet behind router
@@ -1984,7 +1983,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 3,
IPv4: ap("100.123.45.99"), // Node C - isolated node
User: types.User{Name: "isolated-node"},
User: &types.User{Name: "isolated-node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet behind router
@@ -2027,7 +2026,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.123.45.89"), // Node B - regular node
User: types.User{Name: "node-b"},
User: &types.User{Name: "node-b"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/14"), // Network 192.168.1.0/14 as mentioned in original issue

View File

@@ -16,6 +16,7 @@ import (
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
"tailscale.com/util/must"
)
@@ -143,13 +144,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: users[0],
User: ptr.To(users[0]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: users[0],
User: ptr.To(users[0]),
},
},
want: []tailcfg.FilterRule{},
@@ -190,7 +191,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
@@ -201,7 +202,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@@ -282,19 +283,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@@ -343,7 +344,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@@ -352,12 +353,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@@ -452,7 +453,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@@ -461,12 +462,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@@ -564,7 +565,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
@@ -573,12 +574,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@@ -654,7 +655,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
@@ -663,12 +664,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@@ -736,17 +737,17 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
Tags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@@ -803,13 +804,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[3],
User: ptr.To(users[3]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
func TestNodeCanApproveRoute(t *testing.T) {
@@ -24,34 +25,34 @@ func TestNodeCanApproveRoute(t *testing.T) {
ID: 1,
Hostname: "user1-device",
IPv4: ap("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(uint(1)),
User: ptr.To(users[0]),
}
exitNode := types.Node{
ID: 2,
Hostname: "user2-device",
IPv4: ap("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
}
taggedNode := types.Node{
ID: 3,
Hostname: "tagged-server",
IPv4: ap("100.64.0.3"),
UserID: 3,
User: users[2],
ForcedTags: []string{"tag:router"},
ID: 3,
Hostname: "tagged-server",
IPv4: ap("100.64.0.3"),
UserID: ptr.To(uint(3)),
User: ptr.To(users[2]),
Tags: []string{"tag:router"},
}
multiTagNode := types.Node{
ID: 4,
Hostname: "multi-tag-node",
IPv4: ap("100.64.0.4"),
UserID: 2,
User: users[1],
ForcedTags: []string{"tag:router", "tag:server"},
ID: 4,
Hostname: "multi-tag-node",
IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
Tags: []string{"tag:router", "tag:server"},
}
tests := []struct {

View File

@@ -168,7 +168,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID == node.User().ID && !n.IsTagged() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
}
}
@@ -349,7 +349,7 @@ func (pol *Policy) compileSSHPolicy(
// Build destination set for autogroup:self (same-user untagged devices only)
var dest netipx.IPSetBuilder
for _, n := range nodes.All() {
if n.User().ID == node.User().ID && !n.IsTagged() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
n.AppendToIPSet(&dest)
}
}
@@ -365,7 +365,7 @@ func (pol *Policy) compileSSHPolicy(
// Pre-filter to same-user untagged devices for efficiency
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID == node.User().ID && !n.IsTagged() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
}
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
@@ -381,7 +382,7 @@ func TestParsing(t *testing.T) {
},
&types.Node{
IPv4: ap("200.200.200.200"),
User: users[0],
User: &users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
}.ViewSlice())
@@ -409,14 +410,14 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
}
nodes := types.Nodes{&nodeUser1, &nodeUser2}
@@ -621,14 +622,14 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
}
nodes := types.Nodes{&nodeUser1, &nodeUser2}
@@ -682,15 +683,15 @@ func TestSSHIntegrationReproduction(t *testing.T) {
node1 := &types.Node{
Hostname: "user1-node",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
}
node2 := &types.Node{
Hostname: "user2-node",
IPv4: createAddr("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
}
nodes := types.Nodes{node1, node2}
@@ -741,11 +742,12 @@ func TestSSHJSONSerialization(t *testing.T) {
{Name: "user1", Model: gorm.Model{ID: 1}},
}
uid := uint(1)
node := &types.Node{
Hostname: "test-node",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: &uid,
User: &users[0],
}
nodes := types.Nodes{node}
@@ -804,32 +806,32 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
nodes := types.Nodes{
{
User: users[0],
User: ptr.To(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: users[0],
User: ptr.To(users[0]),
IPv4: ap("100.64.0.2"),
},
{
User: users[1],
User: ptr.To(users[1]),
IPv4: ap("100.64.0.3"),
},
{
User: users[1],
User: ptr.To(users[1]),
IPv4: ap("100.64.0.4"),
},
// Tagged device for user1
{
User: users[0],
IPv4: ap("100.64.0.5"),
ForcedTags: []string{"tag:test"},
User: &users[0],
IPv4: ap("100.64.0.5"),
Tags: []string{"tag:test"},
},
// Tagged device for user2
{
User: users[1],
IPv4: ap("100.64.0.6"),
ForcedTags: []string{"tag:test"},
User: &users[1],
IPv4: ap("100.64.0.6"),
Tags: []string{"tag:test"},
},
}
@@ -925,6 +927,251 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
}
}
// TestTagUserMutualExclusivity tests that user-owned nodes and tagged nodes
// are treated as separate identity classes and cannot inadvertently access each other.
func TestTagUserMutualExclusivity(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
}
nodes := types.Nodes{
// User-owned nodes
{
User: ptr.To(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[1]),
IPv4: ap("100.64.0.2"),
},
// Tagged nodes
{
User: &users[0], // "created by" tracking
IPv4: ap("100.64.0.10"),
Tags: []string{"tag:server"},
},
{
User: &users[1], // "created by" tracking
IPv4: ap("100.64.0.11"),
Tags: []string{"tag:database"},
},
}
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
},
ACLs: []ACL{
// Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes
{
Action: "accept",
Sources: []Alias{up("user1@")},
Destinations: []AliasWithPorts{
aliasWithPorts(tp("tag:server"), tailcfg.PortRangeAny),
},
},
// Rule 2: tag:server should be able to reach tag:database
{
Action: "accept",
Sources: []Alias{tp("tag:server")},
Destinations: []AliasWithPorts{
aliasWithPorts(tp("tag:database"), tailcfg.PortRangeAny),
},
},
},
}
err := policy.validate()
if err != nil {
t.Fatalf("policy validation failed: %v", err)
}
// Test user1's user-owned node (100.64.0.1)
userNode := nodes[0].View()
userRules, err := policy.compileFilterRulesForNode(users, userNode, nodes.ViewSlice())
if err != nil {
t.Fatalf("unexpected error for user node: %v", err)
}
// User1's user-owned node should NOT reach tag:server (100.64.0.10)
// because user1@ as a source only matches user1's user-owned devices, NOT tagged devices
for _, rule := range userRules {
for _, dst := range rule.DstPorts {
if dst.IP == "100.64.0.10" {
t.Errorf("SECURITY: user-owned node should NOT reach tagged node (got dest %s in rule)", dst.IP)
}
}
}
// Test tag:server node (100.64.0.10)
// compileFilterRulesForNode returns rules for what the node can ACCESS (as source)
taggedNode := nodes[2].View()
taggedRules, err := policy.compileFilterRulesForNode(users, taggedNode, nodes.ViewSlice())
if err != nil {
t.Fatalf("unexpected error for tagged node: %v", err)
}
// Tag:server (as source) should be able to reach tag:database (100.64.0.11)
// Check destinations in the rules for this node
foundDatabaseDest := false
for _, rule := range taggedRules {
// Check if this rule applies to tag:server as source
if !slices.Contains(rule.SrcIPs, "100.64.0.10/32") {
continue
}
// Check if tag:database is in destinations
for _, dst := range rule.DstPorts {
if dst.IP == "100.64.0.11/32" {
foundDatabaseDest = true
break
}
}
if foundDatabaseDest {
break
}
}
if !foundDatabaseDest {
t.Errorf("tag:server should reach tag:database but didn't find 100.64.0.11 in destinations")
}
}
// TestAutogroupTagged tests that autogroup:tagged correctly selects all devices
// with tag-based identity (IsTagged() == true or has requested tags in tagOwners).
func TestAutogroupTagged(t *testing.T) {
t.Parallel()
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
}
nodes := types.Nodes{
// User-owned nodes (not tagged)
{
User: ptr.To(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[1]),
IPv4: ap("100.64.0.2"),
},
// Tagged nodes
{
User: &users[0], // "created by" tracking
IPv4: ap("100.64.0.10"),
Tags: []string{"tag:server"},
},
{
User: &users[1], // "created by" tracking
IPv4: ap("100.64.0.11"),
Tags: []string{"tag:database"},
},
{
User: &users[0],
IPv4: ap("100.64.0.12"),
Tags: []string{"tag:web", "tag:prod"},
},
}
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
},
ACLs: []ACL{
// Rule: autogroup:tagged can reach user-owned nodes
{
Action: "accept",
Sources: []Alias{agp("autogroup:tagged")},
Destinations: []AliasWithPorts{
aliasWithPorts(up("user1@"), tailcfg.PortRangeAny),
aliasWithPorts(up("user2@"), tailcfg.PortRangeAny),
},
},
},
}
err := policy.validate()
require.NoError(t, err)
// Verify autogroup:tagged includes all tagged nodes
taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, taggedIPs)
// Should contain all tagged nodes
assert.True(t, taggedIPs.Contains(*ap("100.64.0.10")), "should include tag:server")
assert.True(t, taggedIPs.Contains(*ap("100.64.0.11")), "should include tag:database")
assert.True(t, taggedIPs.Contains(*ap("100.64.0.12")), "should include tag:web,tag:prod")
// Should NOT contain user-owned nodes
assert.False(t, taggedIPs.Contains(*ap("100.64.0.1")), "should not include user1 node")
assert.False(t, taggedIPs.Contains(*ap("100.64.0.2")), "should not include user2 node")
// Test ACL filtering: all tagged nodes should be able to reach user nodes
tests := []struct {
name string
sourceNode types.NodeView
shouldReach []string // IP strings for comparison
}{
{
name: "tag:server can reach user-owned nodes",
sourceNode: nodes[2].View(),
shouldReach: []string{"100.64.0.1", "100.64.0.2"},
},
{
name: "tag:database can reach user-owned nodes",
sourceNode: nodes[3].View(),
shouldReach: []string{"100.64.0.1", "100.64.0.2"},
},
{
name: "tag:web,tag:prod can reach user-owned nodes",
sourceNode: nodes[4].View(),
shouldReach: []string{"100.64.0.1", "100.64.0.2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
rules, err := policy.compileFilterRulesForNode(users, tt.sourceNode, nodes.ViewSlice())
require.NoError(t, err)
// Verify all expected destinations are reachable
for _, expectedDest := range tt.shouldReach {
found := false
for _, rule := range rules {
for _, dstPort := range rule.DstPorts {
// DstPort.IP is CIDR notation like "100.64.0.1/32"
if strings.HasPrefix(dstPort.IP, expectedDest+"/") || dstPort.IP == expectedDest {
found = true
break
}
}
if found {
break
}
}
assert.True(t, found, "Expected to find destination %s in rules", expectedDest)
}
})
}
}
func TestAutogroupSelfInSourceIsRejected(t *testing.T) {
// Test that autogroup:self cannot be used in sources (per Tailscale spec)
policy := &Policy{
@@ -959,10 +1206,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
}
policy := &Policy{
@@ -1026,11 +1273,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: users[2], IPv4: ap("100.64.0.5")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
}
policy := &Policy{
@@ -1095,13 +1342,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
nodes := types.Nodes{
// User1's nodes
{User: users[0], IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
{User: users[0], IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
// User2's nodes
{User: users[1], IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
{User: users[1], IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
// Tagged node for user1 (should be excluded)
{User: users[0], IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", ForcedTags: []string{"tag:server"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
}
policy := &Policy{
@@ -1173,10 +1420,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
}
policy := &Policy{
@@ -1227,11 +1474,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: users[2], IPv4: ap("100.64.0.5")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
}
policy := &Policy{
@@ -1284,10 +1531,10 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
{User: users[0], IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
{User: users[0], IPv4: ap("100.64.0.3"), Hostname: "tagged1", ForcedTags: []string{"tag:server"}},
{User: users[0], IPv4: ap("100.64.0.4"), Hostname: "tagged2", ForcedTags: []string{"tag:web"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
}
policy := &Policy{
@@ -1344,10 +1591,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
{User: users[0], IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
{User: users[1], IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
{User: users[1], IPv4: ap("100.64.0.4"), Hostname: "user2-router", ForcedTags: []string{"tag:router"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
}
policy := &Policy{

View File

@@ -697,14 +697,14 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// Check for removed nodes
for nodeID, oldNode := range oldNodeMap {
if _, exists := newNodeMap[nodeID]; !exists {
affectedUsers[oldNode.User().ID] = struct{}{}
affectedUsers[oldNode.User().ID()] = struct{}{}
}
}
// Check for added nodes
for nodeID, newNode := range newNodeMap {
if _, exists := oldNodeMap[nodeID]; !exists {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
}
}
@@ -712,26 +712,26 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
for nodeID, newNode := range newNodeMap {
if oldNode, exists := oldNodeMap[nodeID]; exists {
// Check if user changed
if oldNode.User().ID != newNode.User().ID {
affectedUsers[oldNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID] = struct{}{}
if oldNode.User().ID() != newNode.User().ID() {
affectedUsers[oldNode.User().ID()] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
}
// Check if tag status changed
if oldNode.IsTagged() != newNode.IsTagged() {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
}
// Check if IPs changed (simple check - could be more sophisticated)
oldIPs := oldNode.IPs()
newIPs := newNode.IPs()
if len(oldIPs) != len(newIPs) {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
} else {
// Check if any IPs are different
for i, oldIP := range oldIPs {
if i >= len(newIPs) || oldIP != newIPs[i] {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
break
}
}
@@ -750,7 +750,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// Check in new nodes first
for _, node := range newNodes.All() {
if node.ID() == nodeID {
nodeUserID = node.User().ID
nodeUserID = node.User().ID()
found = true
break
}
@@ -760,7 +760,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
if !found {
for _, node := range oldNodes.All() {
if node.ID() == nodeID {
nodeUserID = node.User().ID
nodeUserID = node.User().ID()
found = true
break
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
@@ -19,8 +20,8 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo)
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
User: ptr.To(user),
UserID: ptr.To(user.ID),
Hostinfo: hostinfo,
}
}
@@ -456,8 +457,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
Hostname: "test-1-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[0],
UserID: users[0].ID,
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@@ -467,9 +468,9 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
UserID: users[1].ID,
ForcedTags: []string{"tag:node-router"},
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
}

View File

@@ -206,7 +206,12 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
continue
}
if node.User().ID == user.ID {
// Skip nodes without a user (defensive check for tests)
if !node.User().Valid() {
continue
}
if node.User().ID() == user.ID {
node.AppendToIPSet(&ips)
}
}
@@ -311,8 +316,8 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV
}
for _, node := range nodes.All() {
// Check if node has this tag in all tags (ForcedTags + AuthKey.Tags)
if slices.Contains(node.Tags(), string(t)) {
// Check if node has this tag
if node.HasTag(string(t)) {
node.AppendToIPSet(&ips)
}

View File

@@ -1549,7 +1549,17 @@ func TestResolvePolicy(t *testing.T) {
"groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"},
"groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"},
"notme": {Model: gorm.Model{ID: 5}, Name: "notme"},
"testuser2": {Model: gorm.Model{ID: 6}, Name: "testuser2"},
}
// Extract users to variables so we can take their addresses
testuser := users["testuser"]
groupuser := users["groupuser"]
groupuser1 := users["groupuser1"]
groupuser2 := users["groupuser2"]
notme := users["notme"]
testuser2 := users["testuser2"]
tests := []struct {
name string
nodes types.Nodes
@@ -1579,29 +1589,27 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: users["notme"],
User: ptr.To(notme),
IPv4: ap("100.100.101.1"),
},
// Not matching forced tags
{
User: users["testuser"],
ForcedTags: []string{"tag:anything"},
User: ptr.To(testuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
},
// not matching pak tag
// not matching because it's tagged (tags copied from AuthKey)
{
User: users["testuser"],
AuthKey: &types.PreAuthKey{
Tags: []string{"alsotagged"},
},
User: ptr.To(testuser),
Tags: []string{"alsotagged"},
IPv4: ap("100.100.101.3"),
},
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.103"),
},
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.104"),
},
},
@@ -1613,29 +1621,27 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: users["notme"],
User: ptr.To(notme),
IPv4: ap("100.100.101.4"),
},
// Not matching forced tags
{
User: users["groupuser"],
ForcedTags: []string{"tag:anything"},
User: ptr.To(groupuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
},
// not matching pak tag
// not matching because it's tagged (tags copied from AuthKey)
{
User: users["groupuser"],
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
User: ptr.To(groupuser),
Tags: []string{"tag:alsotagged"},
IPv4: ap("100.100.101.6"),
},
{
User: users["groupuser"],
User: ptr.To(groupuser),
IPv4: ap("100.100.101.203"),
},
{
User: users["groupuser"],
User: ptr.To(groupuser),
IPv4: ap("100.100.101.204"),
},
},
@@ -1653,12 +1659,12 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: users["notme"],
User: ptr.To(notme),
IPv4: ap("100.100.101.9"),
},
// Not matching forced tags
{
ForcedTags: []string{"tag:anything"},
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.10"),
},
// not matching pak tag
@@ -1670,14 +1676,12 @@ func TestResolvePolicy(t *testing.T) {
},
// Not matching forced tags
{
ForcedTags: []string{"tag:test"},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
},
// not matching pak tag
// matching tag (tags copied from AuthKey during registration)
{
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:test"},
},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.239"),
},
},
@@ -1706,11 +1710,11 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(Group("group:testgroup")),
nodes: types.Nodes{
{
User: users["groupuser1"],
User: ptr.To(groupuser1),
IPv4: ap("100.100.101.203"),
},
{
User: users["groupuser2"],
User: ptr.To(groupuser2),
IPv4: ap("100.100.101.204"),
},
},
@@ -1731,7 +1735,7 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(Username("invaliduser@")),
nodes: types.Nodes{
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.103"),
},
},
@@ -1742,7 +1746,7 @@ func TestResolvePolicy(t *testing.T) {
toResolve: tp("tag:invalid"),
nodes: types.Nodes{
{
ForcedTags: []string{"tag:test"},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
},
},
@@ -1763,18 +1767,18 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Node with no tags (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
// Node with forced tags (should be excluded)
{
User: users["testuser"],
ForcedTags: []string{"tag:test"},
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@@ -1782,7 +1786,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with non-allowed requested tag (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
@@ -1790,7 +1794,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, one allowed (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
@@ -1798,7 +1802,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, none allowed (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
@@ -1822,18 +1826,18 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Node with no tags (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
// Node with forced tag (should be included)
{
User: users["testuser"],
ForcedTags: []string{"tag:test"},
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@@ -1841,7 +1845,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with non-allowed requested tag (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
@@ -1849,7 +1853,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, one allowed (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
@@ -1857,7 +1861,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, none allowed (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
@@ -1865,8 +1869,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple forced tags (should be included)
{
User: users["testuser"],
ForcedTags: []string{"tag:test", "tag:other"},
User: ptr.To(testuser),
Tags: []string{"tag:test", "tag:other"},
IPv4: ap("100.100.101.7"),
},
},
@@ -1886,20 +1890,20 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(AutoGroupSelf),
nodes: types.Nodes{
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
{
User: users["testuser2"],
User: ptr.To(testuser2),
IPv4: ap("100.100.101.2"),
},
{
User: users["testuser"],
ForcedTags: []string{"tag:test"},
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.3"),
},
{
User: users["testuser2"],
User: ptr.To(testuser2),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@@ -1961,23 +1965,23 @@ func TestResolveAutoApprovers(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
{
IPv4: ap("100.64.0.4"),
ForcedTags: []string{"tag:testtag"},
Tags: []string{"tag:testtag"},
},
{
IPv4: ap("100.64.0.5"),
ForcedTags: []string{"tag:exittest"},
Tags: []string{"tag:exittest"},
},
}
@@ -2280,15 +2284,15 @@ func TestNodeCanApproveRoute(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
}
@@ -2413,15 +2417,15 @@ func TestResolveTagOwners(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
}
@@ -2498,15 +2502,15 @@ func TestNodeCanHaveTag(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
}
@@ -2580,6 +2584,49 @@ func TestNodeCanHaveTag(t *testing.T) {
tag: "tag:test",
want: false,
},
{
name: "node-with-unauthorized-tag-different-user",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
},
},
node: nodes[2], // user3's node
tag: "tag:prod",
want: false,
},
{
name: "node-with-multiple-tags-one-unauthorized",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
},
},
node: nodes[0], // user1's node
tag: "tag:database",
want: false, // user1 cannot have tag:database (owned by user2)
},
{
name: "empty-tagowners-map",
policy: &Policy{
TagOwners: TagOwners{},
},
node: nodes[0],
tag: "tag:test",
want: false, // No one can have tags if tagOwners is empty
},
{
name: "tag-not-in-tagowners",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
},
},
node: nodes[0],
tag: "tag:dev", // This tag is not defined in tagOwners
want: false,
},
}
for _, tt := range tests {

View File

@@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"tailscale.com/tailcfg"
@@ -42,11 +43,6 @@ type mapSession struct {
node *types.Node
w http.ResponseWriter
warnf func(string, ...any)
infof func(string, ...any)
tracef func(string, ...any)
errf func(error, string, ...any)
}
func (h *Headscale) newMapSession(
@@ -55,8 +51,6 @@ func (h *Headscale) newMapSession(
w http.ResponseWriter,
node *types.Node,
) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node)
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
return &mapSession{
@@ -73,12 +67,6 @@ func (h *Headscale) newMapSession(
keepAlive: ka,
keepAliveTicker: nil,
// Loggers
warnf: warnf,
infof: infof,
tracef: tracef,
errf: errf,
}
}
@@ -295,6 +283,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
}
data := make([]byte, reservedResponseHeaderSize)
//nolint:gosec // G115: JSON response size will not exceed uint32 max
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
data = append(data, jsonBody...)
@@ -365,45 +354,22 @@ func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcf
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
}
func logPollFunc(
mapRequest tailcfg.MapRequest,
node *types.Node,
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Info().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Err(err).
Msgf(msg, a...)
}
// logf adds common mapSession context to a zerolog event.
func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event {
return event.
Bool("omitPeers", m.req.OmitPeers).
Bool("stream", m.req.Stream).
Uint64("node.id", m.node.ID.Uint64()).
Str("node.name", m.node.Hostname)
}
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) }
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) }
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
func (m *mapSession) errf(err error, msg string, a ...any) {
m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...)
}

View File

@@ -78,7 +78,7 @@ func (s *State) DebugOverview() string {
now := time.Now()
for _, node := range allNodes.All() {
if node.Valid() {
userName := node.User().Name
userName := node.User().Name()
userNodeCounts[userName]++
if node.IsOnline().Valid() && node.IsOnline().Get() {
@@ -281,7 +281,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
for _, node := range allNodes.All() {
if node.Valid() {
userName := node.User().Name
userName := node.User().Name()
info.Users[userName]++
if node.IsOnline().Valid() && node.IsOnline().Get() {

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestNetInfoFromMapRequest(t *testing.T) {
@@ -148,8 +149,8 @@ func createTestNodeSimple(id types.NodeID) *types.Node {
node := &types.Node{
ID: id,
Hostname: "test-node",
UserID: uint(id),
User: user,
UserID: ptr.To(uint(id)),
User: &user,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPv4: &netip.Addr{},

View File

@@ -408,7 +408,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
// Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps
for _, n := range nodes {
nodeView := n.View()
userID := types.UserID(n.UserID)
userID := n.TypedUserID()
newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView)
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
@@ -515,7 +515,7 @@ func (s *NodeStore) DebugString() string {
if len(nodes) > 0 {
userName := "unknown"
if len(nodes) > 0 && nodes[0].Valid() {
userName = nodes[0].User().Name
userName = nodes[0].User().Name()
}
sb.WriteString(fmt.Sprintf(" - User %d (%s): %d nodes\n", userID, userName, len(nodes)))
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestSnapshotFromNodes(t *testing.T) {
@@ -173,8 +174,8 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string)
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: userID,
User: types.User{
UserID: ptr.To(userID),
User: &types.User{
Name: username,
DisplayName: username,
},
@@ -627,7 +628,7 @@ func TestNodeStoreOperations(t *testing.T) {
go func() {
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.ForcedTags = []string{"tag1", "tag2"}
n.Tags = []string{"tag1", "tag2"}
})
close(done3)
}()
@@ -648,24 +649,24 @@ func TestNodeStoreOperations(t *testing.T) {
// resultNode1 (from hostname update) should also have the givenname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode1.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode1.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.Tags().AsSlice())
// resultNode2 (from givenname update) should also have the hostname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode2.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode2.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.Tags().AsSlice())
// resultNode3 (from tags update) should also have the hostname and givenname changes
assert.Equal(t, "multi-update-hostname", resultNode3.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode3.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.Tags().AsSlice())
// Verify the snapshot also has all changes
snapshot := store.data.Load()
finalNode := snapshot.nodesByID[1]
assert.Equal(t, "multi-update-hostname", finalNode.Hostname)
assert.Equal(t, "multi-update-givenname", finalNode.GivenName)
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags)
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.Tags)
},
},
},
@@ -687,7 +688,7 @@ func TestNodeStoreOperations(t *testing.T) {
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "db-save-hostname"
n.GivenName = "db-save-given"
n.ForcedTags = []string{"db-tag1", "db-tag2"}
n.Tags = []string{"db-tag1", "db-tag2"}
})
assert.True(t, ok, "UpdateNode should succeed")
@@ -696,21 +697,21 @@ func TestNodeStoreOperations(t *testing.T) {
// Verify the returned node has all expected values
assert.Equal(t, "db-save-hostname", resultNode.Hostname())
assert.Equal(t, "db-save-given", resultNode.GivenName())
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice())
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.Tags().AsSlice())
// Convert to struct as would be done for database save
nodePtr := resultNode.AsStruct()
assert.NotNil(t, nodePtr)
assert.Equal(t, "db-save-hostname", nodePtr.Hostname)
assert.Equal(t, "db-save-given", nodePtr.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.Tags)
// Verify the snapshot also reflects the same state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, "db-save-hostname", storedNode.Hostname)
assert.Equal(t, "db-save-given", storedNode.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.Tags)
},
},
{
@@ -742,7 +743,7 @@ func TestNodeStoreOperations(t *testing.T) {
go func() {
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.ForcedTags = []string{"concurrent-tag"}
n.Tags = []string{"concurrent-tag"}
})
close(done3)
}()
@@ -767,22 +768,22 @@ func TestNodeStoreOperations(t *testing.T) {
// All should have the complete final state
assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.Tags)
assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.Tags)
assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.Tags)
// Verify consistency with stored state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname)
assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName)
assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags)
assert.Equal(t, nodePtr1.Tags, storedNode.Tags)
},
},
{
@@ -855,8 +856,8 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
Hostname: hostname,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
UserID: 1,
User: types.User{
UserID: ptr.To(uint(1)),
User: &types.User{
Name: "concurrent-test-user",
},
}

View File

@@ -53,6 +53,9 @@ const (
// ErrUnsupportedPolicyMode is returned for invalid policy modes. Valid modes are "file" and "db".
var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode")
// ErrNodeNotFound is returned when a node cannot be found by its ID.
var ErrNodeNotFound = errors.New("node not found")
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
@@ -651,13 +654,36 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
return s.persistNodeToDB(n)
}
// SetNodeTags assigns tags to a node for use in access control policies.
// SetNodeTags assigns tags to a node, making it a "tagged node".
// Once a node is tagged, it cannot be un-tagged (only tags can be changed).
// The UserID is preserved as "created by" information.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) {
// CANNOT REMOVE ALL TAGS
if len(tags) == 0 {
return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags
}
// Get node for validation
existingNode, exists := s.nodeStore.GetNode(nodeID)
if !exists {
return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID)
}
// Validate tags against policy
validatedTags, err := s.validateAndNormalizeTags(existingNode.AsStruct(), tags)
if err != nil {
return types.NodeView{}, change.EmptySet, err
}
// Log the operation
logTagOperation(existingNode, validatedTags)
// Update NodeStore before database to ensure consistency. The NodeStore update is
// blocking and will be the source of truth for the batcher. The database update must
// make the exact same change.
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
node.ForcedTags = tags
node.Tags = validatedTags
// UserID is preserved as "created by" - do NOT set to nil
})
if !ok {
@@ -927,7 +953,8 @@ func (s *State) DestroyAPIKey(key types.APIKey) error {
}
// CreatePreAuthKey generates a new pre-authentication key for a user.
func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) {
// The userID parameter is now optional (can be nil) for system-created tagged keys.
func (s *State) CreatePreAuthKey(userID *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) {
return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags)
}
@@ -1063,8 +1090,6 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
// Prepare the node for registration
nodeToRegister := types.Node{
Hostname: params.Hostname,
UserID: params.User.ID,
User: params.User,
MachineKey: params.MachineKey,
NodeKey: params.NodeKey,
DiscoKey: params.DiscoKey,
@@ -1075,11 +1100,38 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
Expiry: params.Expiry,
}
// Pre-auth key specific fields
// Assign ownership based on PreAuthKey
if params.PreAuthKey != nil {
nodeToRegister.ForcedTags = params.PreAuthKey.Proto().GetAclTags()
if params.PreAuthKey.IsTagged() {
// TAGGED NODE
// Tags from PreAuthKey are assigned ONLY during initial authentication
nodeToRegister.Tags = params.PreAuthKey.Proto().GetAclTags()
// Set UserID to track "created by" (who created the PreAuthKey)
if params.PreAuthKey.UserID != nil {
nodeToRegister.UserID = params.PreAuthKey.UserID
nodeToRegister.User = params.PreAuthKey.User
}
// If PreAuthKey.UserID is nil, the node is "orphaned" (system-created)
} else {
// USER-OWNED NODE
nodeToRegister.UserID = &params.PreAuthKey.User.ID
nodeToRegister.User = params.PreAuthKey.User
nodeToRegister.Tags = nil
}
nodeToRegister.AuthKey = params.PreAuthKey
nodeToRegister.AuthKeyID = &params.PreAuthKey.ID
} else {
// Non-PreAuthKey registration (OIDC, CLI) - always user-owned
nodeToRegister.UserID = &params.User.ID
nodeToRegister.User = &params.User
nodeToRegister.Tags = nil
}
// Validate before saving
err := validateNodeOwnership(&nodeToRegister)
if err != nil {
return types.NodeView{}, err
}
// Allocate new IPs
@@ -1156,7 +1208,7 @@ func (s *State) HandleNodeFromAuthPath(
logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
user.Username(),
user.Name,
hostname,
regEntry.Node.Hostinfo,
)
@@ -1171,7 +1223,7 @@ func (s *State) HandleNodeFromAuthPath(
log.Debug().
Caller().
Str("registration_id", registrationID.String()).
Str("user.name", user.Username()).
Str("user.name", user.Name).
Str("registrationMethod", registrationMethod).
Str("node.name", existingNodeSameUser.Hostname()).
Uint64("node.id", existingNodeSameUser.ID().Uint64()).
@@ -1233,7 +1285,7 @@ func (s *State) HandleNodeFromAuthPath(
// Check if node exists with this machine key for a different user (for netinfo preservation)
existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey)
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != user.ID {
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != user.ID {
// Node exists but belongs to a different user
// Create a NEW node for the new user (do not transfer)
// This allows the same machine to have separate node identities per user
@@ -1243,8 +1295,8 @@ func (s *State) HandleNodeFromAuthPath(
Str("existing.node.name", existingNodeAnyUser.Hostname()).
Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()).
Str("machine.key", regEntry.Node.MachineKey.ShortString()).
Str("old.user", oldUser.Username()).
Str("new.user", user.Username()).
Str("old.user", oldUser.Name()).
Str("new.user", user.Name).
Str("method", registrationMethod).
Msg("Creating new node for different user (same machine key exists for another user)")
}
@@ -1253,7 +1305,7 @@ func (s *State) HandleNodeFromAuthPath(
log.Debug().
Caller().
Str("registration_id", registrationID.String()).
Str("user.name", user.Username()).
Str("user.name", user.Name).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("Registering new node from auth callback")
@@ -1416,8 +1468,11 @@ func (s *State) HandleNodeFromPreAuthKey(
node.RegisterMethod = util.RegisterMethodAuthKey
// TODO(kradalby): This might need a rework as part of #2417
node.ForcedTags = pak.Proto().GetAclTags()
// CRITICAL: Tags from PreAuthKey are ONLY applied during initial authentication
// On re-registration, we MUST NOT change tags or node ownership
// The node keeps whatever tags/user ownership it already has
//
// Only update AuthKey reference
node.AuthKey = pak
node.AuthKeyID = &pak.ID
node.IsOnline = ptr.To(false)
@@ -1467,7 +1522,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Check if node exists with this machine key for a different user
existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != pak.User.ID {
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != pak.User.ID {
// Node exists but belongs to a different user
// Create a NEW node for the new user (do not transfer)
// This allows the same machine to have separate node identities per user
@@ -1477,7 +1532,7 @@ func (s *State) HandleNodeFromPreAuthKey(
Str("existing.node.name", existingNodeAnyUser.Hostname()).
Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()).
Str("machine.key", machineKey.ShortString()).
Str("old.user", oldUser.Username()).
Str("old.user", oldUser.Name()).
Str("new.user", pak.User.Username()).
Msg("Creating new node for different user (same machine key exists for another user)")
}
@@ -1488,7 +1543,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Create and save new node
var err error
finalNode, err = s.createAndSaveNewNode(newNodeParams{
User: pak.User,
User: *pak.User,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
DiscoKey: key.DiscoPublic{}, // DiscoKey not available in RegisterRequest

107
hscontrol/state/tags.go Normal file
View File

@@ -0,0 +1,107 @@
package state
import (
"errors"
"fmt"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
)
var (
// ErrNodeMarkedTaggedButHasNoTags is returned when a node is marked as tagged but has no tags.
ErrNodeMarkedTaggedButHasNoTags = errors.New("node marked as tagged but has no tags")
// ErrNodeHasNeitherUserNorTags is returned when a node has neither a user nor tags.
ErrNodeHasNeitherUserNorTags = errors.New("node has neither user nor tags - must be owned by user or tagged")
// ErrInvalidOrUnauthorizedTags is returned when tags are invalid or unauthorized.
ErrInvalidOrUnauthorizedTags = errors.New("invalid or unauthorized tags")
)
// validateNodeOwnership ensures proper node ownership model.
// A node must be EITHER user-owned OR tagged (mutually exclusive by behavior).
// Tagged nodes CAN have a UserID for "created by" tracking, but the tag is the owner.
func validateNodeOwnership(node *types.Node) error {
isTagged := node.IsTagged()
// Tagged nodes: Must have tags, UserID is optional (just "created by")
if isTagged {
if len(node.Tags) == 0 {
return fmt.Errorf("%w: %q", ErrNodeMarkedTaggedButHasNoTags, node.Hostname)
}
// UserID can be set (created by) or nil (orphaned), both valid for tagged nodes
return nil
}
// User-owned nodes: Must have UserID, must NOT have tags
if node.UserID == nil {
return fmt.Errorf("%w: %q", ErrNodeHasNeitherUserNorTags, node.Hostname)
}
return nil
}
// validateAndNormalizeTags validates tags against policy and normalizes them.
// Returns validated and normalized tags, or an error if validation fails.
func (s *State) validateAndNormalizeTags(node *types.Node, requestedTags []string) ([]string, error) {
if len(requestedTags) == 0 {
return nil, nil
}
var (
validTags []string
invalidTags []string
)
for _, tag := range requestedTags {
// Validate format
if !strings.HasPrefix(tag, "tag:") {
invalidTags = append(invalidTags, tag)
continue
}
// Validate against policy
nodeView := node.View()
if s.polMan.NodeCanHaveTag(nodeView, tag) {
validTags = append(validTags, tag)
} else {
invalidTags = append(invalidTags, tag)
}
}
if len(invalidTags) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidOrUnauthorizedTags, invalidTags)
}
// Normalize: sort and deduplicate
slices.Sort(validTags)
return slices.Compact(validTags), nil
}
// logTagOperation logs tag assignment operations for audit purposes.
func logTagOperation(existingNode types.NodeView, newTags []string) {
if existingNode.IsTagged() {
log.Info().
Uint64("node.id", existingNode.ID().Uint64()).
Str("node.name", existingNode.Hostname()).
Strs("old.tags", existingNode.Tags().AsSlice()).
Strs("new.tags", newTags).
Msg("Updating tags on already-tagged node")
} else {
var userID uint
if existingNode.UserID().Valid() {
userID = existingNode.UserID().Get()
}
log.Info().
Uint64("node.id", existingNode.ID().Uint64()).
Str("node.name", existingNode.Hostname()).
Uint("created.by.user", userID).
Strs("new.tags", newTags).
Msg("Converting user-owned node to tagged node (irreversible)")
}
}

View File

@@ -6,7 +6,6 @@ import (
"net/netip"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"time"
@@ -28,6 +27,7 @@ var (
ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars")
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node")
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
@@ -97,16 +97,21 @@ type Node struct {
// GivenName is the name used in all DNS related
// parts of headscale.
GivenName string `gorm:"type:varchar(63);unique_index"`
UserID uint
User User `gorm:"constraint:OnDelete:CASCADE;"`
// UserID is set for ALL nodes (tagged and user-owned) to track "created by".
// For tagged nodes, this is informational only - the tag is the owner.
// For user-owned nodes, this identifies the owner.
// Only nil for orphaned nodes (should not happen in normal operation).
UserID *uint
User *User `gorm:"constraint:OnDelete:CASCADE;"`
RegisterMethod string
// ForcedTags are tags set by CLI/API. It is not considered
// the source of truth, but is one of the sources from
// which a tag might originate.
// ForcedTags are _always_ applied to the node.
ForcedTags []string `gorm:"column:forced_tags;serializer:json"`
// Tags is the definitive owner for tagged nodes.
// When non-empty, the node is "tagged" and tags define its identity.
// Empty for user-owned nodes.
// Tags cannot be removed once set (one-way transition).
Tags []string `gorm:"column:tags;serializer:json"`
// When a node has been created with a PreAuthKey, we need to
// prevent the preauthkey from being deleted before the node.
@@ -196,55 +201,32 @@ func (node *Node) HasIP(i netip.Addr) bool {
return false
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys).
// IsTagged reports if a device is tagged and therefore should not be treated
// as a user-owned device.
// When a node has tags, the tags define its identity (not the user).
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
}
return len(node.Tags) > 0
}
if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 {
return true
}
if node.Hostinfo == nil {
return false
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
return false
// IsUserOwned returns true if node is owned by a user (not tagged).
// Tagged nodes may have a UserID for "created by" tracking, but the tag is the owner.
func (node *Node) IsUserOwned() bool {
return !node.IsTagged()
}
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys).
func (node *Node) HasTag(tag string) bool {
return slices.Contains(node.Tags(), tag)
return slices.Contains(node.Tags, tag)
}
func (node *Node) Tags() []string {
var tags []string
if node.AuthKey != nil {
tags = append(tags, node.AuthKey.Tags...)
// TypedUserID returns the UserID as a typed UserID type.
// Returns 0 if UserID is nil.
func (node *Node) TypedUserID() UserID {
if node.UserID == nil {
return 0
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
// #2417
tags = append(tags, node.ForcedTags...)
sort.Strings(tags)
tags = slices.Compact(tags)
return tags
return UserID(*node.UserID)
}
func (node *Node) RequestTags() []string {
@@ -389,8 +371,8 @@ func (node *Node) Proto() *v1.Node {
IpAddresses: node.IPsAsString(),
Name: node.Hostname,
GivenName: node.GivenName,
User: node.User.Proto(),
ForcedTags: node.ForcedTags,
User: nil, // Will be set below based on node type
ForcedTags: node.Tags,
Online: node.IsOnline != nil && *node.IsOnline,
// Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has
@@ -404,6 +386,13 @@ func (node *Node) Proto() *v1.Node {
CreatedAt: timestamppb.New(node.CreatedAt),
}
// Set User field based on node ownership
// Note: User will be set to TaggedDevices in the gRPC layer (grpcv1.go)
// for proper MapResponse formatting
if node.User != nil {
nodeProto.User = node.User.Proto()
}
if node.AuthKey != nil {
nodeProto.PreAuthKey = node.AuthKey.Proto()
}
@@ -701,8 +690,20 @@ func (nodes Nodes) DebugString() string {
func (node Node) DebugString() string {
var sb strings.Builder
fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID)
fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags())
// Show ownership status
if node.IsTagged() {
fmt.Fprintf(&sb, "\tTagged: %v\n", node.Tags)
if node.User != nil {
fmt.Fprintf(&sb, "\tCreated by: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
}
} else if node.User != nil {
fmt.Fprintf(&sb, "\tUser-owned: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
} else {
fmt.Fprintf(&sb, "\tOrphaned: no user or tags\n")
}
fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs())
fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes)
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
@@ -714,8 +715,7 @@ func (node Node) DebugString() string {
}
func (v NodeView) UserView() UserView {
u := v.User()
return u.View()
return v.User()
}
func (v NodeView) IPs() []netip.Addr {
@@ -790,13 +790,6 @@ func (v NodeView) RequestTagsSlice() views.Slice[string] {
return v.Hostinfo().RequestTags()
}
func (v NodeView) Tags() []string {
if !v.Valid() {
return nil
}
return v.ж.Tags()
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
@@ -893,6 +886,32 @@ func (v NodeView) HasTag(tag string) bool {
return v.ж.HasTag(tag)
}
// TypedUserID returns the UserID as a typed UserID type.
// Returns 0 if UserID is nil or node is invalid.
func (v NodeView) TypedUserID() UserID {
if !v.Valid() {
return 0
}
return v.ж.TypedUserID()
}
// TailscaleUserID returns the user ID to use in Tailscale protocol.
// Tagged nodes always return TaggedDevices.ID, user-owned nodes return their actual UserID.
func (v NodeView) TailscaleUserID() tailcfg.UserID {
if !v.Valid() {
return 0
}
if v.IsTagged() {
//nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64
return tailcfg.UserID(int64(TaggedDevices.ID))
}
//nolint:gosec // G115: UserID values are within int64 range
return tailcfg.UserID(int64(v.UserID().Get()))
}
// Prefixes returns the node IPs as netip.Prefix.
func (v NodeView) Prefixes() []netip.Prefix {
if !v.Valid() {

View File

@@ -0,0 +1,295 @@
package types
import (
"testing"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged.
func TestNodeIsTagged(t *testing.T) {
tests := []struct {
name string
node Node
want bool
}{
{
name: "node with tags - is tagged",
node: Node{
Tags: []string{"tag:server", "tag:prod"},
},
want: true,
},
{
name: "node with single tag - is tagged",
node: Node{
Tags: []string{"tag:web"},
},
want: true,
},
{
name: "node with no tags - not tagged",
node: Node{
Tags: []string{},
},
want: false,
},
{
name: "node with nil tags - not tagged",
node: Node{
Tags: nil,
},
want: false,
},
{
// Tags should be copied from AuthKey during registration, so a node
// with only AuthKey.Tags and no Tags would be invalid in practice.
// IsTagged() only checks node.Tags, not AuthKey.Tags.
name: "node registered with tagged authkey only - not tagged (tags should be copied)",
node: Node{
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
want: false,
},
{
name: "node with both tags and authkey tags - is tagged",
node: Node{
Tags: []string{"tag:server"},
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
want: true,
},
{
name: "node with user and no tags - not tagged",
node: Node{
UserID: ptr.To(uint(42)),
Tags: []string{},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.node.IsTagged()
assert.Equal(t, tt.want, got, "IsTagged() returned unexpected value")
})
}
}
// TestNodeViewIsTagged tests the IsTagged() method on NodeView.
func TestNodeViewIsTagged(t *testing.T) {
tests := []struct {
name string
node Node
want bool
}{
{
name: "tagged node via Tags field",
node: Node{
Tags: []string{"tag:server"},
},
want: true,
},
{
// Tags should be copied from AuthKey during registration, so a node
// with only AuthKey.Tags and no Tags would be invalid in practice.
name: "node with only AuthKey tags - not tagged (tags should be copied)",
node: Node{
AuthKey: &PreAuthKey{
Tags: []string{"tag:web"},
},
},
want: false, // IsTagged() only checks node.Tags
},
{
name: "user-owned node",
node: Node{
UserID: ptr.To(uint(1)),
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
view := tt.node.View()
got := view.IsTagged()
assert.Equal(t, tt.want, got, "NodeView.IsTagged() returned unexpected value")
})
}
}
// TestNodeHasTag tests the HasTag() method for checking specific tag membership.
func TestNodeHasTag(t *testing.T) {
tests := []struct {
name string
node Node
tag string
want bool
}{
{
name: "node has the tag",
node: Node{
Tags: []string{"tag:server", "tag:prod"},
},
tag: "tag:server",
want: true,
},
{
name: "node does not have the tag",
node: Node{
Tags: []string{"tag:server", "tag:prod"},
},
tag: "tag:web",
want: false,
},
{
// Tags should be copied from AuthKey during registration
// HasTag() only checks node.Tags, not AuthKey.Tags
name: "node has tag only in authkey - returns false",
node: Node{
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
tag: "tag:database",
want: false,
},
{
// node.Tags is what matters, not AuthKey.Tags
name: "node has tag in Tags but not in AuthKey",
node: Node{
Tags: []string{"tag:server"},
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
tag: "tag:server",
want: true,
},
{
name: "invalid tag format still returns false",
node: Node{
Tags: []string{"tag:server"},
},
tag: "invalid-tag",
want: false,
},
{
name: "empty tag returns false",
node: Node{
Tags: []string{"tag:server"},
},
tag: "",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.node.HasTag(tt.tag)
assert.Equal(t, tt.want, got, "HasTag() returned unexpected value")
})
}
}
// TestNodeTagsImmutableAfterRegistration tests that tags can only be set during registration.
func TestNodeTagsImmutableAfterRegistration(t *testing.T) {
// Test that a node registered with tags keeps them
taggedNode := Node{
ID: 1,
Tags: []string{"tag:server"},
AuthKey: &PreAuthKey{
Tags: []string{"tag:server"},
},
RegisterMethod: util.RegisterMethodAuthKey,
}
// Node should be tagged
assert.True(t, taggedNode.IsTagged(), "Node registered with tags should be tagged")
// Node should have the tag
has := taggedNode.HasTag("tag:server")
assert.True(t, has, "Node should have the tag it was registered with")
// Test that a user-owned node is not tagged
userNode := Node{
ID: 2,
UserID: ptr.To(uint(42)),
Tags: []string{},
RegisterMethod: util.RegisterMethodOIDC,
}
assert.False(t, userNode.IsTagged(), "User-owned node should not be tagged")
}
// TestNodeOwnershipModel tests the tags-as-identity model.
func TestNodeOwnershipModel(t *testing.T) {
tests := []struct {
name string
node Node
wantIsTagged bool
description string
}{
{
name: "tagged node has tags, UserID is informational",
node: Node{
ID: 1,
UserID: ptr.To(uint(5)), // "created by" user 5
Tags: []string{"tag:server"},
},
wantIsTagged: true,
description: "Tagged nodes may have UserID set for tracking, but ownership is defined by tags",
},
{
name: "user-owned node has no tags",
node: Node{
ID: 2,
UserID: ptr.To(uint(5)),
Tags: []string{},
},
wantIsTagged: false,
description: "User-owned nodes are owned by the user, not by tags",
},
{
// Tags should be copied from AuthKey to Node during registration
// IsTagged() only checks node.Tags, not AuthKey.Tags
name: "node with only authkey tags - not tagged (tags should be copied)",
node: Node{
ID: 3,
UserID: ptr.To(uint(5)), // "created by" user 5
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
wantIsTagged: false,
description: "IsTagged() only checks node.Tags; AuthKey.Tags should be copied during registration",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.node.IsTagged()
assert.Equal(t, tt.wantIsTagged, got, tt.description)
})
}
}
// TestUserTypedID tests the TypedID() helper method.
func TestUserTypedID(t *testing.T) {
user := User{
Model: gorm.Model{ID: 42},
}
typedID := user.TypedID()
assert.NotNil(t, typedID, "TypedID() should return non-nil pointer")
assert.Equal(t, UserID(42), *typedID, "TypedID() should return correct UserID value")
}

View File

@@ -139,7 +139,7 @@ func TestNodeFQDN(t *testing.T) {
name: "no-dnsconfig-with-username",
node: Node{
GivenName: "test",
User: User{
User: &User{
Name: "user",
},
},
@@ -150,7 +150,7 @@ func TestNodeFQDN(t *testing.T) {
name: "all-set",
node: Node{
GivenName: "test",
User: User{
User: &User{
Name: "user",
},
},
@@ -160,7 +160,7 @@ func TestNodeFQDN(t *testing.T) {
{
name: "no-given-name",
node: Node{
User: User{
User: &User{
Name: "user",
},
},
@@ -179,7 +179,7 @@ func TestNodeFQDN(t *testing.T) {
name: "no-dnsconfig",
node: Node{
GivenName: "test",
User: User{
User: &User{
Name: "user",
},
},

View File

@@ -23,16 +23,19 @@ type PreAuthKey struct {
Prefix string
Hash []byte // bcrypt
UserID uint
User User `gorm:"constraint:OnDelete:SET NULL;"`
// For tagged keys: UserID tracks who created the key (informational)
// For user-owned keys: UserID tracks the node owner
// Can be nil for system-created tagged keys
UserID *uint
User *User `gorm:"constraint:OnDelete:SET NULL;"`
Reusable bool
Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"`
// Tags are always applied to the node and is one of
// the sources of tags a node might have. They are copied
// from the PreAuthKey when the node logs in the first time,
// and ignored after.
// Tags to assign to nodes registered with this key.
// Tags are copied to the node during registration.
// If non-empty, this creates tagged nodes (not user-owned).
Tags []string `gorm:"serializer:json"`
CreatedAt *time.Time
@@ -48,19 +51,23 @@ type PreAuthKeyNew struct {
Tags []string
Expiration *time.Time
CreatedAt *time.Time
User User
User *User // Can be nil for system-created tagged keys
}
func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
Id: key.ID,
Key: key.Key,
User: key.User.Proto(),
User: nil, // Will be set below if not nil
Reusable: key.Reusable,
Ephemeral: key.Ephemeral,
AclTags: key.Tags,
}
if key.User != nil {
protoKey.User = key.User.Proto()
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
@@ -74,7 +81,7 @@ func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey {
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
User: key.User.Proto(),
User: nil, // Will be set below if not nil
Id: key.ID,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,
@@ -82,6 +89,10 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey {
AclTags: key.Tags,
}
if key.User != nil {
protoKey.User = key.User.Proto()
}
// For new keys (with prefix/hash), show the prefix so users can identify the key
// For legacy keys (with plaintext key), show the full key for backwards compatibility
if key.Prefix != "" {
@@ -139,3 +150,9 @@ func (pak *PreAuthKey) Validate() error {
return nil
}
// IsTagged returns true if this PreAuthKey creates tagged nodes.
// When a PreAuthKey has tags, nodes registered with it will be tagged nodes.
func (pak *PreAuthKey) IsTagged() bool {
return len(pak.Tags) > 0
}

View File

@@ -54,7 +54,13 @@ func (src *Node) Clone() *Node {
if dst.IPv6 != nil {
dst.IPv6 = ptr.To(*src.IPv6)
}
dst.ForcedTags = append(src.ForcedTags[:0:0], src.ForcedTags...)
if dst.UserID != nil {
dst.UserID = ptr.To(*src.UserID)
}
if dst.User != nil {
dst.User = ptr.To(*src.User)
}
dst.Tags = append(src.Tags[:0:0], src.Tags...)
if dst.AuthKeyID != nil {
dst.AuthKeyID = ptr.To(*src.AuthKeyID)
}
@@ -87,10 +93,10 @@ var _NodeCloneNeedsRegeneration = Node(struct {
IPv6 *netip.Addr
Hostname string
GivenName string
UserID uint
User User
UserID *uint
User *User
RegisterMethod string
ForcedTags []string
Tags []string
AuthKeyID *uint64
AuthKey *PreAuthKey
Expiry *time.Time
@@ -111,6 +117,12 @@ func (src *PreAuthKey) Clone() *PreAuthKey {
dst := new(PreAuthKey)
*dst = *src
dst.Hash = append(src.Hash[:0:0], src.Hash...)
if dst.UserID != nil {
dst.UserID = ptr.To(*src.UserID)
}
if dst.User != nil {
dst.User = ptr.To(*src.User)
}
dst.Tags = append(src.Tags[:0:0], src.Tags...)
if dst.CreatedAt != nil {
dst.CreatedAt = ptr.To(*src.CreatedAt)
@@ -127,8 +139,8 @@ var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct {
Key string
Prefix string
Hash []byte
UserID uint
User User
UserID *uint
User *User
Reusable bool
Ephemeral bool
Used bool

View File

@@ -139,12 +139,13 @@ func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePo
func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) }
func (v NodeView) Hostname() string { return v.ж.Hostname }
func (v NodeView) GivenName() string { return v.ж.GivenName }
func (v NodeView) UserID() uint { return v.ж.UserID }
func (v NodeView) User() User { return v.ж.User }
func (v NodeView) Hostname() string { return v.ж.Hostname }
func (v NodeView) GivenName() string { return v.ж.GivenName }
func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) }
func (v NodeView) User() UserView { return v.ж.User.View() }
func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod }
func (v NodeView) ForcedTags() views.Slice[string] { return views.SliceOf(v.ж.ForcedTags) }
func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) }
func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() }
@@ -179,10 +180,10 @@ var _NodeViewNeedsRegeneration = Node(struct {
IPv6 *netip.Addr
Hostname string
GivenName string
UserID uint
User User
UserID *uint
User *User
RegisterMethod string
ForcedTags []string
Tags []string
AuthKeyID *uint64
AuthKey *PreAuthKey
Expiry *time.Time
@@ -239,16 +240,17 @@ func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error {
return nil
}
func (v PreAuthKeyView) ID() uint64 { return v.ж.ID }
func (v PreAuthKeyView) Key() string { return v.ж.Key }
func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix }
func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) }
func (v PreAuthKeyView) UserID() uint { return v.ж.UserID }
func (v PreAuthKeyView) User() User { return v.ж.User }
func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable }
func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral }
func (v PreAuthKeyView) Used() bool { return v.ж.Used }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v PreAuthKeyView) ID() uint64 { return v.ж.ID }
func (v PreAuthKeyView) Key() string { return v.ж.Key }
func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix }
func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) }
func (v PreAuthKeyView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) }
func (v PreAuthKeyView) User() UserView { return v.ж.User.View() }
func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable }
func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral }
func (v PreAuthKeyView) Used() bool { return v.ж.Used }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.CreatedAt)
}
@@ -263,8 +265,8 @@ var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct {
Key string
Prefix string
Hash []byte
UserID uint
User User
UserID *uint
User *User
Reusable bool
Ephemeral bool
Used bool

View File

@@ -22,6 +22,21 @@ type UserID uint64
type Users []User
const (
// TaggedDevicesUserID is the special user ID for tagged devices.
// This ID is used when rendering tagged nodes in the Tailscale protocol.
TaggedDevicesUserID = 2147455555
)
// TaggedDevices is a special user used in MapResponse for tagged nodes.
// Tagged nodes don't belong to a real user - the tag is their identity.
// This special user ID is used when rendering tagged nodes in the Tailscale protocol.
var TaggedDevices = User{
Model: gorm.Model{ID: TaggedDevicesUserID},
Name: "tagged-devices",
DisplayName: "Tagged Devices",
}
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
@@ -77,6 +92,13 @@ func (u *User) StringID() string {
return strconv.FormatUint(uint64(u.ID), 10)
}
// TypedID returns a pointer to the user's ID as a UserID type.
// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)).
func (u *User) TypedID() *UserID {
uid := UserID(u.ID)
return &uid
}
// Username is the main way to get the username of a user,
// it will return the email if it exists, the name if it exists,
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
@@ -117,6 +139,13 @@ func (u UserView) TailscaleUser() tailcfg.User {
return u.ж.TailscaleUser()
}
// ID returns the user's ID.
// This is a custom accessor because gorm.Model.ID is embedded
// and the viewer generator doesn't always produce it.
func (u UserView) ID() uint {
return u.ж.ID
}
func (u *User) TailscaleLogin() tailcfg.Login {
return tailcfg.Login{
ID: tailcfg.LoginID(u.ID),