From 7b7b270126d992e7f4c3e8e677523240726708d7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 18 Feb 2026 14:49:04 +0000 Subject: [PATCH] cmd/headscale/cli: add mustMarkRequired helper for init-time flag validation Replace three inconsistent MarkFlagRequired error-handling styles (stdlib log.Fatal, zerolog log.Fatal, silently discarded) with a single mustMarkRequired helper that panics on programmer error. Also fixes a bug where renameNodeCmd.MarkFlagRequired("new-name") targeted the wrong command (should be renameUserCmd), making the --new-name flag effectively never required on "headscale users rename". --- cmd/headscale/cli/debug.go | 18 +----------------- cmd/headscale/cli/nodes.go | 38 ++++++------------------------------- cmd/headscale/cli/policy.go | 15 ++------------- cmd/headscale/cli/users.go | 3 +-- cmd/headscale/cli/utils.go | 12 ++++++++++++ 5 files changed, 22 insertions(+), 64 deletions(-) diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 49a479f6..e9eff4ce 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -6,7 +6,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" - "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -19,12 +18,6 @@ func init() { rootCmd.AddCommand(debugCmd) createNodeCmd.Flags().StringP("name", "", "", "Name") - - err := createNodeCmd.MarkFlagRequired("name") - if err != nil { - log.Fatal().Err(err).Msg("") - } - createNodeCmd.Flags().StringP("user", "u", "", "User") createNodeCmd.Flags().StringP("namespace", "n", "", "User") @@ -32,17 +25,8 @@ func init() { createNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage createNodeNamespaceFlag.Hidden = true - err = createNodeCmd.MarkFlagRequired("user") - if err != nil { - log.Fatal().Err(err).Msg("") - } - createNodeCmd.Flags().StringP("key", "k", "", "Key") - - err = createNodeCmd.MarkFlagRequired("key") - if err != nil { - log.Fatal().Err(err).Msg("") - } + mustMarkRequired(createNodeCmd, "name", "user", "key") createNodeCmd.Flags(). StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise") diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index a13f8b56..878144c2 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -3,7 +3,6 @@ package cli import ( "context" "fmt" - "log" "net/netip" "strconv" "strings" @@ -39,55 +38,30 @@ func init() { registerNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage registerNodeNamespaceFlag.Hidden = true - err := registerNodeCmd.MarkFlagRequired("user") - if err != nil { - log.Fatal(err.Error()) - } - registerNodeCmd.Flags().StringP("key", "k", "", "Key") - - err = registerNodeCmd.MarkFlagRequired("key") - if err != nil { - log.Fatal(err.Error()) - } - + mustMarkRequired(registerNodeCmd, "user", "key") nodeCmd.AddCommand(registerNodeCmd) expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") - - err = expireNodeCmd.MarkFlagRequired("identifier") - if err != nil { - log.Fatal(err.Error()) - } - + mustMarkRequired(expireNodeCmd, "identifier") nodeCmd.AddCommand(expireNodeCmd) renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - - err = renameNodeCmd.MarkFlagRequired("identifier") - if err != nil { - log.Fatal(err.Error()) - } - + mustMarkRequired(renameNodeCmd, "identifier") nodeCmd.AddCommand(renameNodeCmd) deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - - err = deleteNodeCmd.MarkFlagRequired("identifier") - if err != nil { - log.Fatal(err.Error()) - } - + mustMarkRequired(deleteNodeCmd, "identifier") nodeCmd.AddCommand(deleteNodeCmd) tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - _ = tagCmd.MarkFlagRequired("identifier") + mustMarkRequired(tagCmd, "identifier") tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") nodeCmd.AddCommand(tagCmd) approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - _ = approveRoutesCmd.MarkFlagRequired("identifier") + mustMarkRequired(approveRoutesCmd, "identifier") approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) nodeCmd.AddCommand(approveRoutesCmd) diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 00fc9945..1708eec2 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -11,7 +11,6 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" "github.com/spf13/cobra" "tailscale.com/types/views" ) @@ -29,22 +28,12 @@ func init() { policyCmd.AddCommand(getPolicy) setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - - err := setPolicy.MarkFlagRequired("file") - if err != nil { - log.Fatal().Err(err).Msg("") - } - setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") + mustMarkRequired(setPolicy, "file") policyCmd.AddCommand(setPolicy) checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - - err = checkPolicy.MarkFlagRequired("file") - if err != nil { - log.Fatal().Err(err).Msg("") - } - + mustMarkRequired(checkPolicy, "file") policyCmd.AddCommand(checkPolicy) } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 61ea5b16..ea8b1b24 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -58,8 +58,7 @@ func init() { userCmd.AddCommand(renameUserCmd) usernameAndIDFlag(renameUserCmd) renameUserCmd.Flags().StringP("new-name", "r", "", "New username") - - _ = renameNodeCmd.MarkFlagRequired("new-name") + mustMarkRequired(renameUserCmd, "new-name") } var errMissingParameter = errors.New("missing parameters") diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 781a9085..52f47215 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -32,6 +32,18 @@ const ( var errAPIKeyNotSet = errors.New("HEADSCALE_CLI_API_KEY environment variable needs to be set") +// mustMarkRequired marks the named flags as required on cmd, panicking +// if any name does not match a registered flag. This is only called +// from init() where a failure indicates a programming error. +func mustMarkRequired(cmd *cobra.Command, names ...string) { + for _, n := range names { + err := cmd.MarkFlagRequired(n) + if err != nil { + panic(fmt.Sprintf("marking flag %q required on %q: %v", n, cmd.Name(), err)) + } + } +} + func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) { cfg, err := types.LoadServerConfig() if err != nil {