mirror of
https://github.com/juanfont/headscale.git
synced 2026-04-23 17:18:50 +02:00
cli,hscontrol: use ID-based preauthkey operations
This commit is contained in:
committed by
Kristoffer Dalby
parent
8631581852
commit
1325fd8b27
@@ -57,10 +57,6 @@ func CreatePreAuthKey(
|
||||
return nil, ErrPreAuthKeyNotTaggedOrOwned
|
||||
}
|
||||
|
||||
// If uid != nil && len(aclTags) > 0:
|
||||
// Both are allowed: UserID tracks "created by", tags define node ownership
|
||||
// This is valid per the new model
|
||||
|
||||
var (
|
||||
user *types.User
|
||||
userID *uint
|
||||
@@ -158,22 +154,17 @@ func CreatePreAuthKey(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
return ListPreAuthKeysByUser(rx, uid)
|
||||
return ListPreAuthKeys(rx)
|
||||
})
|
||||
}
|
||||
|
||||
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
|
||||
user, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func ListPreAuthKeys(tx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||
var keys []types.PreAuthKey
|
||||
|
||||
keys := []types.PreAuthKey{}
|
||||
|
||||
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
|
||||
err := tx.Preload("User").Find(&keys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -298,34 +289,35 @@ func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist. This also clears the auth_key_id on any nodes that reference
|
||||
// this key.
|
||||
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
|
||||
func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
return tx.Transaction(func(db *gorm.DB) error {
|
||||
// First, clear the foreign key reference on any nodes using this key
|
||||
err := db.Model(&types.Node{}).
|
||||
Where("auth_key_id = ?", pak.ID).
|
||||
Where("auth_key_id = ?", id).
|
||||
Update("auth_key_id", nil).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
|
||||
}
|
||||
|
||||
// Then delete the pre-auth key
|
||||
if result := db.Unscoped().Delete(pak); result.Error != nil {
|
||||
return result.Error
|
||||
err = tx.Unscoped().Delete(&types.PreAuthKey{}, id).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(id uint64) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return ExpirePreAuthKey(tx, k)
|
||||
return ExpirePreAuthKey(tx, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeletePreAuthKey(k *types.PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return DestroyPreAuthKey(tx, *k)
|
||||
return DestroyPreAuthKey(tx, id)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -341,7 +333,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestCreatePreAuthKey(t *testing.T) {
|
||||
assert.NotEmpty(t, key.Key)
|
||||
|
||||
// List keys for the user
|
||||
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
keys, err := db.ListPreAuthKeys()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
|
||||
@@ -49,15 +49,6 @@ func TestCreatePreAuthKey(t *testing.T) {
|
||||
assert.Equal(t, user.ID, keys[0].User.ID)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error_list_invalid_user_id",
|
||||
test: func(t *testing.T, db *HSDatabase) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.ListPreAuthKeys(1000000)
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -101,7 +92,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
|
||||
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate)
|
||||
require.NoError(t, err)
|
||||
|
||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
listedPaks, err := db.ListPreAuthKeys()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, listedPaks, 1)
|
||||
|
||||
|
||||
@@ -58,12 +58,12 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||
return ErrUserStillHasNodes
|
||||
}
|
||||
|
||||
keys, err := ListPreAuthKeysByUser(tx, uid)
|
||||
keys, err := ListPreAuthKeys(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
err = DestroyPreAuthKey(tx, key)
|
||||
err = DestroyPreAuthKey(tx, key.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -184,16 +184,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
||||
ctx context.Context,
|
||||
request *v1.ExpirePreAuthKeyRequest,
|
||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uint64(preAuthKey.User.ID) != request.GetUser() {
|
||||
return nil, fmt.Errorf("preauth key does not belong to user")
|
||||
}
|
||||
|
||||
err = api.h.state.ExpirePreAuthKey(preAuthKey)
|
||||
err := api.h.state.ExpirePreAuthKey(request.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -205,16 +196,7 @@ func (api headscaleV1APIServer) DeletePreAuthKey(
|
||||
ctx context.Context,
|
||||
request *v1.DeletePreAuthKeyRequest,
|
||||
) (*v1.DeletePreAuthKeyResponse, error) {
|
||||
preAuthKey, err := api.h.state.GetPreAuthKey(request.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if uint64(preAuthKey.User.ID) != request.GetUser() {
|
||||
return nil, fmt.Errorf("preauth key does not belong to user")
|
||||
}
|
||||
|
||||
err = api.h.state.DeletePreAuthKey(preAuthKey)
|
||||
err := api.h.state.DeletePreAuthKey(request.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -226,12 +208,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||
ctx context.Context,
|
||||
request *v1.ListPreAuthKeysRequest,
|
||||
) (*v1.ListPreAuthKeysResponse, error) {
|
||||
user, err := api.h.state.GetUserByID(types.UserID(request.GetUser()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
preAuthKeys, err := api.h.state.ListPreAuthKeys(types.UserID(user.ID))
|
||||
preAuthKeys, err := api.h.state.ListPreAuthKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1036,18 +1036,18 @@ func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) {
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns all pre-authentication keys for a user.
|
||||
func (s *State) ListPreAuthKeys(userID types.UserID) ([]types.PreAuthKey, error) {
|
||||
return s.db.ListPreAuthKeys(userID)
|
||||
func (s *State) ListPreAuthKeys() ([]types.PreAuthKey, error) {
|
||||
return s.db.ListPreAuthKeys()
|
||||
}
|
||||
|
||||
// ExpirePreAuthKey marks a pre-authentication key as expired.
|
||||
func (s *State) ExpirePreAuthKey(preAuthKey *types.PreAuthKey) error {
|
||||
return s.db.ExpirePreAuthKey(preAuthKey)
|
||||
func (s *State) ExpirePreAuthKey(id uint64) error {
|
||||
return s.db.ExpirePreAuthKey(id)
|
||||
}
|
||||
|
||||
// DeletePreAuthKey permanently deletes a pre-authentication key.
|
||||
func (s *State) DeletePreAuthKey(preAuthKey *types.PreAuthKey) error {
|
||||
return s.db.DeletePreAuthKey(preAuthKey)
|
||||
func (s *State) DeletePreAuthKey(id uint64) error {
|
||||
return s.db.DeletePreAuthKey(id)
|
||||
}
|
||||
|
||||
// GetRegistrationCacheEntry retrieves a node registration from cache.
|
||||
|
||||
Reference in New Issue
Block a user