db: use PolicyManager for RequestTags migration

Refactor the RequestTags migration (202601121700-migrate-hostinfo-request-tags)
to use PolicyManager.NodeCanHaveTag() instead of reimplementing tag validation.

Changes:
- NewHeadscaleDatabase now accepts *types.Config to allow migrations
  access to policy configuration
- Add loadPolicyBytes helper to load policy from file or DB based on config
- Add standalone GetPolicy(tx *gorm.DB) for use during migrations
- Replace custom tag validation logic with PolicyManager

Benefits:
- Full HuJSON parsing support (not just JSON)
- Proper group expansion via PolicyManager
- Support for nested tags and autogroups
- Works with both file and database policy modes
- Single source of truth for tag validation


Co-Authored-By: Shourya Gautam <shouryamgautam@gmail.com>
This commit is contained in:
Shourya Gautam
2026-01-21 19:40:29 +05:30
committed by GitHub
parent 22afb2c61b
commit 4e1834adaf
9 changed files with 413 additions and 103 deletions

View File

@@ -67,6 +67,83 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
}
},
},
// Test for RequestTags migration (202601121700-migrate-hostinfo-request-tags)
// and forced_tags->tags rename migration (202511131445-node-forced-tags-to-tags)
//
// This test validates that:
// 1. The forced_tags column is renamed to tags
// 2. RequestTags from host_info are validated against policy tagOwners
// 3. Authorized tags are migrated to the tags column
// 4. Unauthorized tags are rejected
// 5. Existing tags are preserved
// 6. Group membership is evaluated for tag authorization
{
dbPath: "testdata/sqlite/request_tags_migration_test.sql",
wantFunc: func(t *testing.T, hsdb *HSDatabase) {
t.Helper()
nodes, err := Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx)
})
require.NoError(t, err)
require.Len(t, nodes, 7, "should have all 7 nodes")
// Helper to find node by hostname
findNode := func(hostname string) *types.Node {
for _, n := range nodes {
if n.Hostname == hostname {
return n
}
}
return nil
}
// Node 1: user1 has RequestTags for tag:server (authorized)
// Expected: tags = ["tag:server"]
node1 := findNode("node1")
require.NotNil(t, node1, "node1 should exist")
assert.Contains(t, node1.Tags, "tag:server", "node1 should have tag:server migrated from RequestTags")
// Node 2: user1 has RequestTags for tag:unauthorized (NOT authorized)
// Expected: tags = [] (unchanged)
node2 := findNode("node2")
require.NotNil(t, node2, "node2 should exist")
assert.Empty(t, node2.Tags, "node2 should have empty tags (unauthorized tag rejected)")
// Node 3: user2 has RequestTags for tag:client (authorized) + existing tag:existing
// Expected: tags = ["tag:client", "tag:existing"]
node3 := findNode("node3")
require.NotNil(t, node3, "node3 should exist")
assert.Contains(t, node3.Tags, "tag:client", "node3 should have tag:client migrated from RequestTags")
assert.Contains(t, node3.Tags, "tag:existing", "node3 should preserve existing tag")
// Node 4: user1 has RequestTags for tag:server which already exists
// Expected: tags = ["tag:server"] (no duplicates)
node4 := findNode("node4")
require.NotNil(t, node4, "node4 should exist")
assert.Equal(t, []string{"tag:server"}, node4.Tags, "node4 should have tag:server without duplicates")
// Node 5: user2 has no RequestTags
// Expected: tags = [] (unchanged)
node5 := findNode("node5")
require.NotNil(t, node5, "node5 should exist")
assert.Empty(t, node5.Tags, "node5 should have empty tags (no RequestTags)")
// Node 6: admin1 has RequestTags for tag:admin (authorized via group:admins)
// Expected: tags = ["tag:admin"]
node6 := findNode("node6")
require.NotNil(t, node6, "node6 should exist")
assert.Contains(t, node6.Tags, "tag:admin", "node6 should have tag:admin migrated via group membership")
// Node 7: user1 has RequestTags for tag:server (authorized) and tag:forbidden (unauthorized)
// Expected: tags = ["tag:server"] (only authorized tag)
node7 := findNode("node7")
require.NotNil(t, node7, "node7 should exist")
assert.Contains(t, node7.Tags, "tag:server", "node7 should have tag:server migrated")
assert.NotContains(t, node7.Tags, "tag:forbidden", "node7 should NOT have tag:forbidden (unauthorized)")
},
},
}
for _, tt := range tests {
@@ -288,13 +365,17 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
}
db, err := NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
&types.Config{
Database: types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
},
},
"",
emptyCache(),
)
if err != nil {
@@ -343,13 +424,17 @@ func TestSQLiteAllTestdataMigrations(t *testing.T) {
require.NoError(t, err)
_, err = NewHeadscaleDatabase(
types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
&types.Config{
Database: types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
},
},
"",
emptyCache(),
)
require.NoError(t, err)