cmd/headscale/cli: convert remaining commands to RunE

Convert the 10 commands that were still using Run with
ErrorOutput/SuccessOutput or log.Fatal/os.Exit:

- backfillNodeIPsCmd: use grpcRunE-style manual connection with
  error returns; simplify the confirm/force logic
- getPolicy, setPolicy, checkPolicy: replace ErrorOutput with
  fmt.Errorf returns in both the bypass-gRPC and gRPC paths
- serveCmd, configTestCmd: replace log.Fatal with error returns
- mockOidcCmd: replace log.Error+os.Exit with error return
- versionCmd, generatePrivateKeyCmd: replace SuccessOutput with
  printOutput
- dumpConfigCmd: return the error instead of swallowing it
This commit is contained in:
Kristoffer Dalby
2026-02-18 13:46:42 +00:00
parent e4fe216e45
commit 095106f498
8 changed files with 86 additions and 99 deletions

View File

@@ -1,7 +1,8 @@
package cli package cli
import ( import (
"github.com/rs/zerolog/log" "fmt"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@@ -13,10 +14,12 @@ var configTestCmd = &cobra.Command{
Use: "configtest", Use: "configtest",
Short: "Test the configuration.", Short: "Test the configuration.",
Long: "Run a test of the configuration and exit.", Long: "Run a test of the configuration and exit.",
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
_, err := newHeadscaleServerWithConfig() _, err := newHeadscaleServerWithConfig()
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msg("error initializing") return fmt.Errorf("configuration error: %w", err)
} }
return nil
}, },
} }

View File

@@ -18,11 +18,12 @@ var dumpConfigCmd = &cobra.Command{
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
err := viper.WriteConfigAs("/etc/headscale/config.dump.yaml") err := viper.WriteConfigAs("/etc/headscale/config.dump.yaml")
if err != nil { if err != nil {
//nolint return fmt.Errorf("dumping config: %w", err)
fmt.Println("Failed to dump config")
} }
return nil
}, },
} }

View File

@@ -21,22 +21,17 @@ var generateCmd = &cobra.Command{
var generatePrivateKeyCmd = &cobra.Command{ var generatePrivateKeyCmd = &cobra.Command{
Use: "private-key", Use: "private-key",
Short: "Generate a private key for the headscale server", Short: "Generate a private key for the headscale server",
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
output, _ := cmd.Flags().GetString("output")
machineKey := key.NewMachine() machineKey := key.NewMachine()
machineKeyStr, err := machineKey.MarshalText() machineKeyStr, err := machineKey.MarshalText()
if err != nil { if err != nil {
ErrorOutput( return fmt.Errorf("marshalling machine key: %w", err)
err,
fmt.Sprintf("Error getting machine key from flag: %s", err),
output,
)
} }
SuccessOutput(map[string]string{ return printOutput(cmd, map[string]string{
"private_key": string(machineKeyStr), "private_key": string(machineKeyStr),
}, },
string(machineKeyStr), output) string(machineKeyStr))
}, },
} }

View File

@@ -34,12 +34,13 @@ var mockOidcCmd = &cobra.Command{
Use: "mockoidc", Use: "mockoidc",
Short: "Runs a mock OIDC server for testing", Short: "Runs a mock OIDC server for testing",
Long: "This internal command runs a OpenID Connect for testing purposes", Long: "This internal command runs a OpenID Connect for testing purposes",
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
err := mockOIDC() err := mockOIDC()
if err != nil { if err != nil {
log.Error().Err(err).Msgf("error running mock OIDC server") return fmt.Errorf("running mock OIDC server: %w", err)
os.Exit(1)
} }
return nil
}, },
} }

View File

