mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-24 17:48:49 +02:00
all: fix golangci-lint issues (#3064)
This commit is contained in:
@@ -77,7 +77,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil {
|
||||
if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||
return "", nil, fmt.Errorf("saving API key to database: %w", err)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
keys := []types.APIKey{}
|
||||
if err := hsdb.DB.Find(&keys).Error; err != nil {
|
||||
|
||||
err := hsdb.DB.Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||
|
||||
// ExpireAPIKey marks a ApiKey as expired.
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ type HSDatabase struct {
|
||||
|
||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||
// It accepts the full configuration to allow migrations access to policy settings.
|
||||
//
|
||||
//nolint:gocyclo // complex database initialization with many migrations
|
||||
func NewHeadscaleDatabase(
|
||||
cfg *types.Config,
|
||||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
|
||||
@@ -76,7 +78,7 @@ func NewHeadscaleDatabase(
|
||||
ID: "202501221827",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Remove any invalid routes associated with a node that does not exist.
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -84,14 +86,14 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
|
||||
// Remove any invalid routes without a node_id.
|
||||
if tx.Migrator().HasTable(&types.Route{}) {
|
||||
if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
err := tx.Exec("delete from routes where node_id is null").Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := tx.AutoMigrate(&types.Route{})
|
||||
err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Route: %w", err)
|
||||
}
|
||||
@@ -109,6 +111,7 @@ func NewHeadscaleDatabase(
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
|
||||
}
|
||||
|
||||
err = tx.AutoMigrate(&types.Node{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Node: %w", err)
|
||||
@@ -155,7 +158,8 @@ AND auth_key_id NOT IN (
|
||||
|
||||
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||
|
||||
var routes []types.Route
|
||||
var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
|
||||
err = tx.Find(&routes).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching routes: %w", err)
|
||||
@@ -171,7 +175,7 @@ AND auth_key_id NOT IN (
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
data, _ := json.Marshal(routes)
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||
if err != nil {
|
||||
@@ -180,7 +184,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
// Drop the old table.
|
||||
_ = tx.Migrator().DropTable(&types.Route{})
|
||||
_ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
|
||||
|
||||
return nil
|
||||
},
|
||||
@@ -256,10 +260,13 @@ AND auth_key_id NOT IN (
|
||||
|
||||
// Check if routes table exists and drop it (should have been migrated already)
|
||||
var routesExists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
||||
if err == nil && routesExists {
|
||||
log.Info().Msg("dropping leftover routes table")
|
||||
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
|
||||
|
||||
err := tx.Exec("DROP TABLE routes").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("dropping routes table: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -281,6 +288,7 @@ AND auth_key_id NOT IN (
|
||||
for _, table := range tablesToRename {
|
||||
// Check if table exists before renaming
|
||||
var exists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
||||
@@ -291,7 +299,8 @@ AND auth_key_id NOT IN (
|
||||
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
|
||||
// Rename current table to _old
|
||||
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
|
||||
}
|
||||
}
|
||||
@@ -365,7 +374,8 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, createSQL := range tableCreationSQL {
|
||||
if err := tx.Exec(createSQL).Error; err != nil {
|
||||
err := tx.Exec(createSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating new table: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -394,7 +404,8 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, copySQL := range dataCopySQL {
|
||||
if err := tx.Exec(copySQL).Error; err != nil {
|
||||
err := tx.Exec(copySQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("copying data: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -417,14 +428,16 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for _, indexSQL := range indexes {
|
||||
if err := tx.Exec(indexSQL).Error; err != nil {
|
||||
err := tx.Exec(indexSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Drop old tables only after everything succeeds
|
||||
for _, table := range tablesToRename {
|
||||
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
if err != nil {
|
||||
log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded")
|
||||
}
|
||||
}
|
||||
@@ -760,6 +773,7 @@ AND auth_key_id NOT IN (
|
||||
|
||||
// or else it blocks...
|
||||
sqlConn.SetMaxIdleConns(maxIdleConns)
|
||||
|
||||
sqlConn.SetMaxOpenConns(maxOpenConns)
|
||||
defer sqlConn.SetMaxIdleConns(1)
|
||||
defer sqlConn.SetMaxOpenConns(1)
|
||||
@@ -777,7 +791,7 @@ AND auth_key_id NOT IN (
|
||||
},
|
||||
}
|
||||
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("validating schema: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -803,6 +817,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
switch cfg.Type {
|
||||
case types.DatabaseSqlite:
|
||||
dir := filepath.Dir(cfg.Sqlite.Path)
|
||||
|
||||
err := util.EnsureDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating directory for sqlite: %w", err)
|
||||
@@ -856,7 +871,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||
Str("path", dbString).
|
||||
Msg("Opening database")
|
||||
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil {
|
||||
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr
|
||||
if !sslEnabled {
|
||||
dbString += " sslmode=disable"
|
||||
}
|
||||
@@ -911,7 +926,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
|
||||
// Get the current foreign key status
|
||||
var fkOriginallyEnabled int
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("checking foreign key status: %w", err)
|
||||
}
|
||||
|
||||
@@ -940,28 +955,31 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
|
||||
if needsFKDisabled {
|
||||
// Disable foreign keys for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
} else {
|
||||
// Ensure foreign keys are enabled for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run up to this specific migration (will only run the next pending migration)
|
||||
if err := migrations.MigrateTo(migrationID); err != nil {
|
||||
err := migrations.MigrateTo(migrationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("restoring foreign keys: %w", err)
|
||||
}
|
||||
|
||||
// Run the rest of the migrations
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
if err := migrations.Migrate(); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -979,16 +997,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var violation constraintViolation
|
||||
if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil {
|
||||
|
||||
err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
violatedConstraints = append(violatedConstraints, violation)
|
||||
}
|
||||
_ = rows.Close()
|
||||
|
||||
if err := rows.Err(); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
if len(violatedConstraints) > 0 {
|
||||
for _, violation := range violatedConstraints {
|
||||
@@ -1003,7 +1027,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
}
|
||||
} else {
|
||||
// PostgreSQL can run all migrations in one block - no foreign key issues
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
err := migrations.Migrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1014,6 +1039,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
|
||||
sqlDB, err := hsdb.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1029,7 +1055,7 @@ func (hsdb *HSDatabase) Close() error {
|
||||
}
|
||||
|
||||
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
||||
db.Exec("VACUUM")
|
||||
db.Exec("VACUUM") //nolint:errcheck,noctx
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
@@ -1038,12 +1064,14 @@ func (hsdb *HSDatabase) Close() error {
|
||||
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
|
||||
rx := hsdb.DB.Begin()
|
||||
defer rx.Rollback()
|
||||
|
||||
return fn(rx)
|
||||
}
|
||||
|
||||
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
rx := db.Begin()
|
||||
defer rx.Rollback()
|
||||
|
||||
ret, err := fn(rx)
|
||||
if err != nil {
|
||||
var no T
|
||||
@@ -1056,7 +1084,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
|
||||
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
tx := hsdb.DB.Begin()
|
||||
defer tx.Rollback()
|
||||
if err := fn(tx); err != nil {
|
||||
|
||||
err := fn(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1066,6 +1096,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
|
||||
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
|
||||
tx := db.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
ret, err := fn(tx)
|
||||
if err != nil {
|
||||
var no T
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
|
||||
// Verify api_keys data preservation
|
||||
var apiKeyCount int
|
||||
|
||||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||||
@@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(schemaContent))
|
||||
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
||||
func requireConstraintFailed(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||
}
|
||||
@@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "no-duplicate-username-if-no-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
_, err := CreateUser(db, types.User{Name: "user1"})
|
||||
require.NoError(t, err)
|
||||
_, err = CreateUser(db, types.User{Name: "user1"})
|
||||
@@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-username-and-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
@@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-oidc-duplicate-id",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
@@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-cli-then-oidc",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "allow-duplicate-username-oidc-then-cli",
|
||||
run: func(t *testing.T, db *gorm.DB) {
|
||||
run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
|
||||
user := types.User{
|
||||
Name: "user1",
|
||||
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||||
@@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Construct the pg_restore command
|
||||
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
|
||||
// Set the output streams
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
||||
// skip already-applied migrations and only run new ones.
|
||||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemas, err := os.ReadDir("testdata/sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Basic deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var deletionWg sync.WaitGroup
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
deletionWg sync.WaitGroup
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
deletionWg.Done()
|
||||
}
|
||||
@@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
go gc.Start()
|
||||
|
||||
// Schedule several nodes for deletion with short expiry
|
||||
const expiry = fifty
|
||||
const numNodes = 100
|
||||
const (
|
||||
expiry = fifty
|
||||
numNodes = 100
|
||||
)
|
||||
|
||||
// Set up wait group for expected deletions
|
||||
|
||||
deletionWg.Add(numNodes)
|
||||
|
||||
for i := 1; i <= numNodes; i++ {
|
||||
gc.Schedule(types.NodeID(i), expiry)
|
||||
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test
|
||||
}
|
||||
|
||||
// Wait for all scheduled deletions to complete
|
||||
@@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
|
||||
// Schedule and immediately cancel to test that part of the code
|
||||
for i := numNodes + 1; i <= numNodes*2; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test
|
||||
gc.Schedule(nodeID, time.Hour)
|
||||
gc.Cancel(nodeID)
|
||||
}
|
||||
@@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
||||
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
@@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
|
||||
// Start GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
const shortExpiry = fifty
|
||||
const longExpiry = 1 * time.Hour
|
||||
const (
|
||||
shortExpiry = fifty
|
||||
longExpiry = 1 * time.Hour
|
||||
)
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
@@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
}
|
||||
|
||||
// Start the GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
const expiry = fifty
|
||||
|
||||
// Schedule node for deletion
|
||||
@@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
||||
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
@@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
nodeDeleted := make(chan struct{})
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
close(nodeDeleted) // Signal that deletion happened
|
||||
}
|
||||
@@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
// Use a WaitGroup to ensure the GC has started
|
||||
var startWg sync.WaitGroup
|
||||
startWg.Add(1)
|
||||
|
||||
go func() {
|
||||
startWg.Done() // Signal that the goroutine has started
|
||||
gc.Start()
|
||||
}()
|
||||
|
||||
startWg.Wait() // Wait for the GC to start
|
||||
|
||||
// Close GC right away
|
||||
@@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
||||
|
||||
// Check no node was deleted
|
||||
deleteMutex.Lock()
|
||||
|
||||
nodesDeleted := len(deletedIDs)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
||||
|
||||
@@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
go gc.Start()
|
||||
|
||||
// Number of concurrent scheduling goroutines
|
||||
const numSchedulers = 10
|
||||
const nodesPerScheduler = 50
|
||||
const (
|
||||
numSchedulers = 10
|
||||
nodesPerScheduler = 50
|
||||
)
|
||||
|
||||
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
||||
|
||||
@@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
||||
case <-stopScheduling:
|
||||
return
|
||||
default:
|
||||
nodeID := types.NodeID(baseNodeID + j + 1)
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
atomic.AddInt64(&scheduledCount, 1)
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
|
||||
@@ -17,7 +17,11 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||
var (
|
||||
errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
|
||||
errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix")
|
||||
errIPAllocatorNil = errors.New("ip allocator was nil")
|
||||
)
|
||||
|
||||
// IPAllocator is a singleton responsible for allocating
|
||||
// IP addresses for nodes and making sure the same
|
||||
@@ -62,8 +66,10 @@ func NewIPAllocator(
|
||||
strategy: strategy,
|
||||
}
|
||||
|
||||
var v4s []sql.NullString
|
||||
var v6s []sql.NullString
|
||||
var (
|
||||
v4s []sql.NullString
|
||||
v6s []sql.NullString
|
||||
)
|
||||
|
||||
if db != nil {
|
||||
err := db.Read(func(rx *gorm.DB) error {
|
||||
@@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
var err error
|
||||
var ret4 *netip.Addr
|
||||
var ret6 *netip.Addr
|
||||
var (
|
||||
err error
|
||||
ret4 *netip.Addr
|
||||
ret6 *netip.Addr
|
||||
)
|
||||
|
||||
if i.prefix4 != nil {
|
||||
ret4, err = i.next(i.prev4, i.prefix4)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
|
||||
}
|
||||
|
||||
i.prev4 = *ret4
|
||||
}
|
||||
|
||||
@@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
|
||||
}
|
||||
|
||||
i.prev6 = *ret6
|
||||
}
|
||||
|
||||
@@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.
|
||||
}
|
||||
|
||||
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
|
||||
var err error
|
||||
var ip netip.Addr
|
||||
var (
|
||||
err error
|
||||
ip netip.Addr
|
||||
)
|
||||
|
||||
switch i.strategy {
|
||||
case types.IPAllocationStrategySequential:
|
||||
@@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
|
||||
|
||||
if !pfx.Contains(ip) {
|
||||
return netip.Addr{}, fmt.Errorf(
|
||||
"generated ip(%s) not in prefix(%s)",
|
||||
"%w: ip(%s) not in prefix(%s)",
|
||||
errGeneratedIPNotInPrefix,
|
||||
ip.String(),
|
||||
pfx.String(),
|
||||
)
|
||||
@@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool {
|
||||
// If a prefix type has been removed (IPv4 or IPv6), it
|
||||
// will remove the IPs in that family from the node.
|
||||
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
var err error
|
||||
var ret []string
|
||||
var (
|
||||
err error
|
||||
ret []string
|
||||
)
|
||||
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
if i == nil {
|
||||
return errors.New("backfilling IPs: ip allocator was nil")
|
||||
return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil)
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("starting to backfill IPs")
|
||||
@@ -295,6 +311,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
|
||||
node.IPv4 = ret4
|
||||
changed = true
|
||||
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
@@ -307,6 +324,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
|
||||
|
||||
node.IPv6 = ret6
|
||||
changed = true
|
||||
|
||||
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
|
||||
}
|
||||
|
||||
|
||||
@@ -21,9 +21,7 @@ var mpp = func(pref string) *netip.Prefix {
|
||||
return &p
|
||||
}
|
||||
|
||||
var na = func(pref string) netip.Addr {
|
||||
return netip.MustParseAddr(pref)
|
||||
}
|
||||
var na = netip.MustParseAddr
|
||||
|
||||
var nap = func(pref string) *netip.Addr {
|
||||
n := na(pref)
|
||||
@@ -158,8 +156,10 @@ func TestIPAllocatorSequential(t *testing.T) {
|
||||
types.IPAllocationStrategySequential,
|
||||
)
|
||||
|
||||
var got4s []netip.Addr
|
||||
var got6s []netip.Addr
|
||||
var (
|
||||
got4s []netip.Addr
|
||||
got6s []netip.Addr
|
||||
)
|
||||
|
||||
for range tt.getCount {
|
||||
got4, got6, err := alloc.Next()
|
||||
@@ -175,6 +175,7 @@ func TestIPAllocatorSequential(t *testing.T) {
|
||||
got6s = append(got6s, *got6)
|
||||
}
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
|
||||
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -288,6 +289,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||
fullNodeP := func(i int) *types.Node {
|
||||
v4 := fmt.Sprintf("100.64.0.%d", i)
|
||||
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
|
||||
|
||||
return &types.Node{
|
||||
IPv4: nap(v4),
|
||||
IPv6: nap(v6),
|
||||
@@ -484,6 +486,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer db.Close()
|
||||
|
||||
alloc, err := NewIPAllocator(
|
||||
|
||||
@@ -27,8 +27,14 @@ import (
|
||||
const (
|
||||
NodeGivenNameHashLength = 8
|
||||
NodeGivenNameTrimSize = 2
|
||||
|
||||
// defaultTestNodePrefix is the default hostname prefix for nodes created in tests.
|
||||
defaultTestNodePrefix = "testnode"
|
||||
)
|
||||
|
||||
// ErrNodeNameNotUnique is returned when a node name is not unique.
|
||||
var ErrNodeNameNotUnique = errors.New("node name is not unique")
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
var (
|
||||
@@ -52,12 +58,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID)
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
|
||||
err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where("id <> ?", nodeID).
|
||||
Where(peerIDs).Find(&nodes).Error; err != nil {
|
||||
Where(peerIDs).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return types.Nodes{}, err
|
||||
}
|
||||
|
||||
@@ -76,11 +84,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error)
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
|
||||
err := tx.
|
||||
Preload("AuthKey").
|
||||
Preload("AuthKey.User").
|
||||
Preload("User").
|
||||
Where(nodeIDs).Find(&nodes).Error; err != nil {
|
||||
Where(nodeIDs).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -90,7 +100,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
|
||||
|
||||
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -222,7 +234,7 @@ func SetTags(
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
// SetApprovedRoutes takes a Node struct pointer and updates the approved routes.
|
||||
func SetApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
@@ -254,7 +266,7 @@ func SetApprovedRoutes(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("updating approved routes: %w", err)
|
||||
}
|
||||
|
||||
@@ -294,10 +306,10 @@ func RenameNode(tx *gorm.DB,
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("name is not unique")
|
||||
return ErrNodeNameNotUnique
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr
|
||||
return fmt.Errorf("renaming node in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -329,7 +341,8 @@ func DeleteNode(tx *gorm.DB,
|
||||
node *types.Node,
|
||||
) error {
|
||||
// Unscoped causes the node to be fully removed from the database.
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
|
||||
err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -343,9 +356,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
|
||||
nodeID types.NodeID,
|
||||
) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil {
|
||||
err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -395,7 +410,8 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
// so we store the node.Expire and node.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if node.IPv4 != nil || node.IPv6 != nil {
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
err := tx.Save(&node).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registering existing node in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -431,7 +447,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
||||
node.GivenName = givenName
|
||||
}
|
||||
|
||||
if err := tx.Save(&node).Error; err != nil {
|
||||
if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("saving node to database: %w", err)
|
||||
}
|
||||
|
||||
@@ -656,7 +672,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
||||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
nodeName := defaultTestNodePrefix
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
@@ -728,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname
|
||||
panic("CreateNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
prefix := defaultTestNodePrefix
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
@@ -751,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
|
||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
prefix := defaultTestNodePrefix
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
@@ -187,6 +187,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
suppliedName string
|
||||
randomSuffix bool
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -467,10 +468,10 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
@@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -520,6 +523,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
want := []types.NodeID{1, 3}
|
||||
got := []types.NodeID{}
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
deletionCount := make(chan struct{}, 10)
|
||||
@@ -527,6 +531,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
got = append(got, ni)
|
||||
|
||||
deletionCount <- struct{}{}
|
||||
@@ -576,8 +581,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
var got []types.NodeID
|
||||
var mu sync.Mutex
|
||||
var (
|
||||
got []types.NodeID
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
want := 1000
|
||||
|
||||
@@ -589,6 +596,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
runtime.Gosched()
|
||||
|
||||
got = append(got, ni)
|
||||
|
||||
atomic.AddInt64(&deletedCount, 1)
|
||||
@@ -616,9 +624,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateRandomNumber(t *testing.T, max int64) int64 {
|
||||
//nolint:unused
|
||||
func generateRandomNumber(t *testing.T, maxVal int64) int64 {
|
||||
t.Helper()
|
||||
maxB := big.NewInt(max)
|
||||
|
||||
maxB := big.NewInt(maxVal)
|
||||
|
||||
n, err := rand.Int(rand.Reader, maxB)
|
||||
if err != nil {
|
||||
t.Fatalf("getting random number: %s", err)
|
||||
@@ -722,7 +733,7 @@ func TestNodeNaming(t *testing.T) {
|
||||
nodeInvalidHostname := types.Node{
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
UserID: &user2.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
@@ -746,12 +757,15 @@ func TestNodeNaming(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
||||
|
||||
_, _ = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -810,25 +824,25 @@ func TestNodeNaming(t *testing.T) {
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "test")
|
||||
})
|
||||
assert.ErrorContains(t, err, "name is not unique")
|
||||
require.ErrorContains(t, err, "name is not unique")
|
||||
|
||||
// Rename invalid chars
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑")
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename too short
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[3].ID, "a")
|
||||
})
|
||||
assert.ErrorContains(t, err, "at least 2 characters")
|
||||
require.ErrorContains(t, err, "at least 2 characters")
|
||||
|
||||
// Rename with emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename with only emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
@@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "chinese_chars_with_dash_rejected",
|
||||
newName: "server-北京-01",
|
||||
newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "chinese_only_rejected",
|
||||
newName: "我的电脑",
|
||||
newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
@@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "mixed_chinese_emoji_rejected",
|
||||
newName: "测试💻机器",
|
||||
newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
@@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
@@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
||||
@@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
|
||||
Data: policy,
|
||||
}
|
||||
|
||||
if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil {
|
||||
err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ func CreatePreAuthKey(
|
||||
Hash: hash, // Store hash
|
||||
}
|
||||
|
||||
if err := tx.Save(&key).Error; err != nil {
|
||||
if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr
|
||||
return nil, fmt.Errorf("creating key in database: %w", err)
|
||||
}
|
||||
|
||||
@@ -155,9 +155,7 @@ func CreatePreAuthKey(
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeys(rx)
|
||||
})
|
||||
return Read(hsdb.DB, ListPreAuthKeys)
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns all PreAuthKeys in the database.
|
||||
@@ -329,10 +327,11 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
}
|
||||
|
||||
k.Used = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
// ExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||
|
||||
@@ -362,7 +362,8 @@ func (c *Config) Validate() error {
|
||||
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
|
||||
// compatible with modernc.org/sqlite driver.
|
||||
func (c *Config) ToURL() (string, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
err := c.Validate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
@@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) {
|
||||
if c.BusyTimeout > 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
||||
}
|
||||
|
||||
if c.JournalMode != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
||||
}
|
||||
|
||||
if c.AutoVacuum != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
||||
}
|
||||
|
||||
if c.WALAutocheckpoint >= 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
||||
}
|
||||
|
||||
if c.Synchronous != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
||||
}
|
||||
|
||||
if c.ForeignKeys {
|
||||
pragmas = append(pragmas, "foreign_keys=ON")
|
||||
}
|
||||
|
||||
@@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
|
||||
t.Errorf("Config.ToURL() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
@@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
|
||||
Path: "",
|
||||
BusyTimeout: -1,
|
||||
}
|
||||
|
||||
_, err := config.ToURL()
|
||||
if err == nil {
|
||||
t.Error("Config.ToURL() with invalid config should return error")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqliteconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
ctx := context.Background()
|
||||
|
||||
err = db.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
@@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
||||
for pragma, expectedValue := range tt.expected {
|
||||
t.Run("pragma_"+pragma, func(t *testing.T) {
|
||||
var actualValue any
|
||||
|
||||
query := "PRAGMA " + pragma
|
||||
err := db.QueryRow(query).Scan(&actualValue)
|
||||
|
||||
err := db.QueryRowContext(ctx, query).Scan(&actualValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query %s: %v", query, err)
|
||||
}
|
||||
@@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test tables with foreign key relationship
|
||||
schema := `
|
||||
CREATE TABLE parent (
|
||||
@@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
_, err = db.ExecContext(ctx, schema)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create schema: %v", err)
|
||||
}
|
||||
|
||||
// Insert parent record
|
||||
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert parent: %v", err)
|
||||
}
|
||||
|
||||
// Test 1: Valid foreign key should work
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Valid foreign key insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Test 2: Invalid foreign key should fail
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
_, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
@@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test 3: Deleting referenced parent should fail
|
||||
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
|
||||
_, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
@@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
var actualMode string
|
||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
||||
|
||||
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||
}
|
||||
|
||||
@@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL {
|
||||
t.Helper()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
srv, err := postgrestest.Start(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Cleanup(srv.Cleanup)
|
||||
|
||||
u, err := srv.CreateDatabase(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("created local postgres: %s", u)
|
||||
pu, _ := url.Parse(u)
|
||||
|
||||
|
||||
@@ -3,12 +3,19 @@ package db
|
||||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnmarshalTextValue = errors.New("unmarshalling text value")
|
||||
errUnsupportedType = errors.New("unsupported type")
|
||||
errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
|
||||
)
|
||||
|
||||
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
|
||||
|
||||
@@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("unmarshalling text value: %#v", dbValue)
|
||||
return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue)
|
||||
}
|
||||
|
||||
if isTextUnmarshaler(fieldValue) {
|
||||
maybeInstantiatePtr(fieldValue)
|
||||
f := fieldValue.MethodByName("UnmarshalText")
|
||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||
|
||||
ret := f.Call(args)
|
||||
if !ret[0].IsNil() {
|
||||
return decodingError(field.Name, ret[0].Interface().(error))
|
||||
if err, ok := ret[0].Interface().(error); ok {
|
||||
return decodingError(field.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If the underlying field is to a pointer type, we need to
|
||||
@@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||||
return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
||||
// always comparable, particularly when reflection is involved:
|
||||
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer
|
||||
}
|
||||
|
||||
b, err := v.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
||||
|
||||
return string(b), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||||
return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
|
||||
ErrUserNotUnique = errors.New("expected exactly one user")
|
||||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
@@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
||||
if err := util.ValidateHostname(user.Name); err != nil {
|
||||
err := util.ValidateHostname(user.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
|
||||
err = tx.Create(&user).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
||||
@@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
return ErrUserStillHasNodes
|
||||
}
|
||||
@@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
err = DestroyPreAuthKey(tx, key.ID)
|
||||
if err != nil {
|
||||
@@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
||||
// not exist or if another User exists with the new name.
|
||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
var err error
|
||||
|
||||
oldUser, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = util.ValidateHostname(newName); err != nil {
|
||||
|
||||
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||
// ListUsers gets all the existing users.
|
||||
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
if len(where) > 1 {
|
||||
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
||||
return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
|
||||
}
|
||||
|
||||
var user *types.User
|
||||
@@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
}
|
||||
|
||||
users := []types.User{}
|
||||
if err := tx.Where(user).Find(&users).Error; err != nil {
|
||||
|
||||
err := tx.Where(user).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||
}
|
||||
|
||||
if len(users) != 1 {
|
||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||
return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
|
||||
}
|
||||
|
||||
return &users[0], nil
|
||||
|
||||
Reference in New Issue
Block a user