mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-19 16:21:23 +01:00
Generalise the registration pipeline to a more general auth pipeline supporting both node registrations and SSH check auth requests. Rename RegistrationID to AuthID, unexport AuthRequest fields, and introduce AuthVerdict to unify the auth finish API. Add the urlParam generic helper for extracting typed URL parameters from chi routes, used by the new auth request handler. Updates #1850
448 lines
14 KiB
Go
448 lines
14 KiB
Go
package db
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/juanfont/headscale/hscontrol/types"
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"gorm.io/gorm"
|
||
"zgo.at/zcache/v2"
|
||
)
|
||
|
||
// TestSQLiteMigrationAndDataValidation tests specific SQLite migration scenarios
|
||
// and validates data integrity after migration. All migrations that require data validation
|
||
// should be added here.
|
||
func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||
tests := []struct {
|
||
dbPath string
|
||
wantFunc func(*testing.T, *HSDatabase)
|
||
}{
|
||
// at 14:15:06 ❯ go run ./cmd/headscale preauthkeys list
|
||
// ID | Key | Reusable | Ephemeral | Used | Expiration | Created | Tags
|
||
// 1 | 09b28f.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp
|
||
// 2 | 3112b9.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp
|
||
{
|
||
dbPath: "testdata/sqlite/failing-node-preauth-constraint_dump.sql",
|
||
wantFunc: func(t *testing.T, hsdb *HSDatabase) {
|
||
t.Helper()
|
||
// Comprehensive data preservation validation for node-preauth constraint issue
|
||
// Expected data from dump: 1 user, 2 api_keys, 6 nodes
|
||
|
||
// Verify users data preservation
|
||
users, err := Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||
return ListUsers(rx)
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, users, 1, "should preserve all 1 user from original schema")
|
||
|
||
// Verify api_keys data preservation
|
||
var apiKeyCount int
|
||
|
||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||
require.NoError(t, err)
|
||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||
|
||
// Verify nodes data preservation and field validation
|
||
nodes, err := Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||
return ListNodes(rx)
|
||
})
|
||
require.NoError(t, err)
|
||
assert.Len(t, nodes, 6, "should preserve all 6 nodes from original schema")
|
||
|
||
for _, node := range nodes {
|
||
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
|
||
assert.Contains(t, node.MachineKey.String(), "mkey:")
|
||
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
|
||
assert.Contains(t, node.NodeKey.String(), "nodekey:")
|
||
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
|
||
assert.Contains(t, node.DiscoKey.String(), "discokey:")
|
||
assert.Nil(t, node.AuthKey)
|
||
assert.Nil(t, node.AuthKeyID)
|
||
}
|
||
},
|
||
},
|
||
// Test for RequestTags migration (202601121700-migrate-hostinfo-request-tags)
|
||
// and forced_tags->tags rename migration (202511131445-node-forced-tags-to-tags)
|
||
//
|
||
// This test validates that:
|
||
// 1. The forced_tags column is renamed to tags
|
||
// 2. RequestTags from host_info are validated against policy tagOwners
|
||
// 3. Authorized tags are migrated to the tags column
|
||
// 4. Unauthorized tags are rejected
|
||
// 5. Existing tags are preserved
|
||
// 6. Group membership is evaluated for tag authorization
|
||
{
|
||
dbPath: "testdata/sqlite/request_tags_migration_test.sql",
|
||
wantFunc: func(t *testing.T, hsdb *HSDatabase) {
|
||
t.Helper()
|
||
|
||
nodes, err := Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||
return ListNodes(rx)
|
||
})
|
||
require.NoError(t, err)
|
||
require.Len(t, nodes, 7, "should have all 7 nodes")
|
||
|
||
// Helper to find node by hostname
|
||
findNode := func(hostname string) *types.Node {
|
||
for _, n := range nodes {
|
||
if n.Hostname == hostname {
|
||
return n
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Node 1: user1 has RequestTags for tag:server (authorized)
|
||
// Expected: tags = ["tag:server"]
|
||
node1 := findNode("node1")
|
||
require.NotNil(t, node1, "node1 should exist")
|
||
assert.Contains(t, node1.Tags, "tag:server", "node1 should have tag:server migrated from RequestTags")
|
||
|
||
// Node 2: user1 has RequestTags for tag:unauthorized (NOT authorized)
|
||
// Expected: tags = [] (unchanged)
|
||
node2 := findNode("node2")
|
||
require.NotNil(t, node2, "node2 should exist")
|
||
assert.Empty(t, node2.Tags, "node2 should have empty tags (unauthorized tag rejected)")
|
||
|
||
// Node 3: user2 has RequestTags for tag:client (authorized) + existing tag:existing
|
||
// Expected: tags = ["tag:client", "tag:existing"]
|
||
node3 := findNode("node3")
|
||
require.NotNil(t, node3, "node3 should exist")
|
||
assert.Contains(t, node3.Tags, "tag:client", "node3 should have tag:client migrated from RequestTags")
|
||
assert.Contains(t, node3.Tags, "tag:existing", "node3 should preserve existing tag")
|
||
|
||
// Node 4: user1 has RequestTags for tag:server which already exists
|
||
// Expected: tags = ["tag:server"] (no duplicates)
|
||
node4 := findNode("node4")
|
||
require.NotNil(t, node4, "node4 should exist")
|
||
assert.Equal(t, []string{"tag:server"}, node4.Tags, "node4 should have tag:server without duplicates")
|
||
|
||
// Node 5: user2 has no RequestTags
|
||
// Expected: tags = [] (unchanged)
|
||
node5 := findNode("node5")
|
||
require.NotNil(t, node5, "node5 should exist")
|
||
assert.Empty(t, node5.Tags, "node5 should have empty tags (no RequestTags)")
|
||
|
||
// Node 6: admin1 has RequestTags for tag:admin (authorized via group:admins)
|
||
// Expected: tags = ["tag:admin"]
|
||
node6 := findNode("node6")
|
||
require.NotNil(t, node6, "node6 should exist")
|
||
assert.Contains(t, node6.Tags, "tag:admin", "node6 should have tag:admin migrated via group membership")
|
||
|
||
// Node 7: user1 has RequestTags for tag:server (authorized) and tag:forbidden (unauthorized)
|
||
// Expected: tags = ["tag:server"] (only authorized tag)
|
||
node7 := findNode("node7")
|
||
require.NotNil(t, node7, "node7 should exist")
|
||
assert.Contains(t, node7.Tags, "tag:server", "node7 should have tag:server migrated")
|
||
assert.NotContains(t, node7.Tags, "tag:forbidden", "node7 should NOT have tag:forbidden (unauthorized)")
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.dbPath, func(t *testing.T) {
|
||
if !strings.HasSuffix(tt.dbPath, ".sql") {
|
||
t.Fatalf("TestSQLiteMigrationAndDataValidation only supports .sql files, got: %s", tt.dbPath)
|
||
}
|
||
|
||
hsdb := dbForTestWithPath(t, tt.dbPath)
|
||
if tt.wantFunc != nil {
|
||
tt.wantFunc(t, hsdb)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func emptyCache() *zcache.Cache[types.AuthID, types.AuthRequest] {
|
||
return zcache.New[types.AuthID, types.AuthRequest](time.Minute, time.Hour)
|
||
}
|
||
|
||
func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||
db, err := sql.Open("sqlite", dbPath)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer db.Close()
|
||
|
||
schemaContent, err := os.ReadFile(sqlFilePath)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||
|
||
return err
|
||
}
|
||
|
||
// requireConstraintFailed checks if the error is a constraint failure with
|
||
// either SQLite and PostgreSQL error messages.
|
||
func requireConstraintFailed(t *testing.T, err error) {
|
||
t.Helper()
|
||
require.Error(t, err)
|
||
|
||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||
}
|
||
}
|
||
|
||
func TestConstraints(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
run func(*testing.T, *gorm.DB)
|
||
}{
|
||
{
|
||
name: "no-duplicate-username-if-no-oidc",
|
||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||
_, err := CreateUser(db, types.User{Name: "user1"})
|
||
require.NoError(t, err)
|
||
_, err = CreateUser(db, types.User{Name: "user1"})
|
||
requireConstraintFailed(t, err)
|
||
},
|
||
},
|
||
{
|
||
name: "no-oidc-duplicate-username-and-id",
|
||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||
user := types.User{
|
||
Model: gorm.Model{ID: 1},
|
||
Name: "user1",
|
||
}
|
||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||
|
||
err := db.Save(&user).Error
|
||
require.NoError(t, err)
|
||
|
||
user = types.User{
|
||
Model: gorm.Model{ID: 2},
|
||
Name: "user1",
|
||
}
|
||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||
|
||
err = db.Save(&user).Error
|
||
requireConstraintFailed(t, err)
|
||
},
|
||
},
|
||
{
|
||
name: "no-oidc-duplicate-id",
|
||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||
user := types.User{
|
||
Model: gorm.Model{ID: 1},
|
||
Name: "user1",
|
||
}
|
||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||
|
||
err := db.Save(&user).Error
|
||
require.NoError(t, err)
|
||
|
||
user = types.User{
|
||
Model: gorm.Model{ID: 2},
|
||
Name: "user1.1",
|
||
}
|
||
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||
|
||
err = db.Save(&user).Error
|
||
requireConstraintFailed(t, err)
|
||
},
|
||
},
|
||
{
|
||
name: "allow-duplicate-username-cli-then-oidc",
|
||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
||
require.NoError(t, err)
|
||
|
||
user := types.User{
|
||
Name: "user1",
|
||
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||
}
|
||
|
||
err = db.Save(&user).Error
|
||
require.NoError(t, err)
|
||
},
|
||
},
|
||
{
|
||
name: "allow-duplicate-username-oidc-then-cli",
|
||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||
user := types.User{
|
||
Name: "user1",
|
||
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||
}
|
||
|
||
err := db.Save(&user).Error
|
||
require.NoError(t, err)
|
||
|
||
_, err = CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
||
require.NoError(t, err)
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name+"-postgres", func(t *testing.T) {
|
||
db := newPostgresTestDB(t)
|
||
tt.run(t, db.DB.Debug())
|
||
})
|
||
t.Run(tt.name+"-sqlite", func(t *testing.T) {
|
||
db, err := newSQLiteTestDB()
|
||
if err != nil {
|
||
t.Fatalf("creating database: %s", err)
|
||
}
|
||
|
||
tt.run(t, db.DB.Debug())
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestPostgresMigrationAndDataValidation tests specific PostgreSQL migration scenarios
|
||
// and validates data integrity after migration. All migrations that require data validation
|
||
// should be added here.
|
||
//
|
||
// TODO(kradalby): Convert to use plain text SQL dumps instead of binary .pssql dumps for consistency
|
||
// with SQLite tests and easier version control.
|
||
func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
dbPath string
|
||
wantFunc func(*testing.T, *HSDatabase)
|
||
}{}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
u := newPostgresDBForTest(t)
|
||
|
||
pgRestorePath, err := exec.LookPath("pg_restore")
|
||
if err != nil {
|
||
t.Fatal("pg_restore not found in PATH. Please install it and ensure it is accessible.")
|
||
}
|
||
|
||
// Construct the pg_restore command
|
||
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||
|
||
// Set the output streams
|
||
cmd.Stdout = os.Stdout
|
||
cmd.Stderr = os.Stderr
|
||
|
||
// Execute the command
|
||
err = cmd.Run()
|
||
if err != nil {
|
||
t.Fatalf("failed to restore postgres database: %s", err)
|
||
}
|
||
|
||
db := newHeadscaleDBFromPostgresURL(t, u)
|
||
|
||
if tt.wantFunc != nil {
|
||
tt.wantFunc(t, db)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func dbForTest(t *testing.T) *HSDatabase {
|
||
t.Helper()
|
||
return dbForTestWithPath(t, "")
|
||
}
|
||
|
||
func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
||
t.Helper()
|
||
|
||
dbPath := t.TempDir() + "/headscale_test.db"
|
||
|
||
// If SQL file path provided, validate and create database from it
|
||
if sqlFilePath != "" {
|
||
// Validate that the file is a SQL text file
|
||
if !strings.HasSuffix(sqlFilePath, ".sql") {
|
||
t.Fatalf("dbForTestWithPath only accepts .sql files, got: %s", sqlFilePath)
|
||
}
|
||
|
||
err := createSQLiteFromSQLFile(sqlFilePath, dbPath)
|
||
if err != nil {
|
||
t.Fatalf("setting up database from SQL file %s: %s", sqlFilePath, err)
|
||
}
|
||
}
|
||
|
||
db, err := NewHeadscaleDatabase(
|
||
&types.Config{
|
||
Database: types.DatabaseConfig{
|
||
Type: "sqlite3",
|
||
Sqlite: types.SqliteConfig{
|
||
Path: dbPath,
|
||
},
|
||
},
|
||
Policy: types.PolicyConfig{
|
||
Mode: types.PolicyModeDB,
|
||
},
|
||
},
|
||
emptyCache(),
|
||
)
|
||
if err != nil {
|
||
t.Fatalf("setting up database: %s", err)
|
||
}
|
||
|
||
if sqlFilePath != "" {
|
||
t.Logf("database set up from %s at: %s", sqlFilePath, dbPath)
|
||
} else {
|
||
t.Logf("database set up at: %s", dbPath)
|
||
}
|
||
|
||
return db
|
||
}
|
||
|
||
// TestSQLiteAllTestdataMigrations tests migration compatibility across all SQLite schemas
|
||
// in the testdata directory. It verifies they can be successfully migrated to the current
|
||
// schema version. This test only validates migration success, not data integrity.
|
||
//
|
||
// All test database files are SQL dumps (created with `sqlite3 headscale.db .dump`) generated
|
||
// with old Headscale binaries on empty databases (no user/node data). These dumps include the
|
||
// migration history in the `migrations` table, which allows the migration system to correctly
|
||
// skip already-applied migrations and only run new ones.
|
||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
schemas, err := os.ReadDir("testdata/sqlite")
|
||
require.NoError(t, err)
|
||
|
||
t.Logf("loaded %d schemas", len(schemas))
|
||
|
||
for _, schema := range schemas {
|
||
if schema.IsDir() {
|
||
continue
|
||
}
|
||
|
||
t.Logf("validating: %s", schema.Name())
|
||
|
||
t.Run(schema.Name(), func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
dbPath := t.TempDir() + "/headscale_test.db"
|
||
|
||
// Setup a database with the old schema
|
||
schemaPath := filepath.Join("testdata/sqlite", schema.Name())
|
||
err := createSQLiteFromSQLFile(schemaPath, dbPath)
|
||
require.NoError(t, err)
|
||
|
||
_, err = NewHeadscaleDatabase(
|
||
&types.Config{
|
||
Database: types.DatabaseConfig{
|
||
Type: "sqlite3",
|
||
Sqlite: types.SqliteConfig{
|
||
Path: dbPath,
|
||
},
|
||
},
|
||
Policy: types.PolicyConfig{
|
||
Mode: types.PolicyModeDB,
|
||
},
|
||
},
|
||
emptyCache(),
|
||
)
|
||
require.NoError(t, err)
|
||
})
|
||
}
|
||
}
|