@@ -14,7 +14,6 @@ import (
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
@@ -357,9 +356,7 @@ all nodes that are missing.
If you remove IPv4 or IPv6 prefixes from the config, If you remove IPv4 or IPv6 prefixes from the config,
it can be run to remove the IPs that should no longer it can be run to remove the IPs that should no longer
be assigned to nodes.`, be assigned to nodes.`,
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
output, _ := cmd.Flags().GetString("output")
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
@@ -367,25 +364,23 @@ be assigned to nodes.`,
confirm = util.YesNo("Are you sure that you want to assign/remove IPs to/from nodes?") confirm = util.YesNo("Are you sure that you want to assign/remove IPs to/from nodes?")
} }
if confirm || force { if !confirm && !force {
ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig() return nil
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error connecting: %s", err), output)
}
defer cancel()
defer conn.Close()
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force})
if err != nil {
ErrorOutput(
err,
"Error backfilling IPs: "+status.Convert(err).Message(),
output,
)
}
SuccessOutput(changes, "Node IPs backfilled successfully", output)
} }
ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig()
if err != nil {
return fmt.Errorf("connecting to headscale: %w", err)
}
defer cancel()
defer conn.Close()
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: true})
if err != nil {
return fmt.Errorf("backfilling IPs: %w", err)
}
return printOutput(cmd, changes, "Node IPs backfilled successfully")
}, },
} }

View File

@@ -1,6 +1,7 @@
package cli package cli
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -19,6 +20,8 @@ const (
bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential
) )
var errAborted = errors.New("command aborted by user")
func init() { func init() {
rootCmd.AddCommand(policyCmd) rootCmd.AddCommand(policyCmd)
@@ -54,11 +57,8 @@ var getPolicy = &cobra.Command{
Use: "get", Use: "get",
Short: "Print the current ACL Policy", Short: "Print the current ACL Policy",
Aliases: []string{"show", "view", "fetch"}, Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
output, _ := cmd.Flags().GetString("output") var policyData string
var policy string
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
confirm := false confirm := false
@@ -68,51 +68,46 @@ var getPolicy = &cobra.Command{
} }
if !confirm && !force { if !confirm && !force {
ErrorOutput(nil, "Aborting command", output) return errAborted
return
} }
cfg, err := types.LoadServerConfig() cfg, err := types.LoadServerConfig()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output) return fmt.Errorf("loading config: %w", err)
} }
d, err := db.NewHeadscaleDatabase( d, err := db.NewHeadscaleDatabase(cfg, nil)
cfg,
nil,
)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output) return fmt.Errorf("opening database: %w", err)
} }
pol, err := d.GetPolicy() pol, err := d.GetPolicy()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading Policy from database: %s", err), output) return fmt.Errorf("loading policy from database: %w", err)
} }
policy = pol.Data policyData = pol.Data
} else { } else {
ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig() ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error connecting: %s", err), output) return fmt.Errorf("connecting to headscale: %w", err)
} }
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
request := &v1.GetPolicyRequest{} response, err := client.GetPolicy(ctx, &v1.GetPolicyRequest{})
response, err := client.GetPolicy(ctx, request)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output) return fmt.Errorf("loading ACL policy: %w", err)
} }
policy = response.GetPolicy() policyData = response.GetPolicy()
} }
// TODO(pallabpain): Maybe print this better? // This does not pass output format as we don't support yaml, json or
// This does not pass output as we dont support yaml, json or json-line // json-line output for this command. It is HuJSON already.
// output for this command. It is HuJSON already. fmt.Println(policyData)
SuccessOutput("", policy, "")
return nil
}, },
} }
@@ -123,19 +118,18 @@ var setPolicy = &cobra.Command{
Updates the existing ACL Policy with the provided policy. The policy must be a valid HuJSON object. Updates the existing ACL Policy with the provided policy. The policy must be a valid HuJSON object.
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`, This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
Aliases: []string{"put", "update"}, Aliases: []string{"put", "update"},
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file") policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath) f, err := os.Open(policyPath)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output) return fmt.Errorf("opening policy file: %w", err)
} }
defer f.Close() defer f.Close()
policyBytes, err := io.ReadAll(f) policyBytes, err := io.ReadAll(f)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) return fmt.Errorf("reading policy file: %w", err)
} }
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
@@ -147,80 +141,79 @@ var setPolicy = &cobra.Command{
} }
if !confirm && !force { if !confirm && !force {
ErrorOutput(nil, "Aborting command", output) return errAborted
return
} }
cfg, err := types.LoadServerConfig() cfg, err := types.LoadServerConfig()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output) return fmt.Errorf("loading config: %w", err)
} }
d, err := db.NewHeadscaleDatabase( d, err := db.NewHeadscaleDatabase(cfg, nil)
cfg,
nil,
)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output) return fmt.Errorf("opening database: %w", err)
} }
users, err := d.ListUsers() users, err := d.ListUsers()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to load users for policy validation: %s", err), output) return fmt.Errorf("loading users for policy validation: %w", err)
} }
_, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{}) _, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{})
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) return fmt.Errorf("parsing policy file: %w", err)
return
} }
_, err = d.SetPolicy(string(policyBytes)) _, err = d.SetPolicy(string(policyBytes))
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) return fmt.Errorf("setting ACL policy: %w", err)
} }
} else { } else {
request := &v1.SetPolicyRequest{Policy: string(policyBytes)} request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig() ctx, client, conn, cancel, err := newHeadscaleCLIWithConfig()
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error connecting: %s", err), output) return fmt.Errorf("connecting to headscale: %w", err)
} }
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr _, err = client.SetPolicy(ctx, request)
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) if err != nil {
return fmt.Errorf("setting ACL policy: %w", err)
} }
} }
SuccessOutput(nil, "Policy updated.", "") fmt.Println("Policy updated.")
return nil
}, },
} }
var checkPolicy = &cobra.Command{ var checkPolicy = &cobra.Command{
Use: "check", Use: "check",
Short: "Check the Policy file for errors", Short: "Check the Policy file for errors",
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
output, _ := cmd.Flags().GetString("output")
policyPath, _ := cmd.Flags().GetString("file") policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath) f, err := os.Open(policyPath)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output) return fmt.Errorf("opening policy file: %w", err)
} }
defer f.Close() defer f.Close()
policyBytes, err := io.ReadAll(f) policyBytes, err := io.ReadAll(f)
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) return fmt.Errorf("reading policy file: %w", err)
} }
_, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{}) _, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{})
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) return fmt.Errorf("parsing policy file: %w", err)
} }
SuccessOutput(nil, "Policy is valid", "") fmt.Println("Policy is valid")
return nil
}, },
} }

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/tailscale/squibble" "github.com/tailscale/squibble"
) )
@@ -20,7 +19,7 @@ var serveCmd = &cobra.Command{
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
app, err := newHeadscaleServerWithConfig() app, err := newHeadscaleServerWithConfig()
if err != nil { if err != nil {
if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok { if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok {
@@ -28,12 +27,14 @@ var serveCmd = &cobra.Command{
fmt.Println(squibbleErr.Diff) fmt.Println(squibbleErr.Diff)
} }
log.Fatal().Caller().Err(err).Msg("error initializing") return fmt.Errorf("initializing: %w", err)
} }
err = app.Serve() err = app.Serve()
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Caller().Err(err).Msg("headscale ran into an error and had to shut down") return fmt.Errorf("headscale ran into an error and had to shut down: %w", err)
} }
return nil
}, },
} }

View File

@@ -14,11 +14,9 @@ var versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "Print the version.", Short: "Print the version.",
Long: "The version of headscale.", Long: "The version of headscale.",
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
output, _ := cmd.Flags().GetString("output")
info := types.GetVersionInfo() info := types.GetVersionInfo()
SuccessOutput(info, info.String(), output) return printOutput(cmd, info, info.String())
}, },
} }