all: fix golangci-lint issues (#3064)

This commit is contained in:
Kristoffer Dalby
2026-02-06 21:45:32 +01:00
committed by GitHub
parent bfb6fd80df
commit ce580f8245
131 changed files with 3131 additions and 1560 deletions

View File

@@ -77,7 +77,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
Expiration: expiration,
}
if err := hsdb.DB.Save(&key).Error; err != nil {
if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr
return "", nil, fmt.Errorf("saving API key to database: %w", err)
}
@@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey(
// ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
keys := []types.APIKey{}
if err := hsdb.DB.Find(&keys).Error; err != nil {
err := hsdb.DB.Find(&keys).Error
if err != nil {
return nil, err
}
@@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
// ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error
if err != nil {
return err
}

View File

@@ -53,6 +53,8 @@ type HSDatabase struct {
// NewHeadscaleDatabase creates a new database connection and runs migrations.
// It accepts the full configuration to allow migrations access to policy settings.
//
//nolint:gocyclo // complex database initialization with many migrations
func NewHeadscaleDatabase(
cfg *types.Config,
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
@@ -76,7 +78,7 @@ func NewHeadscaleDatabase(
ID: "202501221827",
Migrate: func(tx *gorm.DB) error {
// Remove any invalid routes associated with a node that does not exist.
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
if err != nil {
return err
@@ -84,14 +86,14 @@ func NewHeadscaleDatabase(
}
// Remove any invalid routes without a node_id.
if tx.Migrator().HasTable(&types.Route{}) {
if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations
err := tx.Exec("delete from routes where node_id is null").Error
if err != nil {
return err
}
}
err := tx.AutoMigrate(&types.Route{})
err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
if err != nil {
return fmt.Errorf("automigrating types.Route: %w", err)
}
@@ -109,6 +111,7 @@ func NewHeadscaleDatabase(
if err != nil {
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
}
err = tx.AutoMigrate(&types.Node{})
if err != nil {
return fmt.Errorf("automigrating types.Node: %w", err)
@@ -155,7 +158,8 @@ AND auth_key_id NOT IN (
nodeRoutes := map[uint64][]netip.Prefix{}
var routes []types.Route
var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations
err = tx.Find(&routes).Error
if err != nil {
return fmt.Errorf("fetching routes: %w", err)
@@ -171,7 +175,7 @@ AND auth_key_id NOT IN (
tsaddr.SortPrefixes(routes)
routes = slices.Compact(routes)
data, err := json.Marshal(routes)
data, _ := json.Marshal(routes)
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
if err != nil {
@@ -180,7 +184,7 @@ AND auth_key_id NOT IN (
}
// Drop the old table.
_ = tx.Migrator().DropTable(&types.Route{})
_ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
return nil
},
@@ -256,10 +260,13 @@ AND auth_key_id NOT IN (
// Check if routes table exists and drop it (should have been migrated already)
var routesExists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
if err == nil && routesExists {
log.Info().Msg("dropping leftover routes table")
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
err := tx.Exec("DROP TABLE routes").Error
if err != nil {
return fmt.Errorf("dropping routes table: %w", err)
}
}
@@ -281,6 +288,7 @@ AND auth_key_id NOT IN (
for _, table := range tablesToRename {
// Check if table exists before renaming
var exists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
if err != nil {
return fmt.Errorf("checking if table %s exists: %w", table, err)
@@ -291,7 +299,8 @@ AND auth_key_id NOT IN (
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
// Rename current table to _old
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
if err != nil {
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
}
}
@@ -365,7 +374,8 @@ AND auth_key_id NOT IN (
}
for _, createSQL := range tableCreationSQL {
if err := tx.Exec(createSQL).Error; err != nil {
err := tx.Exec(createSQL).Error
if err != nil {
return fmt.Errorf("creating new table: %w", err)
}
}
@@ -394,7 +404,8 @@ AND auth_key_id NOT IN (
}
for _, copySQL := range dataCopySQL {
if err := tx.Exec(copySQL).Error; err != nil {
err := tx.Exec(copySQL).Error
if err != nil {
return fmt.Errorf("copying data: %w", err)
}
}
@@ -417,14 +428,16 @@ AND auth_key_id NOT IN (
}
for _, indexSQL := range indexes {
if err := tx.Exec(indexSQL).Error; err != nil {
err := tx.Exec(indexSQL).Error
if err != nil {
return fmt.Errorf("creating index: %w", err)
}
}
// Drop old tables only after everything succeeds
for _, table := range tablesToRename {
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
if err != nil {
log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded")
}
}
@@ -760,6 +773,7 @@ AND auth_key_id NOT IN (
// or else it blocks...
sqlConn.SetMaxIdleConns(maxIdleConns)
sqlConn.SetMaxOpenConns(maxOpenConns)
defer sqlConn.SetMaxIdleConns(1)
defer sqlConn.SetMaxOpenConns(1)
@@ -777,7 +791,7 @@ AND auth_key_id NOT IN (
},
}
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("validating schema: %w", err)
}
}
@@ -803,6 +817,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
switch cfg.Type {
case types.DatabaseSqlite:
dir := filepath.Dir(cfg.Sqlite.Path)
err := util.EnsureDir(dir)
if err != nil {
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
@@ -856,7 +871,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
Str("path", dbString).
Msg("Opening database")
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr
if !sslEnabled {
dbString += " sslmode=disable"
}
@@ -911,7 +926,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
// Get the current foreign key status
var fkOriginallyEnabled int
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("checking foreign key status: %w", err)
}
@@ -940,28 +955,31 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
if needsFKDisabled {
// Disable foreign keys for this migration
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
if err != nil {
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
}
} else {
// Ensure foreign keys are enabled for this migration
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
if err != nil {
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
}
}
// Run up to this specific migration (will only run the next pending migration)
if err := migrations.MigrateTo(migrationID); err != nil {
err := migrations.MigrateTo(migrationID)
if err != nil {
return fmt.Errorf("running migration %s: %w", migrationID, err)
}
}
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("restoring foreign keys: %w", err)
}
// Run the rest of the migrations
if err := migrations.Migrate(); err != nil {
if err := migrations.Migrate(); err != nil { //nolint:noinlineerr
return err
}
@@ -979,16 +997,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var violation constraintViolation
if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil {
err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex)
if err != nil {
return err
}
violatedConstraints = append(violatedConstraints, violation)
}
_ = rows.Close()
if err := rows.Err(); err != nil { //nolint:noinlineerr
return err
}
if len(violatedConstraints) > 0 {
for _, violation := range violatedConstraints {
@@ -1003,7 +1027,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
}
} else {
// PostgreSQL can run all migrations in one block - no foreign key issues
if err := migrations.Migrate(); err != nil {
err := migrations.Migrate()
if err != nil {
return err
}
}
@@ -1014,6 +1039,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
sqlDB, err := hsdb.DB.DB()
if err != nil {
return err
@@ -1029,7 +1055,7 @@ func (hsdb *HSDatabase) Close() error {
}
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
db.Exec("VACUUM")
db.Exec("VACUUM") //nolint:errcheck,noctx
}
return db.Close()
@@ -1038,12 +1064,14 @@ func (hsdb *HSDatabase) Close() error {
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
rx := hsdb.DB.Begin()
defer rx.Rollback()
return fn(rx)
}
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
rx := db.Begin()
defer rx.Rollback()
ret, err := fn(rx)
if err != nil {
var no T
@@ -1056,7 +1084,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
tx := hsdb.DB.Begin()
defer tx.Rollback()
if err := fn(tx); err != nil {
err := fn(tx)
if err != nil {
return err
}
@@ -1066,6 +1096,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
tx := db.Begin()
defer tx.Rollback()
ret, err := fn(tx)
if err != nil {
var no T

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"os"
"os/exec"
@@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
// Verify api_keys data preservation
var apiKeyCount int
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
require.NoError(t, err)
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
@@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
return err
}
_, err = db.Exec(string(schemaContent))
_, err = db.ExecContext(context.Background(), string(schemaContent))
return err
}
@@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
func requireConstraintFailed(t *testing.T, err error) {
t.Helper()
require.Error(t, err)
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
}
@@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) {
}{
{
name: "no-duplicate-username-if-no-oidc",
run: func(t *testing.T, db *gorm.DB) {
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
_, err := CreateUser(db, types.User{Name: "user1"})
require.NoError(t, err)
_, err = CreateUser(db, types.User{Name: "user1"})
@@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) {
},
{
name: "no-oidc-duplicate-username-and-id",
run: func(t *testing.T, db *gorm.DB) {
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
@@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) {
},
{
name: "no-oidc-duplicate-id",
run: func(t *testing.T, db *gorm.DB) {
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
@@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) {
},
{
name: "allow-duplicate-username-cli-then-oidc",
run: func(t *testing.T, db *gorm.DB) {
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
require.NoError(t, err)
@@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) {
},
{
name: "allow-duplicate-username-oidc-then-cli",
run: func(t *testing.T, db *gorm.DB) {
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
user := types.User{
Name: "user1",
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
@@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
}
// Construct the pg_restore command
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
// Set the output streams
cmd.Stdout = os.Stdout
@@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
// skip already-applied migrations and only run new ones.
func TestSQLiteAllTestdataMigrations(t *testing.T) {
t.Parallel()
schemas, err := os.ReadDir("testdata/sqlite")
require.NoError(t, err)

View File

@@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Basic deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var deletionWg sync.WaitGroup
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
deletionWg sync.WaitGroup
)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionWg.Done()
}
@@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
go gc.Start()
// Schedule several nodes for deletion with short expiry
const expiry = fifty
const numNodes = 100
const (
expiry = fifty
numNodes = 100
)
// Set up wait group for expected deletions
deletionWg.Add(numNodes)
for i := 1; i <= numNodes; i++ {
gc.Schedule(types.NodeID(i), expiry)
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test
}
// Wait for all scheduled deletions to complete
@@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// Schedule and immediately cancel to test that part of the code
for i := numNodes + 1; i <= numNodes*2; i++ {
nodeID := types.NodeID(i)
nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test
gc.Schedule(nodeID, time.Hour)
gc.Cancel(nodeID)
}
@@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
@@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Start GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
defer gc.Close()
const shortExpiry = fifty
const longExpiry = 1 * time.Hour
const (
shortExpiry = fifty
longExpiry = 1 * time.Hour
)
nodeID := types.NodeID(1)
@@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
}
// Start the GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
defer gc.Close()
nodeID := types.NodeID(1)
const expiry = fifty
// Schedule node for deletion
@@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
// Deletion tracking
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
@@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
nodeDeleted := make(chan struct{})
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
close(nodeDeleted) // Signal that deletion happened
}
@@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Use a WaitGroup to ensure the GC has started
var startWg sync.WaitGroup
startWg.Add(1)
go func() {
startWg.Done() // Signal that the goroutine has started
gc.Start()
}()
startWg.Wait() // Wait for the GC to start
// Close GC right away
@@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Check no node was deleted
deleteMutex.Lock()
nodesDeleted := len(deletedIDs)
deleteMutex.Unlock()
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
@@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
}
@@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
go gc.Start()
// Number of concurrent scheduling goroutines
const numSchedulers = 10
const nodesPerScheduler = 50
const (
numSchedulers = 10
nodesPerScheduler = 50
)
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
@@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
case <-stopScheduling:
return
default:
nodeID := types.NodeID(baseNodeID + j + 1)
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
atomic.AddInt64(&scheduledCount, 1)
// Yield to other goroutines to introduce variability

View File

@@ -17,7 +17,11 @@ import (
"tailscale.com/net/tsaddr"
)
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
var (
errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix")
errIPAllocatorNil = errors.New("ip allocator was nil")
)
// IPAllocator is a singleton responsible for allocating
// IP addresses for nodes and making sure the same
@@ -62,8 +66,10 @@ func NewIPAllocator(
strategy: strategy,
}
var v4s []sql.NullString
var v6s []sql.NullString
var (
v4s []sql.NullString
v6s []sql.NullString
)
if db != nil {
err := db.Read(func(rx *gorm.DB) error {
@@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
i.mu.Lock()
defer i.mu.Unlock()
var err error
var ret4 *netip.Addr
var ret6 *netip.Addr
var (
err error
ret4 *netip.Addr
ret6 *netip.Addr
)
if i.prefix4 != nil {
ret4, err = i.next(i.prev4, i.prefix4)
if err != nil {
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
}
i.prev4 = *ret4
}
@@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
if err != nil {
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
}
i.prev6 = *ret6
}
@@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.
}
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
var err error
var ip netip.Addr
var (
err error
ip netip.Addr
)
switch i.strategy {
case types.IPAllocationStrategySequential:
@@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
if !pfx.Contains(ip) {
return netip.Addr{}, fmt.Errorf(
"generated ip(%s) not in prefix(%s)",
"%w: ip(%s) not in prefix(%s)",
errGeneratedIPNotInPrefix,
ip.String(),
pfx.String(),
)
@@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool {
// If a prefix type has been removed (IPv4 or IPv6), it
// will remove the IPs in that family from the node.
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
var err error
var ret []string
var (
err error
ret []string
)
err = db.Write(func(tx *gorm.DB) error {
if i == nil {
return errors.New("backfilling IPs: ip allocator was nil")
return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil)
}
log.Trace().Caller().Msgf("starting to backfill IPs")
@@ -295,6 +311,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
node.IPv4 = ret4
changed = true
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
}
@@ -307,6 +324,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
node.IPv6 = ret6
changed = true
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
}

View File

@@ -21,9 +21,7 @@ var mpp = func(pref string) *netip.Prefix {
return &p
}
var na = func(pref string) netip.Addr {
return netip.MustParseAddr(pref)
}
var na = netip.MustParseAddr
var nap = func(pref string) *netip.Addr {
n := na(pref)
@@ -158,8 +156,10 @@ func TestIPAllocatorSequential(t *testing.T) {
types.IPAllocationStrategySequential,
)
var got4s []netip.Addr
var got6s []netip.Addr
var (
got4s []netip.Addr
got6s []netip.Addr
)
for range tt.getCount {
got4, got6, err := alloc.Next()
@@ -175,6 +175,7 @@ func TestIPAllocatorSequential(t *testing.T) {
got6s = append(got6s, *got6)
}
}
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
}
@@ -288,6 +289,7 @@ func TestBackfillIPAddresses(t *testing.T) {
fullNodeP := func(i int) *types.Node {
v4 := fmt.Sprintf("100.64.0.%d", i)
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
return &types.Node{
IPv4: nap(v4),
IPv6: nap(v6),
@@ -484,6 +486,7 @@ func TestBackfillIPAddresses(t *testing.T) {
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
defer db.Close()
alloc, err := NewIPAllocator(

View File

@@ -27,8 +27,14 @@ import (
const (
NodeGivenNameHashLength = 8
NodeGivenNameTrimSize = 2
// defaultTestNodePrefix is the default hostname prefix for nodes created in tests.
defaultTestNodePrefix = "testnode"
)
// ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique")
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var (
@@ -52,12 +58,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID)
// If at least one peer ID is given, only these peer nodes will be returned.
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Where("id <> ?", nodeID).
Where(peerIDs).Find(&nodes).Error; err != nil {
Where(peerIDs).Find(&nodes).Error
if err != nil {
return types.Nodes{}, err
}
@@ -76,11 +84,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error)
// or for the given nodes if at least one node ID is given as parameter.
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Where(nodeIDs).Find(&nodes).Error; err != nil {
Where(nodeIDs).Find(&nodes).Error
if err != nil {
return nil, err
}
@@ -90,7 +100,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{}
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
if err != nil {
return nil, err
}
@@ -222,7 +234,7 @@ func SetTags(
return nil
}
// SetTags takes a Node struct pointer and update the forced tags.
// SetApprovedRoutes takes a Node struct pointer and updates the approved routes.
func SetApprovedRoutes(
tx *gorm.DB,
nodeID types.NodeID,
@@ -254,7 +266,7 @@ func SetApprovedRoutes(
return err
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("updating approved routes: %w", err)
}
@@ -294,10 +306,10 @@ func RenameNode(tx *gorm.DB,
}
if count > 0 {
return errors.New("name is not unique")
return ErrNodeNameNotUnique
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("renaming node in database: %w", err)
}
@@ -329,7 +341,8 @@ func DeleteNode(tx *gorm.DB,
node *types.Node,
) error {
// Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error
if err != nil {
return err
}
@@ -343,9 +356,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
nodeID types.NodeID,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error
if err != nil {
return err
}
return nil
})
}
@@ -395,7 +410,8 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
// so we store the node.Expire and node.Nodekey that has been set when
// adding it to the registrationCache
if node.IPv4 != nil || node.IPv6 != nil {
if err := tx.Save(&node).Error; err != nil {
err := tx.Save(&node).Error
if err != nil {
return nil, fmt.Errorf("registering existing node in database: %w", err)
}
@@ -431,7 +447,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
node.GivenName = givenName
}
if err := tx.Save(&node).Error; err != nil {
if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("saving node to database: %w", err)
}
@@ -656,7 +672,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
panic("CreateNodeForTest requires a valid user")
}
nodeName := "testnode"
nodeName := defaultTestNodePrefix
if len(hostname) > 0 && hostname[0] != "" {
nodeName = hostname[0]
}
@@ -728,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname
panic("CreateNodesForTest requires a valid user")
}
prefix := "testnode"
prefix := defaultTestNodePrefix
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0]
}
@@ -751,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
panic("CreateRegisteredNodesForTest requires a valid user")
}
prefix := "testnode"
prefix := defaultTestNodePrefix
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0]
}

View File

@@ -187,6 +187,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
suppliedName string
randomSuffix bool
}
tests := []struct {
name string
args args
@@ -467,10 +468,10 @@ func TestAutoApproveRoutes(t *testing.T) {
require.NoError(t, err)
users, err := adb.ListUsers()
assert.NoError(t, err)
require.NoError(t, err)
nodes, err := adb.ListNodes()
assert.NoError(t, err)
require.NoError(t, err)
pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err)
@@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes1) == 0 {
expectedRoutes1 = nil
}
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
@@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes2) == 0 {
expectedRoutes2 = nil
}
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
@@ -520,6 +523,7 @@ func TestAutoApproveRoutes(t *testing.T) {
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
want := []types.NodeID{1, 3}
got := []types.NodeID{}
var mu sync.Mutex
deletionCount := make(chan struct{}, 10)
@@ -527,6 +531,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
mu.Lock()
defer mu.Unlock()
got = append(got, ni)
deletionCount <- struct{}{}
@@ -576,8 +581,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
}
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
var got []types.NodeID
var mu sync.Mutex
var (
got []types.NodeID
mu sync.Mutex
)
want := 1000
@@ -589,6 +596,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
// Yield to other goroutines to introduce variability
runtime.Gosched()
got = append(got, ni)
atomic.AddInt64(&deletedCount, 1)
@@ -616,9 +624,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
}
}
func generateRandomNumber(t *testing.T, max int64) int64 {
//nolint:unused
func generateRandomNumber(t *testing.T, maxVal int64) int64 {
t.Helper()
maxB := big.NewInt(max)
maxB := big.NewInt(maxVal)
n, err := rand.Int(rand.Reader, maxB)
if err != nil {
t.Fatalf("getting random number: %s", err)
@@ -722,7 +733,7 @@ func TestNodeNaming(t *testing.T) {
nodeInvalidHostname := types.Node{
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "我的电脑",
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
@@ -746,12 +757,15 @@ func TestNodeNaming(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
_, _ = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
return err
})
require.NoError(t, err)
@@ -810,25 +824,25 @@ func TestNodeNaming(t *testing.T) {
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "test")
})
assert.ErrorContains(t, err, "name is not unique")
require.ErrorContains(t, err, "name is not unique")
// Rename invalid chars
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[2].ID, "我的电脑")
return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data
})
assert.ErrorContains(t, err, "invalid characters")
require.ErrorContains(t, err, "invalid characters")
// Rename too short
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[3].ID, "a")
})
assert.ErrorContains(t, err, "at least 2 characters")
require.ErrorContains(t, err, "at least 2 characters")
// Rename with emoji
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
})
assert.ErrorContains(t, err, "invalid characters")
require.ErrorContains(t, err, "invalid characters")
// Rename with only emoji
err = db.Write(func(tx *gorm.DB) error {
@@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
},
{
name: "chinese_chars_with_dash_rejected",
newName: "server-北京-01",
newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
wantErr: "invalid characters",
},
{
name: "chinese_only_rejected",
newName: "我的电脑",
newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
wantErr: "invalid characters",
},
{
@@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
},
{
name: "mixed_chinese_emoji_rejected",
newName: "测试💻机器",
newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
wantErr: "invalid characters",
},
{
@@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
return err
@@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
return err

View File

@@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
Data: policy,
}
if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil {
err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error
if err != nil {
return nil, err
}

View File

@@ -138,7 +138,7 @@ func CreatePreAuthKey(
Hash: hash, // Store hash
}
if err := tx.Save(&key).Error; err != nil {
if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("creating key in database: %w", err)
}
@@ -155,9 +155,7 @@ func CreatePreAuthKey(
}
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeys(rx)
})
return Read(hsdb.DB, ListPreAuthKeys)
}
// ListPreAuthKeys returns all PreAuthKeys in the database.
@@ -329,10 +327,11 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
}
k.Used = true
return nil
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
// ExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error

View File

@@ -362,7 +362,8 @@ func (c *Config) Validate() error {
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
// compatible with modernc.org/sqlite driver.
func (c *Config) ToURL() (string, error) {
if err := c.Validate(); err != nil {
err := c.Validate()
if err != nil {
return "", fmt.Errorf("invalid config: %w", err)
}
@@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) {
if c.BusyTimeout > 0 {
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
}
if c.JournalMode != "" {
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
}
if c.AutoVacuum != "" {
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
}
if c.WALAutocheckpoint >= 0 {
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
}
if c.Synchronous != "" {
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
}
if c.ForeignKeys {
pragmas = append(pragmas, "foreign_keys=ON")
}

View File

@@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
t.Errorf("Config.ToURL() error = %v", err)
return
}
if got != tt.want {
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
}
@@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
Path: "",
BusyTimeout: -1,
}
_, err := config.ToURL()
if err == nil {
t.Error("Config.ToURL() with invalid config should return error")

View File

@@ -1,6 +1,7 @@
package sqliteconfig
import (
"context"
"database/sql"
"path/filepath"
"strings"
@@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
defer db.Close()
// Test connection
if err := db.Ping(); err != nil {
ctx := context.Background()
err = db.PingContext(ctx)
if err != nil {
t.Fatalf("Failed to ping database: %v", err)
}
@@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
for pragma, expectedValue := range tt.expected {
t.Run("pragma_"+pragma, func(t *testing.T) {
var actualValue any
query := "PRAGMA " + pragma
err := db.QueryRow(query).Scan(&actualValue)
err := db.QueryRowContext(ctx, query).Scan(&actualValue)
if err != nil {
t.Fatalf("Failed to query %s: %v", query, err)
}
@@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
}
defer db.Close()
ctx := context.Background()
// Create test tables with foreign key relationship
schema := `
CREATE TABLE parent (
@@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
);
`
if _, err := db.Exec(schema); err != nil {
_, err = db.ExecContext(ctx, schema)
if err != nil {
t.Fatalf("Failed to create schema: %v", err)
}
// Insert parent record
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
_, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')")
if err != nil {
t.Fatalf("Failed to insert parent: %v", err)
}
// Test 1: Valid foreign key should work
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
if err != nil {
t.Fatalf("Valid foreign key insert failed: %v", err)
}
// Test 2: Invalid foreign key should fail
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
if err == nil {
t.Error("Expected foreign key constraint violation, but insert succeeded")
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
@@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
}
// Test 3: Deleting referenced parent should fail
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
_, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1")
if err == nil {
t.Error("Expected foreign key constraint violation when deleting referenced parent")
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
@@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) {
defer db.Close()
var actualMode string
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
if err != nil {
t.Fatalf("Failed to query journal_mode: %v", err)
}

View File

@@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL {
t.Helper()
ctx := t.Context()
srv, err := postgrestest.Start(ctx)
if err != nil {
t.Fatal(err)
}
t.Cleanup(srv.Cleanup)
u, err := srv.CreateDatabase(ctx)
if err != nil {
t.Fatal(err)
}
t.Logf("created local postgres: %s", u)
pu, _ := url.Parse(u)

View File

@@ -3,12 +3,19 @@ package db
import (
"context"
"encoding"
"errors"
"fmt"
"reflect"
"gorm.io/gorm/schema"
)
var (
errUnmarshalTextValue = errors.New("unmarshalling text value")
errUnsupportedType = errors.New("unsupported type")
errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
)
// Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
@@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("unmarshalling text value: %#v", dbValue)
return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue)
}
if isTextUnmarshaler(fieldValue) {
maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args)
if !ret[0].IsNil() {
return decodingError(field.Name, ret[0].Interface().(error))
if err, ok := ret[0].Interface().(error); ok {
return decodingError(field.Name, err)
}
}
// If the underlying field is to a pointer type, we need to
@@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
return nil
} else {
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface())
}
}
@@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
// always comparable, particularly when reflection is involved:
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
return nil, nil
return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer
}
b, err := v.MarshalText()
if err != nil {
return nil, err
@@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
return string(b), nil
default:
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v)
}
}

View File

@@ -12,9 +12,11 @@ import (
)
var (
ErrUserExists = errors.New("user already exists")
ErrUserNotFound = errors.New("user not found")
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
ErrUserExists = errors.New("user already exists")
ErrUserNotFound = errors.New("user not found")
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
ErrUserNotUnique = errors.New("expected exactly one user")
)
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
@@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
if err := util.ValidateHostname(user.Name); err != nil {
err := util.ValidateHostname(user.Name)
if err != nil {
return nil, err
}
if err := tx.Create(&user).Error; err != nil {
err = tx.Create(&user).Error
if err != nil {
return nil, fmt.Errorf("creating user: %w", err)
}
@@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
if err != nil {
return err
}
if len(nodes) > 0 {
return ErrUserStillHasNodes
}
@@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
if err != nil {
return err
}
for _, key := range keys {
err = DestroyPreAuthKey(tx, key.ID)
if err != nil {
@@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
// not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
var err error
oldUser, err := GetUserByID(tx, uid)
if err != nil {
return err
}
if err = util.ValidateHostname(newName); err != nil {
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
return err
}
@@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
// ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
if len(where) > 1 {
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
}
var user *types.User
@@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
}
users := []types.User{}
if err := tx.Where(user).Find(&users).Error; err != nil {
err := tx.Where(user).Find(&users).Error
if err != nil {
return nil, err
}
@@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
}
if len(users) != 1 {
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
}
return &users[0], nil