cmd/headscale/cli: add mustMarkRequired helper for init-time flag validation

Replace three inconsistent MarkFlagRequired error-handling styles
(stdlib log.Fatal, zerolog log.Fatal, silently discarded) with a
single mustMarkRequired helper that panics on programmer error.

Also fixes a bug where renameNodeCmd.MarkFlagRequired("new-name")
targeted the wrong command (should be renameUserCmd), making the
--new-name flag effectively never required on "headscale users rename".
This commit is contained in:
Kristoffer Dalby
2026-02-18 14:49:04 +00:00
parent d6c39e65a5
commit 7b7b270126
5 changed files with 22 additions and 64 deletions

View File

@@ -6,7 +6,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)
@@ -19,12 +18,6 @@ func init() {
rootCmd.AddCommand(debugCmd)
createNodeCmd.Flags().StringP("name", "", "", "Name")
err := createNodeCmd.MarkFlagRequired("name")
if err != nil {
log.Fatal().Err(err).Msg("")
}
createNodeCmd.Flags().StringP("user", "u", "", "User")
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
@@ -32,17 +25,8 @@ func init() {
createNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
createNodeNamespaceFlag.Hidden = true
err = createNodeCmd.MarkFlagRequired("user")
if err != nil {
log.Fatal().Err(err).Msg("")
}
createNodeCmd.Flags().StringP("key", "k", "", "Key")
err = createNodeCmd.MarkFlagRequired("key")
if err != nil {
log.Fatal().Err(err).Msg("")
}
mustMarkRequired(createNodeCmd, "name", "user", "key")
createNodeCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")

View File

@@ -3,7 +3,6 @@ package cli
import (
"context"
"fmt"
"log"
"net/netip"
"strconv"
"strings"
@@ -39,55 +38,30 @@ func init() {
registerNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
registerNodeNamespaceFlag.Hidden = true
err := registerNodeCmd.MarkFlagRequired("user")
if err != nil {
log.Fatal(err.Error())
}
registerNodeCmd.Flags().StringP("key", "k", "", "Key")
err = registerNodeCmd.MarkFlagRequired("key")
if err != nil {
log.Fatal(err.Error())
}
mustMarkRequired(registerNodeCmd, "user", "key")
nodeCmd.AddCommand(registerNodeCmd)
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
err = expireNodeCmd.MarkFlagRequired("identifier")
if err != nil {
log.Fatal(err.Error())
}
mustMarkRequired(expireNodeCmd, "identifier")
nodeCmd.AddCommand(expireNodeCmd)
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = renameNodeCmd.MarkFlagRequired("identifier")
if err != nil {
log.Fatal(err.Error())
}
mustMarkRequired(renameNodeCmd, "identifier")
nodeCmd.AddCommand(renameNodeCmd)
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = deleteNodeCmd.MarkFlagRequired("identifier")
if err != nil {
log.Fatal(err.Error())
}
mustMarkRequired(deleteNodeCmd, "identifier")
nodeCmd.AddCommand(deleteNodeCmd)
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
_ = tagCmd.MarkFlagRequired("identifier")
mustMarkRequired(tagCmd, "identifier")
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
nodeCmd.AddCommand(tagCmd)
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
_ = approveRoutesCmd.MarkFlagRequired("identifier")
mustMarkRequired(approveRoutesCmd, "identifier")
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
nodeCmd.AddCommand(approveRoutesCmd)

View File

@@ -11,7 +11,6 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"tailscale.com/types/views"
)
@@ -29,22 +28,12 @@ func init() {
policyCmd.AddCommand(getPolicy)
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
err := setPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("")
}
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
mustMarkRequired(setPolicy, "file")
policyCmd.AddCommand(setPolicy)
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
err = checkPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("")
}
mustMarkRequired(checkPolicy, "file")
policyCmd.AddCommand(checkPolicy)
}

View File

@@ -58,8 +58,7 @@ func init() {
userCmd.AddCommand(renameUserCmd)
usernameAndIDFlag(renameUserCmd)
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
_ = renameNodeCmd.MarkFlagRequired("new-name")
mustMarkRequired(renameUserCmd, "new-name")
}
var errMissingParameter = errors.New("missing parameters")

View File

@@ -32,6 +32,18 @@ const (
var errAPIKeyNotSet = errors.New("HEADSCALE_CLI_API_KEY environment variable needs to be set")
// mustMarkRequired marks the named flags as required on cmd, panicking
// if any name does not match a registered flag. This is only called
// from init() where a failure indicates a programming error.
func mustMarkRequired(cmd *cobra.Command, names ...string) {
for _, n := range names {
err := cmd.MarkFlagRequired(n)
if err != nil {
panic(fmt.Sprintf("marking flag %q required on %q: %v", n, cmd.Name(), err))
}
}
}
func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
cfg, err := types.LoadServerConfig()
if err != nil {