node: implement disable key expiry via CLI and API

Add --disable flag to "headscale nodes expire" CLI command and
disable_expiry field handling in the gRPC API to allow disabling
key expiry for nodes. When disabled, the node's expiry is set to
NULL and IsExpired() returns false.

The CLI follows the new grpcRunE/RunE/printOutput patterns
introduced in the recent CLI refactor.

Also fix NodeSetExpiry to persist directly to the database instead
of going through persistNodeToDB which omits the expiry field.

Fixes #2681

Co-authored-by: Marco Santos <me@marcopsantos.com>
This commit is contained in:
Kristoffer Dalby
2026-02-20 10:58:49 +00:00
parent a8f7fedced
commit f20bd0cf08
9 changed files with 222 additions and 17 deletions

View File

@@ -212,7 +212,9 @@ func (h *Headscale) handleLogout(
// Update the internal state with the nodes new expiry, meaning it is
// logged out.
updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry)
expiry := req.Expiry
updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), &expiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}

View File

@@ -587,7 +587,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Expire the node
expiredTime := time.Now().Add(-1 * time.Hour)
_, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime)
_, _, err = app.state.SetNodeExpiry(node.ID(), &expiredTime)
return "", err
},

View File

@@ -315,16 +315,15 @@ func RenameNode(tx *gorm.DB,
return nil
}
func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error {
func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry *time.Time) error {
return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetExpiry(tx, nodeID, expiry)
})
}
// NodeSetExpiry takes a Node struct and a new expiry time.
func NodeSetExpiry(tx *gorm.DB,
nodeID types.NodeID, expiry time.Time,
) error {
// NodeSetExpiry sets a new expiry time for a node.
// If expiry is nil, the node's expiry is disabled (node will never expire).
func NodeSetExpiry(tx *gorm.DB, nodeID types.NodeID, expiry *time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
}

View File

@@ -128,7 +128,7 @@ func TestExpireNode(t *testing.T) {
assert.False(t, nodeFromDB.IsExpired())
now := time.Now()
err = db.NodeSetExpiry(nodeFromDB.ID, now)
err = db.NodeSetExpiry(nodeFromDB.ID, &now)
require.NoError(t, err)
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
@@ -137,6 +137,48 @@ func TestExpireNode(t *testing.T) {
assert.True(t, nodeFromDB.IsExpired())
}
func TestDisableNodeExpiry(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
pakID := pak.ID
node := &types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: &pakID,
Expiry: &time.Time{},
}
db.DB.Save(node)
// Set an expiry first.
past := time.Now().Add(-time.Hour)
err = db.NodeSetExpiry(node.ID, &past)
require.NoError(t, err)
nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode")
require.NoError(t, err)
assert.True(t, nodeFromDB.IsExpired(), "node should be expired")
// Disable expiry by setting nil.
err = db.NodeSetExpiry(node.ID, nil)
require.NoError(t, err)
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
require.NoError(t, err)
assert.False(t, nodeFromDB.IsExpired(), "node should not be expired after disabling expiry")
assert.Nil(t, nodeFromDB.Expiry, "expiry should be nil after disabling")
}
func TestSetTags(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)

View File

@@ -451,12 +451,40 @@ func (api headscaleV1APIServer) ExpireNode(
ctx context.Context,
request *v1.ExpireNodeRequest,
) (*v1.ExpireNodeResponse, error) {
if request.GetDisableExpiry() && request.GetExpiry() != nil {
return nil, status.Error(
codes.InvalidArgument,
"cannot set both disable_expiry and expiry",
)
}
// Handle disable expiry request - node will never expire.
if request.GetDisableExpiry() {
node, nodeChange, err := api.h.state.SetNodeExpiry(
types.NodeID(request.GetNodeId()), nil,
)
if err != nil {
return nil, err
}
api.h.Change(nodeChange)
log.Trace().
Caller().
EmbedObject(node).
Msg("node expiry disabled")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
}
expiry := time.Now()
if request.GetExpiry() != nil {
expiry = request.GetExpiry().AsTime()
}
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry)
node, nodeChange, err := api.h.state.SetNodeExpiry(
types.NodeID(request.GetNodeId()), &expiry,
)
if err != nil {
return nil, err
}
@@ -467,7 +495,7 @@ func (api headscaleV1APIServer) ExpireNode(
log.Trace().
Caller().
EmbedObject(node).
Time(zf.ExpiresAt, *node.AsStruct().Expiry).
Time(zf.ExpiresAt, expiry).
Msg("node expired")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil

View File

@@ -638,22 +638,38 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] {
}
// SetNodeExpiry updates the expiration time for a node.
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) {
// If expiry is nil, the node's expiry is disabled (node will never expire).
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry *time.Time) (types.NodeView, change.Change, error) {
// Update NodeStore before database to ensure consistency. The NodeStore update is
// blocking and will be the source of truth for the batcher. The database update must
// make the exact same change. If the database update fails, the NodeStore change will
// remain, but since we return an error, no change notification will be sent to the
// batcher, preventing inconsistent state propagation.
expiryPtr := expiry
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
node.Expiry = &expiryPtr
node.Expiry = expiry
})
if !ok {
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
}
return s.persistNodeToDB(n)
// Persist expiry change to database directly since persistNodeToDB omits expiry.
err := s.db.NodeSetExpiry(nodeID, expiry)
if err != nil {
return types.NodeView{}, change.Change{}, fmt.Errorf("setting node expiry in database: %w", err)
}
// Update policy manager and generate change notification.
c, err := s.updatePolicyManagerNodes()
if err != nil {
return n, change.Change{}, fmt.Errorf("updating policy manager after setting expiry: %w", err)
}
if c.IsEmpty() {
c = change.NodeAdded(n.ID())
}
return n, c, nil
}
// SetNodeTags assigns tags to a node, making it a "tagged node".