mirror of
https://github.com/juanfont/headscale.git
synced 2026-03-19 16:21:23 +01:00
Tagged nodes are owned by their tags, not a user. Enforce this invariant at every write path: - createAndSaveNewNode: do not set UserID for tagged PreAuthKey registration; clear UserID when advertise-tags are applied during OIDC/CLI registration - SetNodeTags: clear UserID/User when tags are assigned - processReauthTags: clear UserID/User when tags are applied during re-authentication - validateNodeOwnership: reject tagged nodes with non-nil UserID - NodeStore: skip nodesByUser indexing for tagged nodes since they have no owning user - HandleNodeFromPreAuthKey: add fallback lookup for tagged PAK re-registration (tagged nodes indexed under UserID(0)); guard against nil User deref for tagged nodes in different-user check Since tagged nodes now have user_id = NULL, ListNodesByUser will not return them and DestroyUser naturally allows deleting users whose nodes have all been tagged. The ON DELETE CASCADE FK cannot reach tagged nodes through a NULL foreign key. Also tone down shouty comments throughout state.go. Fixes #3077
251 lines
5.6 KiB
Go
251 lines
5.6 KiB
Go
package db
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrUserExists = errors.New("user already exists")
|
|
ErrUserNotFound = errors.New("user not found")
|
|
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
|
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
|
|
ErrUserNotUnique = errors.New("expected exactly one user")
|
|
)
|
|
|
|
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
|
|
return CreateUser(tx, user)
|
|
})
|
|
}
|
|
|
|
// CreateUser creates a new User. Returns error if could not be created
|
|
// or another user already exists.
|
|
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
|
err := util.ValidateHostname(user.Name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
err = tx.Create(&user).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating user: %w", err)
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error {
|
|
return hsdb.Write(func(tx *gorm.DB) error {
|
|
return DestroyUser(tx, uid)
|
|
})
|
|
}
|
|
|
|
// DestroyUser destroys a User. Returns error if the User does
|
|
// not exist or if there are user-owned nodes associated with it.
|
|
// Tagged nodes have user_id = NULL so they do not block deletion.
|
|
func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
|
user, err := GetUserByID(tx, uid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
nodes, err := ListNodesByUser(tx, uid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(nodes) > 0 {
|
|
return ErrUserStillHasNodes
|
|
}
|
|
|
|
keys, err := ListPreAuthKeys(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, key := range keys {
|
|
err = DestroyPreAuthKey(tx, key.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if result := tx.Unscoped().Delete(&user); result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error {
|
|
return hsdb.Write(func(tx *gorm.DB) error {
|
|
return RenameUser(tx, uid, newName)
|
|
})
|
|
}
|
|
|
|
var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
|
|
|
// RenameUser renames a User. Returns error if the User does
|
|
// not exist or if another User exists with the new name.
|
|
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
|
var err error
|
|
|
|
oldUser, err := GetUserByID(tx, uid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
|
|
return err
|
|
}
|
|
|
|
if oldUser.Provider == util.RegisterMethodOIDC {
|
|
return ErrCannotChangeOIDCUser
|
|
}
|
|
|
|
oldUser.Name = newName
|
|
|
|
err = tx.Updates(&oldUser).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
|
return GetUserByID(hsdb.DB, uid)
|
|
}
|
|
|
|
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
|
user := types.User{}
|
|
if result := tx.First(&user, "id = ?", uid); errors.Is(
|
|
result.Error,
|
|
gorm.ErrRecordNotFound,
|
|
) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) GetUserByOIDCIdentifier(id string) (*types.User, error) {
|
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
|
return GetUserByOIDCIdentifier(rx, id)
|
|
})
|
|
}
|
|
|
|
func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
|
user := types.User{}
|
|
if result := tx.First(&user, "provider_identifier = ?", id); errors.Is(
|
|
result.Error,
|
|
gorm.ErrRecordNotFound,
|
|
) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
|
return ListUsers(hsdb.DB, where...)
|
|
}
|
|
|
|
// ListUsers gets all the existing users.
|
|
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
|
if len(where) > 1 {
|
|
return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
|
|
}
|
|
|
|
var user *types.User
|
|
if len(where) == 1 {
|
|
user = where[0]
|
|
}
|
|
|
|
users := []types.User{}
|
|
|
|
err := tx.Where(user).Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// GetUserByName returns a user if the provided username is
|
|
// unique, and otherwise an error.
|
|
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
|
users, err := hsdb.ListUsers(&types.User{Name: name})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(users) == 0 {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
|
|
if len(users) != 1 {
|
|
return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
|
|
}
|
|
|
|
return &users[0], nil
|
|
}
|
|
|
|
// ListNodesByUser gets all the nodes in a given user.
|
|
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
|
nodes := types.Nodes{}
|
|
|
|
uidPtr := uint(uid)
|
|
|
|
err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: &uidPtr}).Find(&nodes).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return nodes, nil
|
|
}
|
|
|
|
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
|
|
if !testing.Testing() {
|
|
panic("CreateUserForTest can only be called during tests")
|
|
}
|
|
|
|
userName := "testuser"
|
|
if len(name) > 0 && name[0] != "" {
|
|
userName = name[0]
|
|
}
|
|
|
|
user, err := hsdb.CreateUser(types.User{Name: userName})
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to create test user: %v", err))
|
|
}
|
|
|
|
return user
|
|
}
|
|
|
|
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
|
|
if !testing.Testing() {
|
|
panic("CreateUsersForTest can only be called during tests")
|
|
}
|
|
|
|
prefix := "testuser"
|
|
if len(namePrefix) > 0 && namePrefix[0] != "" {
|
|
prefix = namePrefix[0]
|
|
}
|
|
|
|
users := make([]*types.User, count)
|
|
for i := range count {
|
|
name := prefix + "-" + strconv.Itoa(i)
|
|
users[i] = hsdb.CreateUserForTest(name)
|
|
}
|
|
|
|
return users
|
|
}
|