cli,hscontrol: use ID-based preauthkey operations

This commit is contained in:
Kristoffer Dalby
2026-01-07 13:36:51 +01:00
committed by Kristoffer Dalby
parent 8631581852
commit 1325fd8b27
6 changed files with 43 additions and 112 deletions

View File

@@ -57,10 +57,6 @@ func CreatePreAuthKey(
return nil, ErrPreAuthKeyNotTaggedOrOwned
}
// If uid != nil && len(aclTags) > 0:
// Both are allowed: UserID tracks "created by", tags define node ownership
// This is valid per the new model
var (
user *types.User
userID *uint
@@ -158,22 +154,17 @@ func CreatePreAuthKey(
}, nil
}
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeysByUser(rx, uid)
return ListPreAuthKeys(rx)
})
}
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
}
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeys(tx *gorm.DB) ([]types.PreAuthKey, error) {
var keys []types.PreAuthKey
keys := []types.PreAuthKey{}
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
err := tx.Preload("User").Find(&keys).Error
if err != nil {
return nil, err
}
@@ -298,34 +289,35 @@ func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. This also clears the auth_key_id on any nodes that reference
// this key.
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
return tx.Transaction(func(db *gorm.DB) error {
// First, clear the foreign key reference on any nodes using this key
err := db.Model(&types.Node{}).
Where("auth_key_id = ?", pak.ID).
Where("auth_key_id = ?", id).
Update("auth_key_id", nil).Error
if err != nil {
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
}
// Then delete the pre-auth key
if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
err = tx.Unscoped().Delete(&types.PreAuthKey{}, id).Error
if err != nil {
return err
}
return nil
})
}
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
func (hsdb *HSDatabase) ExpirePreAuthKey(id uint64) error {
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
return ExpirePreAuthKey(tx, id)
})
}
func (hsdb *HSDatabase) DeletePreAuthKey(k *types.PreAuthKey) error {
func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DestroyPreAuthKey(tx, *k)
return DestroyPreAuthKey(tx, id)
})
}
@@ -341,7 +333,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
}

View File

@@ -41,7 +41,7 @@ func TestCreatePreAuthKey(t *testing.T) {
assert.NotEmpty(t, key.Key)
// List keys for the user
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
keys, err := db.ListPreAuthKeys()
require.NoError(t, err)
assert.Len(t, keys, 1)
@@ -49,15 +49,6 @@ func TestCreatePreAuthKey(t *testing.T) {
assert.Equal(t, user.ID, keys[0].User.ID)
},
},
{
name: "error_list_invalid_user_id",
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
_, err := db.ListPreAuthKeys(1000000)
assert.Error(t, err)
},
},
}
for _, tt := range tests {
@@ -101,7 +92,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate)
require.NoError(t, err)
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
listedPaks, err := db.ListPreAuthKeys()
require.NoError(t, err)
require.Len(t, listedPaks, 1)

View File

@@ -58,12 +58,12 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
return ErrUserStillHasNodes
}
keys, err := ListPreAuthKeysByUser(tx, uid)
keys, err := ListPreAuthKeys(tx)
if err != nil {
return err
}
for _, key := range keys {
err = DestroyPreAuthKey(tx, key)
err = DestroyPreAuthKey(tx, key.ID)
if err != nil {
return err
}

View File

@@ -184,16 +184,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
if err != nil {
return nil, err
}
if uint64(preAuthKey.User.ID) != request.GetUser() {
return nil, fmt.Errorf("preauth key does not belong to user")
}
err = api.h.state.ExpirePreAuthKey(preAuthKey)
err := api.h.state.ExpirePreAuthKey(request.GetId())
if err != nil {
return nil, err
}
@@ -205,16 +196,7 @@ func (api headscaleV1APIServer) DeletePreAuthKey(
ctx context.Context,
request *v1.DeletePreAuthKeyRequest,
) (*v1.DeletePreAuthKeyResponse, error) {
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
if err != nil {
return nil, err
}
if uint64(preAuthKey.User.ID) != request.GetUser() {
return nil, fmt.Errorf("preauth key does not belong to user")
}
err = api.h.state.DeletePreAuthKey(preAuthKey)
err := api.h.state.DeletePreAuthKey(request.GetId())
if err != nil {
return nil, err
}
@@ -226,12 +208,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context,
request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) {
user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID))
preAuthKeys, err := api.h.state.ListPreAuthKeys()
if err != nil {
return nil, err
}

View File

@@ -1036,18 +1036,18 @@ func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) {
}
// ListPreAuthKeys returns all pre-authentication keys for a user.
func (s *State) ListPreAuthKeys(userID types.UserID) ([]types.PreAuthKey, error) {
return s.db.ListPreAuthKeys(userID)
func (s *State) ListPreAuthKeys() ([]types.PreAuthKey, error) {
return s.db.ListPreAuthKeys()
}
// ExpirePreAuthKey marks a pre-authentication key as expired.
func (s *State) ExpirePreAuthKey(preAuthKey *types.PreAuthKey) error {
return s.db.ExpirePreAuthKey(preAuthKey)
func (s *State) ExpirePreAuthKey(id uint64) error {
return s.db.ExpirePreAuthKey(id)
}
// DeletePreAuthKey permanently deletes a pre-authentication key.
func (s *State) DeletePreAuthKey(preAuthKey *types.PreAuthKey) error {
return s.db.DeletePreAuthKey(preAuthKey)
func (s *State) DeletePreAuthKey(id uint64) error {
return s.db.DeletePreAuthKey(id)
}
// GetRegistrationCacheEntry retrieves a node registration from cache.