From ce580f824536f6c250dafd1c23e5558ca603130d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 6 Feb 2026 21:45:32 +0100 Subject: [PATCH] all: fix golangci-lint issues (#3064) --- .golangci.yaml | 1 + cmd/headscale/cli/api_key.go | 2 +- cmd/headscale/cli/debug.go | 5 + cmd/headscale/cli/mockoidc.go | 17 +- cmd/headscale/cli/nodes.go | 46 +- cmd/headscale/cli/policy.go | 14 +- cmd/headscale/cli/root.go | 6 +- cmd/headscale/cli/users.go | 21 +- cmd/headscale/cli/utils.go | 14 +- cmd/headscale/headscale.go | 1 + cmd/headscale/headscale_test.go | 8 +- cmd/hi/cleanup.go | 23 +- cmd/hi/docker.go | 96 ++- cmd/hi/doctor.go | 36 +- cmd/hi/main.go | 11 +- cmd/hi/run.go | 14 +- cmd/hi/stats.go | 62 +- cmd/mapresponses/main.go | 9 +- flake.nix | 2 +- hscontrol/app.go | 81 ++- hscontrol/auth.go | 13 +- hscontrol/auth_test.go | 369 ++++++---- hscontrol/db/api_key.go | 9 +- hscontrol/db/db.go | 81 ++- hscontrol/db/db_test.go | 18 +- .../db/ephemeral_garbage_collector_test.go | 87 ++- hscontrol/db/ip.go | 42 +- hscontrol/db/ip_test.go | 13 +- hscontrol/db/node.go | 48 +- hscontrol/db/node_test.go | 48 +- hscontrol/db/policy.go | 3 +- hscontrol/db/preauth_keys.go | 9 +- hscontrol/db/sqliteconfig/config.go | 8 +- hscontrol/db/sqliteconfig/config_test.go | 2 + hscontrol/db/sqliteconfig/integration_test.go | 27 +- hscontrol/db/suite_test.go | 3 + hscontrol/db/text_serialiser.go | 22 +- hscontrol/db/users.go | 29 +- hscontrol/debug.go | 68 +- hscontrol/derp/derp.go | 12 +- hscontrol/derp/derp_test.go | 2 + hscontrol/derp/server/derp_server.go | 45 +- hscontrol/dns/extrarecords.go | 14 +- hscontrol/grpcv1_test.go | 4 +- hscontrol/handlers.go | 27 +- hscontrol/mapper/batcher.go | 14 +- hscontrol/mapper/batcher_lockfree.go | 49 +- hscontrol/mapper/batcher_test.go | 148 ++-- hscontrol/mapper/builder.go | 21 +- hscontrol/mapper/builder_test.go | 4 +- hscontrol/mapper/mapper.go | 17 +- hscontrol/mapper/mapper_test.go | 91 --- hscontrol/mapper/tail_test.go | 7 +- hscontrol/metrics.go | 3 + hscontrol/noise.go | 29 +- hscontrol/oidc.go | 31 +- hscontrol/platform_config.go | 18 +- hscontrol/policy/matcher/matcher.go | 9 +- hscontrol/policy/pm.go | 18 +- hscontrol/policy/policy.go | 1 + hscontrol/policy/policy_autoapprove_test.go | 12 +- hscontrol/policy/policy_test.go | 28 +- hscontrol/policy/policyutil/reduce.go | 1 + hscontrol/policy/policyutil/reduce_test.go | 8 +- hscontrol/policy/route_approval_test.go | 2 + hscontrol/policy/v2/filter.go | 41 +- hscontrol/policy/v2/filter_test.go | 36 +- hscontrol/policy/v2/policy.go | 29 +- hscontrol/policy/v2/policy_test.go | 70 +- hscontrol/policy/v2/tailscale_compat_test.go | 4 +- hscontrol/policy/v2/types.go | 677 +++++++++++------- hscontrol/policy/v2/types_test.go | 103 +-- hscontrol/policy/v2/utils.go | 31 +- hscontrol/policy/v2/utils_test.go | 11 +- hscontrol/poll.go | 20 +- hscontrol/routes/primary.go | 8 + hscontrol/routes/primary_test.go | 19 +- hscontrol/state/debug.go | 12 +- hscontrol/state/ephemeral_test.go | 30 +- hscontrol/state/maprequest.go | 1 + hscontrol/state/maprequest_test.go | 26 - hscontrol/state/node_store.go | 29 +- hscontrol/state/node_store_test.go | 167 +++-- hscontrol/state/state.go | 13 +- hscontrol/tailsql.go | 29 +- hscontrol/types/common.go | 13 +- hscontrol/types/config.go | 91 ++- hscontrol/types/config_test.go | 39 +- hscontrol/types/node.go | 44 +- hscontrol/types/node_test.go | 8 +- hscontrol/types/preauth_key.go | 2 +- hscontrol/types/preauth_key_test.go | 1 + hscontrol/types/users.go | 31 +- hscontrol/types/users_test.go | 12 +- hscontrol/types/version.go | 4 +- hscontrol/util/addr.go | 1 + hscontrol/util/addr_test.go | 2 + hscontrol/util/dns.go | 77 +- hscontrol/util/dns_test.go | 3 + hscontrol/util/file.go | 12 +- hscontrol/util/log.go | 1 + hscontrol/util/prompt.go | 5 +- hscontrol/util/prompt_test.go | 22 +- hscontrol/util/string.go | 12 +- hscontrol/util/test.go | 2 +- hscontrol/util/util.go | 53 +- hscontrol/util/util_test.go | 69 +- integration/acl_test.go | 3 +- integration/api_auth_test.go | 98 ++- integration/auth_key_test.go | 40 +- integration/auth_oidc_test.go | 41 +- integration/auth_web_flow_test.go | 27 +- integration/cli_test.go | 38 +- integration/control.go | 8 +- integration/derp_verify_endpoint_test.go | 11 +- integration/dns_test.go | 23 +- integration/dockertestutil/config.go | 1 + integration/dockertestutil/execute.go | 8 +- integration/dockertestutil/logs.go | 3 + integration/dockertestutil/network.go | 5 +- integration/dsic/dsic.go | 8 + integration/helpers.go | 125 +++- integration/hsic/hsic.go | 77 +- integration/integrationutil/util.go | 9 +- integration/route_test.go | 113 ++- integration/scenario.go | 122 +++- integration/scenario_test.go | 2 + integration/ssh_test.go | 2 +- integration/tags_test.go | 4 +- integration/tsic/tsic.go | 76 +- swagger.go | 4 +- 131 files changed, 3131 insertions(+), 1560 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index 7e1ab297..5ebd698a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -18,6 +18,7 @@ linters: - lll - maintidx - makezero + - mnd - musttag - nestif - nolintlint diff --git a/cmd/headscale/cli/api_key.go b/cmd/headscale/cli/api_key.go index d821b290..36cd30e2 100644 --- a/cmd/headscale/cli/api_key.go +++ b/cmd/headscale/cli/api_key.go @@ -14,7 +14,7 @@ import ( ) const ( - // 90 days. + // DefaultAPIKeyExpiry is 90 days. DefaultAPIKeyExpiry = "90d" ) diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 75187ddd..c0a7d3d0 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -19,10 +19,12 @@ func init() { rootCmd.AddCommand(debugCmd) createNodeCmd.Flags().StringP("name", "", "", "Name") + err := createNodeCmd.MarkFlagRequired("name") if err != nil { log.Fatal().Err(err).Msg("") } + createNodeCmd.Flags().StringP("user", "u", "", "User") createNodeCmd.Flags().StringP("namespace", "n", "", "User") @@ -34,11 +36,14 @@ func init() { if err != nil { log.Fatal().Err(err).Msg("") } + createNodeCmd.Flags().StringP("key", "k", "", "Key") + err = createNodeCmd.MarkFlagRequired("key") if err != nil { log.Fatal().Err(err).Msg("") } + createNodeCmd.Flags(). StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise") diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 9668f880..8204ecc2 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -1,8 +1,8 @@ package cli import ( + "context" "encoding/json" - "errors" "fmt" "net" "net/http" @@ -20,6 +20,7 @@ const ( errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined") errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined") errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined") + errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined") refreshTTL = 60 * time.Minute ) @@ -47,33 +48,39 @@ func mockOIDC() error { if clientID == "" { return errMockOidcClientIDNotDefined } + clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET") if clientSecret == "" { return errMockOidcClientSecretNotDefined } + addrStr := os.Getenv("MOCKOIDC_ADDR") if addrStr == "" { return errMockOidcPortNotDefined } + portStr := os.Getenv("MOCKOIDC_PORT") if portStr == "" { return errMockOidcPortNotDefined } + accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL") if accessTTLOverride != "" { newTTL, err := time.ParseDuration(accessTTLOverride) if err != nil { return err } + accessTTL = newTTL } userStr := os.Getenv("MOCKOIDC_USERS") if userStr == "" { - return errors.New("MOCKOIDC_USERS not defined") + return errMockOidcUsersNotDefined } var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) if err != nil { return fmt.Errorf("unmarshalling users: %w", err) @@ -93,7 +100,7 @@ func mockOIDC() error { return err } - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port)) + listener, err := new(net.ListenConfig).Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", addrStr, port)) if err != nil { return err } @@ -105,6 +112,7 @@ func mockOIDC() error { log.Info().Msgf("mock OIDC server listening on %s", listener.Addr().String()) log.Info().Msgf("issuer: %s", mock.Issuer()) + c := make(chan struct{}) <-c @@ -135,10 +143,11 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser ErrorQueue: &mockoidc.ErrorQueue{}, } - mock.AddMiddleware(func(h http.Handler) http.Handler { + _ = mock.AddMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Info().Msgf("request: %+v", r) h.ServeHTTP(w, r) + if r.Response != nil { log.Info().Msgf("response: %+v", r.Response) } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index fe5d9af5..827d72e7 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -26,6 +26,7 @@ func init() { listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace") listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage listNodesNamespaceFlag.Hidden = true + nodeCmd.AddCommand(listNodesCmd) listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") @@ -42,42 +43,51 @@ func init() { if err != nil { log.Fatal(err.Error()) } + registerNodeCmd.Flags().StringP("key", "k", "", "Key") + err = registerNodeCmd.MarkFlagRequired("key") if err != nil { log.Fatal(err.Error()) } + nodeCmd.AddCommand(registerNodeCmd) expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") + err = expireNodeCmd.MarkFlagRequired("identifier") if err != nil { log.Fatal(err.Error()) } + nodeCmd.AddCommand(expireNodeCmd) renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + err = renameNodeCmd.MarkFlagRequired("identifier") if err != nil { log.Fatal(err.Error()) } + nodeCmd.AddCommand(renameNodeCmd) deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + err = deleteNodeCmd.MarkFlagRequired("identifier") if err != nil { log.Fatal(err.Error()) } + nodeCmd.AddCommand(deleteNodeCmd) tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - tagCmd.MarkFlagRequired("identifier") + _ = tagCmd.MarkFlagRequired("identifier") tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") nodeCmd.AddCommand(tagCmd) approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - approveRoutesCmd.MarkFlagRequired("identifier") + _ = approveRoutesCmd.MarkFlagRequired("identifier") approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) nodeCmd.AddCommand(approveRoutesCmd) @@ -233,10 +243,7 @@ var listNodeRoutesCmd = &cobra.Command{ return } - tableData, err := nodeRoutesToPtables(nodes) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - } + tableData := nodeRoutesToPtables(nodes) err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { @@ -506,15 +513,21 @@ func nodesToPtables( ephemeral = true } - var lastSeen time.Time - var lastSeenTime string + var ( + lastSeen time.Time + lastSeenTime string + ) + if node.GetLastSeen() != nil { lastSeen = node.GetLastSeen().AsTime() lastSeenTime = lastSeen.Format("2006-01-02 15:04:05") } - var expiry time.Time - var expiryTime string + var ( + expiry time.Time + expiryTime string + ) + if node.GetExpiry() != nil { expiry = node.GetExpiry().AsTime() expiryTime = expiry.Format("2006-01-02 15:04:05") @@ -523,6 +536,7 @@ func nodesToPtables( } var machineKey key.MachinePublic + err := machineKey.UnmarshalText( []byte(node.GetMachineKey()), ) @@ -531,6 +545,7 @@ func nodesToPtables( } var nodeKey key.NodePublic + err = nodeKey.UnmarshalText( []byte(node.GetNodeKey()), ) @@ -572,8 +587,11 @@ func nodesToPtables( user = pterm.LightYellow(node.GetUser().GetName()) } - var IPV4Address string - var IPV6Address string + var ( + IPV4Address string + IPV6Address string + ) + for _, addr := range node.GetIpAddresses() { if netip.MustParseAddr(addr).Is4() { IPV4Address = addr @@ -608,7 +626,7 @@ func nodesToPtables( func nodeRoutesToPtables( nodes []*v1.Node, -) (pterm.TableData, error) { +) pterm.TableData { tableHeader := []string{ "ID", "Hostname", @@ -632,7 +650,7 @@ func nodeRoutesToPtables( ) } - return tableData, nil + return tableData } var tagCmd = &cobra.Command{ diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 2aaebcfa..98e50b1d 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -16,7 +16,7 @@ import ( ) const ( - bypassFlag = "bypass-grpc-and-access-database-directly" + bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential ) func init() { @@ -26,16 +26,22 @@ func init() { policyCmd.AddCommand(getPolicy) setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - if err := setPolicy.MarkFlagRequired("file"); err != nil { + + err := setPolicy.MarkFlagRequired("file") + if err != nil { log.Fatal().Err(err).Msg("") } + setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") policyCmd.AddCommand(setPolicy) checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - if err := checkPolicy.MarkFlagRequired("file"); err != nil { + + err = checkPolicy.MarkFlagRequired("file") + if err != nil { log.Fatal().Err(err).Msg("") } + policyCmd.AddCommand(checkPolicy) } @@ -173,7 +179,7 @@ var setPolicy = &cobra.Command{ defer cancel() defer conn.Close() - if _, err := client.SetPolicy(ctx, request); err != nil { + if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) } } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index 4eaf0586..7f84fb8a 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -45,6 +45,7 @@ func initConfig() { if cfgFile == "" { cfgFile = os.Getenv("HEADSCALE_CONFIG") } + if cfgFile != "" { err := types.LoadConfig(cfgFile, true) if err != nil { @@ -80,6 +81,7 @@ func initConfig() { Repository: "headscale", TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }), } + res, err := latest.Check(githubTag, versionInfo.Version) if err == nil && res.Outdated { //nolint @@ -101,6 +103,7 @@ func isPreReleaseVersion(version string) bool { return true } } + return false } @@ -140,7 +143,8 @@ https://github.com/juanfont/headscale`, } func Execute() { - if err := rootCmd.Execute(); err != nil { + err := rootCmd.Execute() + if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 55e5c3db..c1139725 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -15,6 +15,12 @@ import ( "google.golang.org/grpc/status" ) +// CLI user errors. +var ( + errFlagRequired = errors.New("--name or --identifier flag is required") + errMultipleUsersMatch = errors.New("multiple users match query, specify an ID") +) + func usernameAndIDFlag(cmd *cobra.Command) { cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)") cmd.Flags().StringP("name", "n", "", "Username") @@ -24,12 +30,12 @@ func usernameAndIDFlag(cmd *cobra.Command) { // If both are empty, it will exit the program with an error. func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { username, _ := cmd.Flags().GetString("name") + identifier, _ := cmd.Flags().GetInt64("identifier") if username == "" && identifier < 0 { - err := errors.New("--name or --identifier flag is required") ErrorOutput( - err, - "Cannot rename user: "+status.Convert(err).Message(), + errFlagRequired, + "Cannot rename user: "+status.Convert(errFlagRequired).Message(), "", ) } @@ -51,7 +57,8 @@ func init() { userCmd.AddCommand(renameUserCmd) usernameAndIDFlag(renameUserCmd) renameUserCmd.Flags().StringP("new-name", "r", "", "New username") - renameNodeCmd.MarkFlagRequired("new-name") + + _ = renameNodeCmd.MarkFlagRequired("new-name") } var errMissingParameter = errors.New("missing parameters") @@ -95,7 +102,7 @@ var createUserCmd = &cobra.Command{ } if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" { - if _, err := url.Parse(pictureURL); err != nil { + if _, err := url.Parse(pictureURL); err != nil { //nolint:noinlineerr ErrorOutput( err, fmt.Sprintf( @@ -149,7 +156,7 @@ var destroyUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("multiple users match query, specify an ID") + err := errMultipleUsersMatch ErrorOutput( err, "Error: "+status.Convert(err).Message(), @@ -277,7 +284,7 @@ var renameUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("multiple users match query, specify an ID") + err := errMultipleUsersMatch ErrorOutput( err, "Error: "+status.Convert(err).Message(), diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index c5f5923d..2cfbc466 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -58,7 +58,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout) grpcOptions := []grpc.DialOption{ - grpc.WithBlock(), + grpc.WithBlock(), //nolint:staticcheck // SA1019: deprecated but supported in 1.x } address := cfg.CLI.Address @@ -82,6 +82,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g Msgf("Unable to read/write to headscale socket, do you have the correct permissions?") } } + socket.Close() grpcOptions = append( @@ -95,6 +96,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g if apiKey == "" { log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set") } + grpcOptions = append(grpcOptions, grpc.WithPerRPCCredentials(tokenAuth{ token: apiKey, @@ -120,7 +122,8 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g } log.Trace().Caller().Str(zf.Address, address).Msg("connecting via gRPC") - conn, err := grpc.DialContext(ctx, address, grpcOptions...) + + conn, err := grpc.DialContext(ctx, address, grpcOptions...) //nolint:staticcheck // SA1019: deprecated but supported in 1.x if err != nil { log.Fatal().Caller().Err(err).Msgf("could not connect: %v", err) os.Exit(-1) // we get here if logging is suppressed (i.e., json output) @@ -132,8 +135,11 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g } func output(result any, override string, outputFormat string) string { - var jsonBytes []byte - var err error + var ( + jsonBytes []byte + err error + ) + switch outputFormat { case "json": jsonBytes, err = json.MarshalIndent(result, "", "\t") diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index fa17bf6d..679f082e 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -12,6 +12,7 @@ import ( func main() { var colors bool + switch l := termcolor.SupportLevel(os.Stderr); l { case termcolor.Level16M: colors = true diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 2a9fbce6..01eb09b2 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -14,9 +14,7 @@ import ( ) func TestConfigFileLoading(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "headscale") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() path, err := os.Getwd() require.NoError(t, err) @@ -48,9 +46,7 @@ func TestConfigFileLoading(t *testing.T) { } func TestConfigLoading(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "headscale") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() path, err := os.Getwd() require.NoError(t, err) diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 0a95180f..26db49ae 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -25,7 +25,7 @@ func cleanupBeforeTest(ctx context.Context) error { return fmt.Errorf("cleaning stale test containers: %w", err) } - if err := pruneDockerNetworks(ctx); err != nil { + if err := pruneDockerNetworks(ctx); err != nil { //nolint:noinlineerr return fmt.Errorf("pruning networks: %w", err) } @@ -55,7 +55,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI // killTestContainers terminates and removes all test containers. func killTestContainers(ctx context.Context) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -69,8 +69,10 @@ func killTestContainers(ctx context.Context) error { } removed := 0 + for _, cont := range containers { shouldRemove := false + for _, name := range cont.Names { if strings.Contains(name, "headscale-test-suite") || strings.Contains(name, "hs-") || @@ -107,7 +109,7 @@ func killTestContainers(ctx context.Context) error { // This function filters containers by the hi.run-id label to only affect containers // belonging to the specified test run, leaving other concurrent test runs untouched. func killTestContainersByRunID(ctx context.Context, runID string) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -149,7 +151,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error { // This is useful for cleaning up leftover containers from previous crashed or interrupted test runs // without interfering with currently running concurrent tests. func cleanupStaleTestContainers(ctx context.Context) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -223,7 +225,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container // pruneDockerNetworks removes unused Docker networks. func pruneDockerNetworks(ctx context.Context) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -245,7 +247,7 @@ func pruneDockerNetworks(ctx context.Context) error { // cleanOldImages removes test-related and old dangling Docker images. func cleanOldImages(ctx context.Context) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -259,8 +261,10 @@ func cleanOldImages(ctx context.Context) error { } removed := 0 + for _, img := range images { shouldRemove := false + for _, tag := range img.RepoTags { if strings.Contains(tag, "hs-") || strings.Contains(tag, "headscale-integration") || @@ -295,18 +299,19 @@ func cleanOldImages(ctx context.Context) error { // cleanCacheVolume removes the Docker volume used for Go module cache. func cleanCacheVolume(ctx context.Context) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } defer cli.Close() volumeName := "hs-integration-go-cache" + err = cli.VolumeRemove(ctx, volumeName, true) if err != nil { - if errdefs.IsNotFound(err) { + if errdefs.IsNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional fmt.Printf("Go module cache volume not found: %s\n", volumeName) - } else if errdefs.IsConflict(err) { + } else if errdefs.IsConflict(err) { //nolint:staticcheck // SA1019: deprecated but functional fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName) } else { fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 5cccd50b..060057a9 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -22,15 +22,20 @@ import ( "github.com/juanfont/headscale/integration/dockertestutil" ) +const defaultDirPerm = 0o755 + var ( ErrTestFailed = errors.New("test failed") ErrUnexpectedContainerWait = errors.New("unexpected end of container wait") ErrNoDockerContext = errors.New("no docker context found") + ErrMemoryLimitViolations = errors.New("container(s) exceeded memory limits") ) // runTestContainer executes integration tests in a Docker container. +// +//nolint:gocyclo // complex test orchestration function func runTestContainer(ctx context.Context, config *RunConfig) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -52,7 +57,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } const dirPerm = 0o755 - if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { + if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { //nolint:noinlineerr return fmt.Errorf("creating logs directory: %w", err) } @@ -60,7 +65,9 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if config.Verbose { log.Printf("Running pre-test cleanup...") } - if err := cleanupBeforeTest(ctx); err != nil && config.Verbose { + + err := cleanupBeforeTest(ctx) + if err != nil && config.Verbose { log.Printf("Warning: pre-test cleanup failed: %v", err) } } @@ -71,7 +78,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } imageName := "golang:" + config.GoVersion - if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { + if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { //nolint:noinlineerr return fmt.Errorf("ensuring image availability: %w", err) } @@ -84,7 +91,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { log.Printf("Created container: %s", resp.ID) } - if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { + if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { //nolint:noinlineerr return fmt.Errorf("starting container: %w", err) } @@ -95,13 +102,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection for container resource monitoring (if enabled) var statsCollector *StatsCollector + if config.Stats { var err error - statsCollector, err = NewStatsCollector() + + statsCollector, err = NewStatsCollector(ctx) if err != nil { if config.Verbose { log.Printf("Warning: failed to create stats collector: %v", err) } + statsCollector = nil } @@ -110,7 +120,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection immediately - no need for complex retry logic // The new implementation monitors Docker events and will catch containers as they start - if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil { + err := statsCollector.StartCollection(ctx, runID, config.Verbose) + if err != nil { if config.Verbose { log.Printf("Warning: failed to start stats collection: %v", err) } @@ -122,12 +133,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { exitCode, err := streamAndWait(ctx, cli, resp.ID) // Ensure all containers have finished and logs are flushed before extracting artifacts - if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose { + waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose) + if waitErr != nil && config.Verbose { log.Printf("Warning: failed to wait for container finalization: %v", waitErr) } // Extract artifacts from test containers before cleanup - if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { + if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { //nolint:noinlineerr log.Printf("Warning: failed to extract artifacts from containers: %v", err) } @@ -140,12 +152,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if len(violations) > 0 { log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") log.Printf("=================================") + for _, violation := range violations { log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) } - return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) + return fmt.Errorf("test failed: %d %w", len(violations), ErrMemoryLimitViolations) } } @@ -347,6 +360,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC maxWaitTime := 10 * time.Second checkInterval := 500 * time.Millisecond timeout := time.After(maxWaitTime) + ticker := time.NewTicker(checkInterval) defer ticker.Stop() @@ -356,6 +370,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction") } + return nil case <-ticker.C: allFinalized := true @@ -366,12 +381,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err) } + continue } // Check if container is in a final state if !isContainerFinalized(inspect.State) { allFinalized = false + if verbose { log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) } @@ -384,6 +401,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("All test containers finalized, ready for artifact extraction") } + return nil } } @@ -400,13 +418,15 @@ func isContainerFinalized(state *container.State) bool { func findProjectRoot(startPath string) string { current := startPath for { - if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { + if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { //nolint:noinlineerr return current } + parent := filepath.Dir(current) if parent == current { return startPath } + current = parent } } @@ -416,6 +436,7 @@ func boolToInt(b bool) int { if b { return 1 } + return 0 } @@ -428,13 +449,14 @@ type DockerContext struct { } // createDockerClient creates a Docker client with context detection. -func createDockerClient() (*client.Client, error) { - contextInfo, err := getCurrentDockerContext() +func createDockerClient(ctx context.Context) (*client.Client, error) { + contextInfo, err := getCurrentDockerContext(ctx) if err != nil { return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) } var clientOpts []client.Opt + clientOpts = append(clientOpts, client.WithAPIVersionNegotiation()) if contextInfo != nil { @@ -444,6 +466,7 @@ func createDockerClient() (*client.Client, error) { if runConfig.Verbose { log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) } + clientOpts = append(clientOpts, client.WithHost(host)) } } @@ -458,15 +481,16 @@ func createDockerClient() (*client.Client, error) { } // getCurrentDockerContext retrieves the current Docker context information. -func getCurrentDockerContext() (*DockerContext, error) { - cmd := exec.Command("docker", "context", "inspect") +func getCurrentDockerContext(ctx context.Context) (*DockerContext, error) { + cmd := exec.CommandContext(ctx, "docker", "context", "inspect") + output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("getting docker context: %w", err) } var contexts []DockerContext - if err := json.Unmarshal(output, &contexts); err != nil { + if err := json.Unmarshal(output, &contexts); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("parsing docker context: %w", err) } @@ -486,11 +510,12 @@ func getDockerSocketPath() string { // checkImageAvailableLocally checks if the specified Docker image is available locally. func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) { - _, _, err := cli.ImageInspectWithRaw(ctx, imageName) + _, _, err := cli.ImageInspectWithRaw(ctx, imageName) //nolint:staticcheck // SA1019: deprecated but functional if err != nil { - if client.IsErrNotFound(err) { + if client.IsErrNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional return false, nil } + return false, fmt.Errorf("inspecting image %s: %w", imageName, err) } @@ -509,6 +534,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if verbose { log.Printf("Image %s is available locally", imageName) } + return nil } @@ -533,6 +559,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if err != nil { return fmt.Errorf("reading pull output: %w", err) } + log.Printf("Image %s pulled successfully", imageName) } @@ -547,9 +574,11 @@ func listControlFiles(logsDir string) { return } - var logFiles []string - var dataFiles []string - var dataDirs []string + var ( + logFiles []string + dataFiles []string + dataDirs []string + ) for _, entry := range entries { name := entry.Name() @@ -578,6 +607,7 @@ func listControlFiles(logsDir string) { if len(logFiles) > 0 { log.Printf("Headscale logs:") + for _, file := range logFiles { log.Printf(" %s", file) } @@ -585,9 +615,11 @@ func listControlFiles(logsDir string) { if len(dataFiles) > 0 || len(dataDirs) > 0 { log.Printf("Headscale data:") + for _, file := range dataFiles { log.Printf(" %s", file) } + for _, dir := range dataDirs { log.Printf(" %s/", dir) } @@ -596,7 +628,7 @@ func listControlFiles(logsDir string) { // extractArtifactsFromContainers collects container logs and files from the specific test run. func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return fmt.Errorf("creating Docker client: %w", err) } @@ -612,9 +644,11 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose) extractedCount := 0 + for _, cont := range currentTestContainers { // Extract container logs and tar files - if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { + err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose) + if err != nil { if verbose { log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err) } @@ -622,6 +656,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi if verbose { log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12]) } + extractedCount++ } } @@ -645,11 +680,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // Find the test container to get its run ID label var runID string + for _, cont := range containers { if cont.ID == testContainerID { if cont.Labels != nil { runID = cont.Labels["hi.run-id"] } + break } } @@ -690,18 +727,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // extractContainerArtifacts saves logs and tar files from a container. func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { // Ensure the logs directory exists - if err := os.MkdirAll(logsDir, 0o755); err != nil { + err := os.MkdirAll(logsDir, defaultDirPerm) + if err != nil { return fmt.Errorf("creating logs directory: %w", err) } // Extract container logs - if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { + err = extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose) + if err != nil { return fmt.Errorf("extracting logs: %w", err) } // Extract tar files for headscale containers only if strings.HasPrefix(containerName, "hs-") { - if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { + err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose) + if err != nil { if verbose { log.Printf("Warning: failed to extract files from %s: %v", containerName, err) } @@ -741,12 +781,12 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID, } // Write stdout logs - if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { + if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable return fmt.Errorf("writing stdout log: %w", err) } // Write stderr logs - if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { + if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable return fmt.Errorf("writing stderr log: %w", err) } diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 8af6051f..1791d66d 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -38,13 +38,13 @@ func runDoctorCheck(ctx context.Context) error { } // Check 3: Go installation - results = append(results, checkGoInstallation()) + results = append(results, checkGoInstallation(ctx)) // Check 4: Git repository - results = append(results, checkGitRepository()) + results = append(results, checkGitRepository(ctx)) // Check 5: Required files - results = append(results, checkRequiredFiles()) + results = append(results, checkRequiredFiles(ctx)) // Display results displayDoctorResults(results) @@ -86,7 +86,7 @@ func checkDockerBinary() DoctorResult { // checkDockerDaemon verifies Docker daemon is running and accessible. func checkDockerDaemon(ctx context.Context) DoctorResult { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return DoctorResult{ Name: "Docker Daemon", @@ -124,8 +124,8 @@ func checkDockerDaemon(ctx context.Context) DoctorResult { } // checkDockerContext verifies Docker context configuration. -func checkDockerContext(_ context.Context) DoctorResult { - contextInfo, err := getCurrentDockerContext() +func checkDockerContext(ctx context.Context) DoctorResult { + contextInfo, err := getCurrentDockerContext(ctx) if err != nil { return DoctorResult{ Name: "Docker Context", @@ -155,7 +155,7 @@ func checkDockerContext(_ context.Context) DoctorResult { // checkDockerSocket verifies Docker socket accessibility. func checkDockerSocket(ctx context.Context) DoctorResult { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return DoctorResult{ Name: "Docker Socket", @@ -192,7 +192,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult { // checkGolangImage verifies the golang Docker image is available locally or can be pulled. func checkGolangImage(ctx context.Context) DoctorResult { - cli, err := createDockerClient() + cli, err := createDockerClient(ctx) if err != nil { return DoctorResult{ Name: "Golang Image", @@ -251,7 +251,7 @@ func checkGolangImage(ctx context.Context) DoctorResult { } // checkGoInstallation verifies Go is installed and working. -func checkGoInstallation() DoctorResult { +func checkGoInstallation(ctx context.Context) DoctorResult { _, err := exec.LookPath("go") if err != nil { return DoctorResult{ @@ -265,7 +265,8 @@ func checkGoInstallation() DoctorResult { } } - cmd := exec.Command("go", "version") + cmd := exec.CommandContext(ctx, "go", "version") + output, err := cmd.Output() if err != nil { return DoctorResult{ @@ -285,8 +286,9 @@ func checkGoInstallation() DoctorResult { } // checkGitRepository verifies we're in a git repository. -func checkGitRepository() DoctorResult { - cmd := exec.Command("git", "rev-parse", "--git-dir") +func checkGitRepository(ctx context.Context) DoctorResult { + cmd := exec.CommandContext(ctx, "git", "rev-parse", "--git-dir") + err := cmd.Run() if err != nil { return DoctorResult{ @@ -308,7 +310,7 @@ func checkGitRepository() DoctorResult { } // checkRequiredFiles verifies required files exist. -func checkRequiredFiles() DoctorResult { +func checkRequiredFiles(ctx context.Context) DoctorResult { requiredFiles := []string{ "go.mod", "integration/", @@ -316,9 +318,12 @@ func checkRequiredFiles() DoctorResult { } var missingFiles []string + for _, file := range requiredFiles { - cmd := exec.Command("test", "-e", file) - if err := cmd.Run(); err != nil { + cmd := exec.CommandContext(ctx, "test", "-e", file) + + err := cmd.Run() + if err != nil { missingFiles = append(missingFiles, file) } } @@ -350,6 +355,7 @@ func displayDoctorResults(results []DoctorResult) { for _, result := range results { var icon string + switch result.Status { case "PASS": icon = "✅" diff --git a/cmd/hi/main.go b/cmd/hi/main.go index baecc6f3..2bbfefe0 100644 --- a/cmd/hi/main.go +++ b/cmd/hi/main.go @@ -79,13 +79,18 @@ func main() { } func cleanAll(ctx context.Context) error { - if err := killTestContainers(ctx); err != nil { + err := killTestContainers(ctx) + if err != nil { return err } - if err := pruneDockerNetworks(ctx); err != nil { + + err = pruneDockerNetworks(ctx) + if err != nil { return err } - if err := cleanOldImages(ctx); err != nil { + + err = cleanOldImages(ctx) + if err != nil { return err } diff --git a/cmd/hi/run.go b/cmd/hi/run.go index 1694399d..132feb89 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error { if runConfig.Verbose { log.Printf("Running pre-flight system checks...") } - if err := runDoctorCheck(env.Context()); err != nil { + + err := runDoctorCheck(env.Context()) + if err != nil { return fmt.Errorf("pre-flight checks failed: %w", err) } @@ -66,9 +68,9 @@ func runIntegrationTest(env *command.Env) error { func detectGoVersion() string { goModPath := filepath.Join("..", "..", "go.mod") - if _, err := os.Stat("go.mod"); err == nil { + if _, err := os.Stat("go.mod"); err == nil { //nolint:noinlineerr goModPath = "go.mod" - } else if _, err := os.Stat("../../go.mod"); err == nil { + } else if _, err := os.Stat("../../go.mod"); err == nil { //nolint:noinlineerr goModPath = "../../go.mod" } @@ -94,8 +96,10 @@ func detectGoVersion() string { // splitLines splits a string into lines without using strings.Split. func splitLines(s string) []string { - var lines []string - var current string + var ( + lines []string + current string + ) for _, char := range s { if char == '\n' { diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index dcc4cc73..8ab8b7b3 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -18,6 +18,9 @@ import ( "github.com/docker/docker/client" ) +// ErrStatsCollectionAlreadyStarted is returned when trying to start stats collection that is already running. +var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started") + // ContainerStats represents statistics for a single container. type ContainerStats struct { ContainerID string @@ -44,8 +47,8 @@ type StatsCollector struct { } // NewStatsCollector creates a new stats collector instance. -func NewStatsCollector() (*StatsCollector, error) { - cli, err := createDockerClient() +func NewStatsCollector(ctx context.Context) (*StatsCollector, error) { + cli, err := createDockerClient(ctx) if err != nil { return nil, fmt.Errorf("creating Docker client: %w", err) } @@ -63,17 +66,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver defer sc.mutex.Unlock() if sc.collectionStarted { - return errors.New("stats collection already started") + return ErrStatsCollectionAlreadyStarted } sc.collectionStarted = true // Start monitoring existing containers sc.wg.Add(1) + go sc.monitorExistingContainers(ctx, runID, verbose) // Start Docker events monitoring for new containers sc.wg.Add(1) + go sc.monitorDockerEvents(ctx, runID, verbose) if verbose { @@ -87,10 +92,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver func (sc *StatsCollector) StopCollection() { // Check if already stopped without holding lock sc.mutex.RLock() + if !sc.collectionStarted { sc.mutex.RUnlock() return } + sc.mutex.RUnlock() // Signal stop to all goroutines @@ -114,6 +121,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s if verbose { log.Printf("Failed to list existing containers: %v", err) } + return } @@ -147,13 +155,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, case event := <-events: if event.Type == "container" && event.Action == "start" { // Get container details - containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) + containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) //nolint:staticcheck // SA1019: use Actor.ID if err != nil { continue } // Convert to types.Container format for consistency - cont := types.Container{ + cont := types.Container{ //nolint:staticcheck // SA1019: use container.Summary ID: containerInfo.ID, Names: []string{containerInfo.Name}, Labels: containerInfo.Config.Labels, @@ -167,13 +175,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, if verbose { log.Printf("Error in Docker events stream: %v", err) } + return } } } // shouldMonitorContainer determines if a container should be monitored. -func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { +func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { //nolint:staticcheck // SA1019: use container.Summary // Check if it has the correct run ID label if cont.Labels == nil || cont.Labels["hi.run-id"] != runID { return false @@ -213,6 +222,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI } sc.wg.Add(1) + go sc.collectStatsForContainer(ctx, containerID, verbose) } @@ -226,12 +236,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe if verbose { log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) } + return } defer statsResponse.Body.Close() decoder := json.NewDecoder(statsResponse.Body) - var prevStats *container.Stats + + var prevStats *container.Stats //nolint:staticcheck // SA1019: use StatsResponse for { select { @@ -240,12 +252,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe case <-ctx.Done(): return default: - var stats container.Stats - if err := decoder.Decode(&stats); err != nil { + var stats container.Stats //nolint:staticcheck // SA1019: use StatsResponse + + err := decoder.Decode(&stats) + if err != nil { // EOF is expected when container stops or stream ends if err.Error() != "EOF" && verbose { log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) } + return } @@ -261,8 +276,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe // Store the sample (skip first sample since CPU calculation needs previous stats) if prevStats != nil { // Get container stats reference without holding the main mutex - var containerStats *ContainerStats - var exists bool + var ( + containerStats *ContainerStats + exists bool + ) sc.mutex.RLock() containerStats, exists = sc.containers[containerID] @@ -286,7 +303,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe } // calculateCPUPercent calculates CPU usage percentage from Docker stats. -func calculateCPUPercent(prevStats, stats *container.Stats) float64 { +func calculateCPUPercent(prevStats, stats *container.Stats) float64 { //nolint:staticcheck // SA1019: use StatsResponse // CPU calculation based on Docker's implementation cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage) systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage) @@ -331,10 +348,12 @@ type StatsSummary struct { func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { // Take snapshot of container references without holding main lock long sc.mutex.RLock() + containerRefs := make([]*ContainerStats, 0, len(sc.containers)) for _, containerStats := range sc.containers { containerRefs = append(containerRefs, containerStats) } + sc.mutex.RUnlock() summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) @@ -384,23 +403,25 @@ func calculateStatsSummary(values []float64) StatsSummary { return StatsSummary{} } - min := values[0] - max := values[0] + minVal := values[0] + maxVal := values[0] sum := 0.0 for _, value := range values { - if value < min { - min = value + if value < minVal { + minVal = value } - if value > max { - max = value + + if value > maxVal { + maxVal = value } + sum += value } return StatsSummary{ - Min: min, - Max: max, + Min: minVal, + Max: maxVal, Average: sum / float64(len(values)), } } @@ -434,6 +455,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo } summaries := sc.GetSummary() + var violations []MemoryViolation for _, summary := range summaries { diff --git a/cmd/mapresponses/main.go b/cmd/mapresponses/main.go index 5d7ad07d..cf43a66c 100644 --- a/cmd/mapresponses/main.go +++ b/cmd/mapresponses/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" @@ -15,7 +16,10 @@ type MapConfig struct { Directory string `flag:"directory,Directory to read map responses from"` } -var mapConfig MapConfig +var ( + mapConfig MapConfig + errDirectoryRequired = errors.New("directory is required") +) func main() { root := command.C{ @@ -40,7 +44,7 @@ func main() { // runIntegrationTest executes the integration test workflow. func runOnline(env *command.Env) error { if mapConfig.Directory == "" { - return fmt.Errorf("directory is required") + return errDirectoryRequired } resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) @@ -57,5 +61,6 @@ func runOnline(env *command.Env) error { os.Stderr.Write(out) os.Stderr.Write([]byte("\n")) + return nil } diff --git a/flake.nix b/flake.nix index f3e8b41e..7e6214a9 100644 --- a/flake.nix +++ b/flake.nix @@ -27,7 +27,7 @@ let pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; buildGo = pkgs.buildGo125Module; - vendorHash = "sha256-jkeB9XUTEGt58fPOMpE4/e3+JQoMQTgf0RlthVBmfG0="; + vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0="; in { headscale = buildGo { diff --git a/hscontrol/app.go b/hscontrol/app.go index ed919bf0..abd29a45 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -115,6 +115,7 @@ var ( func NewHeadscale(cfg *types.Config) (*Headscale, error) { var err error + if profilingEnabled { runtime.SetBlockProfileRate(1) } @@ -142,6 +143,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if !ok { log.Error().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed") log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed because node not found in NodeStore") + return } @@ -157,10 +159,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.ephemeralGC = ephemeralGC var authProvider AuthProvider + authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + oidcProvider, err := NewAuthProviderOIDC( ctx, &app, @@ -177,17 +181,18 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { authProvider = oidcProvider } } + app.authProvider = authProvider if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS // TODO(kradalby): revisit why this takes a list. - var magicDNSDomains []dnsname.FQDN if cfg.PrefixV4 != nil { magicDNSDomains = append( magicDNSDomains, util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...) } + if cfg.PrefixV6 != nil { magicDNSDomains = append( magicDNSDomains, @@ -198,6 +203,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if app.cfg.TailcfgDNSConfig.Routes == nil { app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver) } + for _, d := range magicDNSDomains { app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil } @@ -232,6 +238,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if err != nil { return nil, err } + app.DERPServer = embeddedDERPServer } @@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { lastExpiryCheck := time.Unix(0, 0) derpTickerChan := make(<-chan time.Time) + if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 { derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) defer derpTicker.Stop() + derpTickerChan = derpTicker.C } @@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { return case <-expireTicker.C: - var expiredNodeChanges []change.Change - var changed bool + var ( + expiredNodeChanges []change.Change + changed bool + ) lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) @@ -287,11 +298,13 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { case <-derpTickerChan: log.Info().Msg("fetching DERPMap updates") - derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { + + derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { //nolint:contextcheck derpMap, err := derp.GetDERPMap(h.cfg.DERP) if err != nil { return nil, err } + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { region, _ := h.DERPServer.GenerateRegion() derpMap.Regions[region.RegionID] = ®ion @@ -303,6 +316,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { log.Error().Err(err).Msg("failed to build new DERPMap, retrying later") continue } + h.state.SetDERPMap(derpMap) h.Change(change.DERPMap()) @@ -311,6 +325,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { if !ok { continue } + h.cfg.TailcfgDNSConfig.ExtraRecords = records h.Change(change.ExtraRecords()) @@ -390,7 +405,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writeUnauthorized := func(statusCode int) { writer.WriteHeader(statusCode) - if _, err := writer.Write([]byte("Unauthorized")); err != nil { + + if _, err := writer.Write([]byte("Unauthorized")); err != nil { //nolint:noinlineerr log.Error().Err(err).Msg("writing HTTP response failed") } } @@ -401,6 +417,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler Str("client_address", req.RemoteAddr). Msg(`missing "Bearer " prefix in "Authorization" header`) writeUnauthorized(http.StatusUnauthorized) + return } @@ -412,6 +429,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler Str("client_address", req.RemoteAddr). Msg("failed to validate token") writeUnauthorized(http.StatusUnauthorized) + return } @@ -420,6 +438,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler Str("client_address", req.RemoteAddr). Msg("invalid token") writeUnauthorized(http.StatusUnauthorized) + return } @@ -431,7 +450,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler // and will remove it if it is not. func (h *Headscale) ensureUnixSocketIsAbsent() error { // File does not exist, all fine - if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { + if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { //nolint:noinlineerr return nil } @@ -455,6 +474,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) } + router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). Methods(http.MethodGet) @@ -484,8 +504,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { } // Serve launches the HTTP and gRPC server service Headscale and the API. +// +//nolint:gocyclo // complex server startup function func (h *Headscale) Serve() error { var err error + capver.CanOldCodeBeCleanedUp() if profilingEnabled { @@ -512,6 +535,7 @@ func (h *Headscale) Serve() error { Msg("Clients with a lower minimum version will be rejected") h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() defer h.mapBatcher.Close() @@ -545,6 +569,7 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() + ephmNodes := h.state.ListEphemeralNodes() for _, node := range ephmNodes.All() { h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) @@ -555,7 +580,9 @@ func (h *Headscale) Serve() error { if err != nil { return fmt.Errorf("setting up extrarecord manager: %w", err) } + h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() + go h.extraRecordMan.Run() defer h.extraRecordMan.Close() } @@ -564,6 +591,7 @@ func (h *Headscale) Serve() error { // records updates scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) defer scheduleCancel() + go h.scheduledTasks(scheduleCtx) if zl.GlobalLevel() == zl.TraceLevel { @@ -576,6 +604,7 @@ func (h *Headscale) Serve() error { errorGroup := new(errgroup.Group) ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -590,25 +619,26 @@ func (h *Headscale) Serve() error { } socketDir := filepath.Dir(h.cfg.UnixSocket) + err = util.EnsureDir(socketDir) if err != nil { return fmt.Errorf("setting up unix socket: %w", err) } - socketListener, err := net.Listen("unix", h.cfg.UnixSocket) + socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket) if err != nil { return fmt.Errorf("setting up gRPC socket: %w", err) } // Change socket permissions - if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { + if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { //nolint:noinlineerr return fmt.Errorf("changing gRPC socket permission: %w", err) } grpcGatewayMux := grpcRuntime.NewServeMux() // Make the grpc-gateway connect to grpc over socket - grpcGatewayConn, err := grpc.Dial( + grpcGatewayConn, err := grpc.Dial( //nolint:staticcheck // SA1019: deprecated but supported in 1.x h.cfg.UnixSocket, []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -659,8 +689,11 @@ func (h *Headscale) Serve() error { // https://github.com/soheilhy/cmux/issues/68 // https://github.com/soheilhy/cmux/issues/91 - var grpcServer *grpc.Server - var grpcListener net.Listener + var ( + grpcServer *grpc.Server + grpcListener net.Listener + ) + if tlsConfig != nil || h.cfg.GRPCAllowInsecure { log.Info().Msgf("enabling remote gRPC at %s", h.cfg.GRPCAddr) @@ -685,7 +718,7 @@ func (h *Headscale) Serve() error { v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) reflection.Register(grpcServer) - grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) + grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr) if err != nil { return fmt.Errorf("binding to TCP address: %w", err) } @@ -715,12 +748,14 @@ func (h *Headscale) Serve() error { } var httpListener net.Listener + if tlsConfig != nil { httpServer.TLSConfig = tlsConfig httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig) } else { - httpListener, err = net.Listen("tcp", h.cfg.Addr) + httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr) } + if err != nil { return fmt.Errorf("binding to TCP address: %w", err) } @@ -751,19 +786,24 @@ func (h *Headscale) Serve() error { log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)") } - var tailsqlContext context.Context + if tailsqlEnabled { if h.cfg.Database.Type != types.DatabaseSqlite { + //nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start log.Fatal(). Str("type", h.cfg.Database.Type). Msgf("tailsql only support %q", types.DatabaseSqlite) } + if tailsqlTSKey == "" { + //nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") } + tailsqlContext = context.Background() - go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) + + go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) //nolint:errcheck } // Handle common process-killing signals so we can gracefully shut down: @@ -774,6 +814,7 @@ func (h *Headscale) Serve() error { syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP) + sigFunc := func(c chan os.Signal) { // Wait for a SIGINT or SIGKILL: for { @@ -798,6 +839,7 @@ func (h *Headscale) Serve() error { default: info := func(msg string) { log.Info().Msg(msg) } + log.Info(). Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") @@ -854,6 +896,7 @@ func (h *Headscale) Serve() error { if debugHTTPListener != nil { debugHTTPListener.Close() } + httpListener.Close() grpcGatewayConn.Close() @@ -863,6 +906,7 @@ func (h *Headscale) Serve() error { // Close state connections info("closing state and database") + err = h.state.Close() if err != nil { log.Error().Err(err).Msg("failed to close state") @@ -875,6 +919,7 @@ func (h *Headscale) Serve() error { } } } + errorGroup.Go(func() error { sigFunc(sigc) @@ -886,6 +931,7 @@ func (h *Headscale) Serve() error { func (h *Headscale) getTLSSettings() (*tls.Config, error) { var err error + if h.cfg.TLS.LetsEncrypt.Hostname != "" { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn(). @@ -918,7 +964,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { // Configuration via autocert with HTTP-01. This requires listening on // port 80 for the certificate validation in addition to the headscale // service, which can be configured to run on any other port. - server := &http.Server{ Addr: h.cfg.TLS.LetsEncrypt.Listen, Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)), @@ -963,6 +1008,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { dir := filepath.Dir(path) + err := util.EnsureDir(dir) if err != nil { return nil, fmt.Errorf("ensuring private key directory: %w", err) @@ -981,6 +1027,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { err, ) } + err = os.WriteFile(path, machineKeyStr, privateKeyFileMode) if err != nil { return nil, fmt.Errorf( @@ -998,7 +1045,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { trimmedPrivateKey := strings.TrimSpace(string(privateKey)) var machineKey key.MachinePrivate - if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { + if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("parsing private key: %w", err) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index a1ba0a3b..1aa40c7b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -20,8 +20,8 @@ import ( ) type AuthProvider interface { - RegisterHandler(http.ResponseWriter, *http.Request) - AuthURL(types.RegistrationID) string + RegisterHandler(w http.ResponseWriter, r *http.Request) + AuthURL(regID types.RegistrationID) string } func (h *Headscale) handleRegister( @@ -51,6 +51,7 @@ func (h *Headscale) handleRegister( if err != nil { return nil, fmt.Errorf("handling logout: %w", err) } + if resp != nil { return resp, nil } @@ -132,7 +133,7 @@ func (h *Headscale) handleRegister( } // handleLogout checks if the [tailcfg.RegisterRequest] is a -// logout attempt from a node. If the node is not attempting to +// logout attempt from a node. If the node is not attempting to. func (h *Headscale) handleLogout( node types.NodeView, req tailcfg.RegisterRequest, @@ -159,6 +160,7 @@ func (h *Headscale) handleLogout( Interface("reg.req", req). Bool("unexpected", true). Msg("Node key expired, forcing re-authentication") + return &tailcfg.RegisterResponse{ NodeKeyExpired: true, MachineAuthorized: false, @@ -275,6 +277,7 @@ func (h *Headscale) waitForFollowup( // registration is expired in the cache, instruct the client to try a new registration return h.reqToNewRegisterResponse(req, machineKey) } + return nodeToRegisterResponse(node.View()), nil } } @@ -340,6 +343,7 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } + var perr types.PAKError if errors.As(err, &perr) { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) @@ -351,7 +355,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // If node is not valid, it means an ephemeral node was deleted during logout if !node.Valid() { h.Change(changed) - return nil, nil + return nil, nil //nolint:nilnil // intentional: no node to return when ephemeral deleted } // This is a bit of a back and forth, but we have a bit of a chicken and egg @@ -430,6 +434,7 @@ func (h *Headscale) handleRegisterInteractive( Str("generated.hostname", hostname). Msg("Received registration request with empty hostname, generated default") } + hostinfo.Hostname = hostname nodeToRegister := types.NewRegisterNode( diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 7967eee3..d28ed565 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -2,6 +2,7 @@ package hscontrol import ( "context" + "errors" "fmt" "net/url" "strings" @@ -16,14 +17,16 @@ import ( "tailscale.com/types/key" ) -// Interactive step type constants +// Interactive step type constants. const ( stepTypeInitialRequest = "initial_request" stepTypeAuthCompletion = "auth_completion" stepTypeFollowupRequest = "followup_request" ) -// interactiveStep defines a step in the interactive authentication workflow +var errNodeNotFoundAfterSetup = errors.New("node not found after setup") + +// interactiveStep defines a step in the interactive authentication workflow. type interactiveStep struct { stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest expectAuthURL bool @@ -31,6 +34,7 @@ type interactiveStep struct { callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked } +//nolint:gocyclo // comprehensive test function with many scenarios func TestAuthenticationFlows(t *testing.T) { // Shared test keys for consistent behavior across test cases machineKey1 := key.NewMachine() @@ -68,13 +72,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Pre-auth keys enable automated/headless node registration without user interaction { name: "preauth_key_valid_new_node", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper // not a test helper, inline closure user := app.state.CreateUserForTest("preauth-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -89,9 +94,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper // not a test helper, inline closure assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) assert.NotEmpty(t, resp.User.DisplayName) @@ -110,7 +115,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Reusable keys allow multiple machines to join using one key (useful for fleet deployments) { name: "preauth_key_reusable_multiple_nodes", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("reusable-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -129,6 +134,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -154,15 +160,16 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) // Verify both nodes exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.True(t, found2) assert.Equal(t, "reusable-node-1", node1.Hostname()) @@ -177,7 +184,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Single-use keys provide security by preventing key reuse after initial registration { name: "preauth_key_single_use_exhausted", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("single-use-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) @@ -196,6 +203,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -221,12 +229,13 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, wantError: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper // First node should exist, second should not _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) _, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.False(t, found2) }, @@ -239,7 +248,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Invalid keys must be rejected to prevent unauthorized node registration { name: "preauth_key_invalid", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper return "invalid-key-12345", nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -254,7 +263,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -265,13 +274,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Ephemeral nodes auto-cleanup when disconnected, useful for temporary/CI environments { name: "preauth_key_ephemeral_node", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("ephemeral-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -286,9 +296,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -311,7 +321,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Interactive flow is the standard user-facing authentication method for new nodes { name: "full_interactive_workflow_new_node", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -323,7 +333,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -339,7 +349,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates handling of requests without Auth field, same as empty auth { name: "interactive_workflow_no_auth_struct", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -352,7 +362,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -372,7 +382,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Nodes signal logout by setting expiry to past time; system updates node state accordingly { name: "existing_node_logout", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("logout-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -391,6 +401,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -400,8 +411,10 @@ func TestAuthenticationFlows(t *testing.T) { // Wait for node to be available in NodeStore with debug info var attemptCount int + require.EventuallyWithT(t, func(c *assert.CollectT) { attemptCount++ + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) if assert.True(c, found, "node should be available in NodeStore") { t.Logf("Node found in NodeStore after %d attempts", attemptCount) @@ -417,10 +430,10 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, wantExpired: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.True(t, resp.NodeKeyExpired) }, @@ -432,7 +445,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Machine key must match to prevent node hijacking/impersonation { name: "existing_node_machine_key_mismatch", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("mismatch-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -451,6 +464,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -471,7 +485,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key + machineKey: machineKey2.Public, // Different machine key wantError: true, }, // TEST: Existing node cannot extend expiry without re-auth @@ -481,7 +495,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Prevents nodes from extending their own lifetime; must re-authenticate { name: "existing_node_key_extension_not_allowed", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("extend-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -500,6 +514,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -520,7 +535,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Expired node must re-authenticate @@ -530,7 +545,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Expired nodes must go through authentication again for security { name: "existing_node_expired_forces_reauth", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("reauth-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -549,25 +564,31 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err } // Wait for node to be available in NodeStore - var node types.NodeView - var found bool + var ( + node types.NodeView + found bool + ) + require.EventuallyWithT(t, func(c *assert.CollectT) { node, found = app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(c, found, "node should be available in NodeStore") }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + if !found { - return "", fmt.Errorf("node not found after setup") + return "", errNodeNotFoundAfterSetup } // Expire the node expiredTime := time.Now().Add(-1 * time.Hour) _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + return "", err }, request: func(_ string) tailcfg.RegisterRequest { @@ -577,9 +598,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), // Future expiry } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantExpired: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.NodeKeyExpired) assert.False(t, resp.MachineAuthorized) }, @@ -591,7 +612,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Ephemeral nodes should not persist after logout; auto-cleanup { name: "ephemeral_node_logout_deletion", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("ephemeral-logout-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) @@ -610,6 +631,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -630,9 +652,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), // Logout } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantExpired: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.NodeKeyExpired) assert.False(t, resp.MachineAuthorized) @@ -653,7 +675,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Followup mechanism allows nodes to poll/wait for auth completion { name: "followup_registration_success", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper regID, err := types.NewRegistrationID() if err != nil { return "", err @@ -673,6 +695,7 @@ func TestAuthenticationFlows(t *testing.T) { // and handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") + node := app.state.CreateNodeForTest(user, "followup-success-node") registered <- node }() @@ -685,9 +708,9 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) }, @@ -699,7 +722,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Prevents indefinite waiting; nodes must retry if auth takes too long { name: "followup_registration_timeout", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper regID, err := types.NewRegistrationID() if err != nil { return "", err @@ -723,7 +746,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Invalid followup URL is rejected @@ -733,7 +756,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates URL format to prevent errors and potential exploits { name: "followup_invalid_url", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper return "invalid://url[malformed", nil }, request: func(followupURL string) tailcfg.RegisterRequest { @@ -742,7 +765,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Non-existent registration ID is rejected @@ -752,7 +775,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Registration must exist in cache; prevents invalid/expired registrations { name: "followup_registration_not_found", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper return "http://localhost:8080/register/nonexistent-id", nil }, request: func(followupURL string) tailcfg.RegisterRequest { @@ -761,7 +784,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -775,13 +798,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Defensive code prevents errors from missing hostnames; generates sensible default { name: "empty_hostname", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("empty-hostname-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -796,9 +820,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.MachineAuthorized) // Node should be created with generated hostname @@ -814,13 +838,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Defensive code prevents nil pointer panics; creates valid default hostinfo { name: "nil_hostinfo", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("nil-hostinfo-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -833,9 +858,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper //nolint:thelper assert.True(t, resp.MachineAuthorized) // Node should be created with generated hostname from defensive code @@ -857,7 +882,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Expired keys must be rejected to maintain security and key lifecycle management { name: "preauth_key_expired", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("expired-pak-user") expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago @@ -865,6 +890,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -879,7 +905,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -890,7 +916,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Pre-auth keys can enforce ACL policies on nodes during registration { name: "preauth_key_with_acl_tags", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper //nolint:thelper user := app.state.CreateUserForTest("tagged-pak-user") tags := []string{"tag:server", "tag:database"} @@ -898,6 +924,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -912,9 +939,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -922,6 +949,7 @@ func TestAuthenticationFlows(t *testing.T) { node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found) assert.Equal(t, "tagged-pak-node", node.Hostname()) + if node.AuthKey().Valid() { assert.NotEmpty(t, node.AuthKey().Tags()) } @@ -938,7 +966,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: PreAuthKey nodes get their tags from the key itself, not from client requests { name: "preauth_key_rejects_request_tags", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper t.Helper() user := app.state.CreateUserForTest("pak-requesttags-user") @@ -974,7 +1002,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Tags-as-identity: PreAuthKey tags are authoritative, client cannot override { name: "tagged_preauth_key_rejects_client_request_tags", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper t.Helper() user := app.state.CreateUserForTest("tagged-pak-clienttags-user") @@ -1012,7 +1040,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Allows nodes to refresh authentication using pre-auth keys { name: "existing_node_reauth_with_new_authkey", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("reauth-user") // First, register with initial auth key @@ -1031,6 +1059,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1047,6 +1076,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(newAuthKey string) tailcfg.RegisterRequest { @@ -1061,9 +1091,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -1080,7 +1110,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Allows expired nodes to re-authenticate without pre-auth keys { name: "existing_node_reauth_interactive_flow", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("interactive-reauth-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -1099,6 +1129,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1124,9 +1155,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuthURL: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.Contains(t, resp.AuthURL, "register/") assert.False(t, resp.MachineAuthorized) }, @@ -1142,7 +1173,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Same machine key means same physical device; node key rotation updates, doesn't duplicate { name: "node_key_rotation_same_machine", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("rotation-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -1161,6 +1192,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1177,6 +1209,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pakRotation.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1191,9 +1224,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -1219,13 +1252,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Zero time is valid Go default; should be handled gracefully { name: "malformed_expiry_zero_time", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("zero-expiry-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1240,9 +1274,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Time{}, // Zero time } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) // Node should be created with default expiry handling @@ -1258,13 +1292,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Defensive code enforces DNS label limit (RFC 1123); prevents errors { name: "malformed_hostinfo_invalid_data", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("invalid-hostinfo-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1286,9 +1321,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) // Node should be created even with malformed hostinfo @@ -1309,7 +1344,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Nil response means cache expired - give client new AuthURL instead of error { name: "followup_registration_node_nil_response", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper regID, err := types.NewRegistrationID() if err != nil { return "", err @@ -1342,9 +1377,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: false, // Should not be authorized yet - needs to use new AuthURL - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Should get a new AuthURL, not an error assert.NotEmpty(t, resp.AuthURL, "should receive new AuthURL when cache returns nil") assert.Contains(t, resp.AuthURL, "/register/", "AuthURL should contain registration path") @@ -1358,7 +1393,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Path validation prevents processing of corrupted/invalid URLs { name: "followup_registration_malformed_path", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "http://localhost:8080/register/", nil // Missing registration ID }, request: func(followupURL string) tailcfg.RegisterRequest { @@ -1367,7 +1402,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Wrong followup path format is rejected @@ -1377,7 +1412,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Strict path validation ensures only valid registration URLs accepted { name: "followup_registration_wrong_path_format", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "http://localhost:8080/wrong/path/format", nil }, request: func(followupURL string) tailcfg.RegisterRequest { @@ -1386,7 +1421,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -1402,7 +1437,7 @@ func TestAuthenticationFlows(t *testing.T) { // claim tags via RequestTags - they must use a tagged PreAuthKey instead. { name: "interactive_workflow_with_custom_hostinfo", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -1417,7 +1452,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1425,10 +1460,11 @@ func TestAuthenticationFlows(t *testing.T) { }, validateCompleteResponse: true, expectedAuthURLPattern: "/register/", - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Verify custom hostinfo was preserved through interactive workflow node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be found after interactive registration") + if found { assert.Equal(t, "custom-interactive-node", node.Hostname()) assert.Equal(t, "linux", node.Hostinfo().OS()) @@ -1448,13 +1484,14 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Usage tracking enables monitoring and auditing of pre-auth key usage { name: "preauth_key_usage_count_tracking", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("usage-count-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // Single use if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1469,9 +1506,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -1495,7 +1532,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Registration IDs must be unique and valid for cache lookup { name: "interactive_workflow_registration_id_generation", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -1508,7 +1545,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1516,10 +1553,11 @@ func TestAuthenticationFlows(t *testing.T) { }, validateCompleteResponse: true, expectedAuthURLPattern: "/register/", - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Verify registration ID was properly generated and used node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered after interactive workflow") + if found { assert.Equal(t, "registration-id-test-node", node.Hostname()) assert.Equal(t, "test-os", node.Hostinfo().OS()) @@ -1528,13 +1566,14 @@ func TestAuthenticationFlows(t *testing.T) { }, { name: "concurrent_registration_same_node_key", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("concurrent-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1549,9 +1588,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -1568,7 +1607,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Request expiry overrides key expiry; allows logout with valid key { name: "auth_key_with_future_expiry_past_request_expiry", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("future-expiry-user") // Auth key expires in the future expiry := time.Now().Add(48 * time.Hour) @@ -1577,6 +1616,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1592,9 +1632,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(12 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -1611,7 +1651,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates device reassignment scenarios where a machine moves between users { name: "reauth_existing_node_different_user_auth_key", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Create two users user1 := app.state.CreateUserForTest("user1-context") user2 := app.state.CreateUserForTest("user2-context") @@ -1632,6 +1672,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1648,6 +1689,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(user2AuthKey string) tailcfg.RegisterRequest { @@ -1662,9 +1704,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.False(t, resp.NodeKeyExpired) @@ -1692,7 +1734,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Same physical machine can have separate node identities per user { name: "interactive_reauth_existing_node_different_user_creates_new_node", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Create user1 and register a node with auth key user1 := app.state.CreateUserForTest("interactive-user-1") @@ -1712,6 +1754,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -1735,14 +1778,14 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key + machineKey: machineKey1.Public, // Same machine key requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, }, validateCompleteResponse: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // User1's original node should STILL exist (not transferred) node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) require.True(t, found1, "user1's original node should still exist") @@ -1769,7 +1812,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates new reqToNewRegisterResponse functionality - prevents client getting stuck { name: "followup_request_after_cache_expiry", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Generate a registration ID that doesn't exist in cache // This simulates an expired/missing cache entry regID, err := types.NewRegistrationID() @@ -1789,9 +1832,9 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: false, // Should not be authorized yet - needs to use new AuthURL - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Should get a new AuthURL, not an error assert.NotEmpty(t, resp.AuthURL, "should receive new AuthURL when registration expired") assert.Contains(t, resp.AuthURL, "/register/", "AuthURL should contain registration path") @@ -1799,13 +1842,13 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the response contains a valid registration URL authURL, err := url.Parse(resp.AuthURL) - assert.NoError(t, err, "AuthURL should be a valid URL") + assert.NoError(t, err, "AuthURL should be a valid URL") //nolint:testifylint // inside closure, uses assert pattern assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/") // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") newRegID, err := types.RegistrationIDFromString(newRegIDStr) - assert.NoError(t, err, "should be able to parse new registration ID") + assert.NoError(t, err, "should be able to parse new registration ID") //nolint:testifylint // inside closure // Verify new registration entry exists in cache _, found := app.state.GetRegistrationCacheEntry(newRegID) @@ -1819,7 +1862,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Edge case: current time should be treated as expired { name: "logout_with_exactly_now_expiry", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper user := app.state.CreateUserForTest("exact-now-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -1838,6 +1881,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1858,10 +1902,10 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now(), // Exactly now (edge case between past and future) } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, wantExpired: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper assert.True(t, resp.MachineAuthorized) assert.True(t, resp.NodeKeyExpired) @@ -1878,7 +1922,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Prevents cache bloat from abandoned registrations { name: "interactive_workflow_timeout_cleanup", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -1890,14 +1934,14 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, // NOTE: No auth_completion step - simulates timeout scenario }, validateRegistrationCache: true, // should be cleaned up eventually - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Verify AuthURL was generated but registration not completed assert.Contains(t, resp.AuthURL, "/register/") assert.False(t, resp.MachineAuthorized) @@ -1912,7 +1956,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Same physical machine can have separate node identities per user { name: "interactive_workflow_with_existing_node_different_user_creates_new_node", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // First create a node under user1 user1 := app.state.CreateUserForTest("existing-user-1") @@ -1932,6 +1976,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -1955,14 +2000,14 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, }, validateCompleteResponse: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // User1's original node with nodeKey1 should STILL exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) require.True(t, found1, "user1's original node with nodeKey1 should still exist") @@ -1989,7 +2034,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates followup URLs to prevent errors { name: "interactive_workflow_malformed_followup_url", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -2001,12 +2046,12 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, }, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Test malformed followup URLs after getting initial AuthURL authURL := resp.AuthURL assert.Contains(t, authURL, "/register/") @@ -2043,7 +2088,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: System should handle concurrent interactive flows without conflicts { name: "interactive_workflow_concurrent_registrations", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -2055,8 +2100,8 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + machineKey: machineKey1.Public, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // This test validates concurrent interactive registration attempts assert.Contains(t, resp.AuthURL, "/register/") @@ -2097,6 +2142,7 @@ func TestAuthenticationFlows(t *testing.T) { // Collect results - at least one should succeed successCount := 0 + for range numConcurrent { select { case err := <-results: @@ -2119,7 +2165,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Interactive flow creates new nodes with new users; doesn't rotate existing nodes { name: "interactive_workflow_node_key_rotation", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper // Register initial node user := app.state.CreateUserForTest("rotation-user") @@ -2162,14 +2208,14 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, }, validateCompleteResponse: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // User1's original node with nodeKey1 should STILL exist oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public()) require.True(t, foundOld, "user1's original node with nodeKey1 should still exist") @@ -2196,7 +2242,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Defensive code handles nil hostinfo in interactive flow { name: "interactive_workflow_with_nil_hostinfo", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -2206,17 +2252,18 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, }, validateCompleteResponse: true, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Should handle nil hostinfo gracefully node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered despite nil hostinfo") + if found { // Should have some default hostname or handle nil gracefully hostname := node.Hostname() @@ -2231,7 +2278,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Failed registrations should clean up to prevent stale cache entries { name: "interactive_workflow_registration_cache_cleanup_on_error", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -2243,8 +2290,8 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + machineKey: machineKey1.Public, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Get initial AuthURL and extract registration ID authURL := resp.AuthURL assert.Contains(t, authURL, "/register/") @@ -2265,7 +2312,7 @@ func TestAuthenticationFlows(t *testing.T) { nil, "error-test-method", ) - assert.Error(t, err, "should fail with invalid user ID") + assert.Error(t, err, "should fail with invalid user ID") //nolint:testifylint // inside closure, uses assert pattern // Cache entry should still exist after auth error (for retry scenarios) _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) @@ -2284,7 +2331,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates that multiple pending registrations don't interfere with each other { name: "interactive_workflow_multiple_steps_same_node", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -2297,8 +2344,8 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + machineKey: machineKey1.Public, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper // Test multiple interactive registration attempts for the same node can coexist authURL1 := resp.AuthURL assert.Contains(t, authURL1, "/register/") @@ -2315,12 +2362,14 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL assert.Contains(t, authURL2, "/register/") // Both should have different registration IDs regID1, err1 := extractRegistrationIDFromAuthURL(authURL1) regID2, err2 := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err1) require.NoError(t, err2) assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") @@ -2328,6 +2377,7 @@ func TestAuthenticationFlows(t *testing.T) { // Both cache entries should exist simultaneously _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2342,7 +2392,7 @@ func TestAuthenticationFlows(t *testing.T) { // WHY: Validates that you can complete any pending registration, not just the first { name: "interactive_workflow_complete_second_of_multiple_pending", - setupFunc: func(t *testing.T, app *Headscale) (string, error) { + setupFunc: func(t *testing.T, app *Headscale) (string, error) { //nolint:thelper return "", nil }, request: func(_ string) tailcfg.RegisterRequest { @@ -2354,8 +2404,8 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, - validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + machineKey: machineKey1.Public, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { //nolint:thelper authURL1 := resp.AuthURL regID1, err := extractRegistrationIDFromAuthURL(authURL1) require.NoError(t, err) @@ -2371,6 +2421,7 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL regID2, err := extractRegistrationIDFromAuthURL(authURL2) require.NoError(t, err) @@ -2378,6 +2429,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both exist _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2403,6 +2455,7 @@ func TestAuthenticationFlows(t *testing.T) { errorChan <- err return } + responseChan <- resp }() @@ -2430,6 +2483,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the node was created with the second registration's data node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered") + if found { assert.Equal(t, "pending-node-2", node.Hostname()) assert.Equal(t, "second-registration-user", node.User().Name()) @@ -2463,8 +2517,10 @@ func TestAuthenticationFlows(t *testing.T) { // Set up context with timeout for followup tests ctx := context.Background() + if req.Followup != "" { var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() } @@ -2516,7 +2572,7 @@ func TestAuthenticationFlows(t *testing.T) { } } -// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow +// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow. func runInteractiveWorkflowTest(t *testing.T, tt struct { name string setupFunc func(*testing.T, *Headscale) (string, error) @@ -2535,6 +2591,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { validateCompleteResponse bool }, app *Headscale, dynamicValue string, ) { + t.Helper() // Build initial request req := tt.request(dynamicValue) machineKey := tt.machineKey() @@ -2597,6 +2654,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { errorChan <- err return } + responseChan <- resp }() @@ -2650,25 +2708,29 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if responseToValidate == nil { responseToValidate = initialResp } + tt.validate(t, responseToValidate, app) } } -// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL +// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" + idx := strings.LastIndex(authURL, registerPrefix) if idx == -1 { - return "", fmt.Errorf("invalid AuthURL format: %s", authURL) + return "", fmt.Errorf("invalid AuthURL format: %s", authURL) //nolint:err113 } idStr := authURL[idx+len(registerPrefix):] + return types.RegistrationIDFromString(idStr) } -// validateCompleteRegistrationResponse performs comprehensive validation of a registration response -func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { +// validateCompleteRegistrationResponse performs comprehensive validation of a registration response. +func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, _ tailcfg.RegisterRequest) { + t.Helper() // Basic response validation require.NotNil(t, resp, "response should not be nil") require.True(t, resp.MachineAuthorized, "machine should be authorized") @@ -2681,7 +2743,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe // Additional validation can be added here as needed } -// Simple test to validate basic node creation and lookup +// Simple test to validate basic node creation and lookup. func TestNodeStoreLookup(t *testing.T) { app := createTestApp(t) @@ -2713,8 +2775,10 @@ func TestNodeStoreLookup(t *testing.T) { // Wait for node to be available in NodeStore var node types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + node, found = app.state.GetNodeByNodeKey(nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") @@ -2783,8 +2847,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Get the node ID var registeredNode types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available") @@ -2796,6 +2862,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify initial state: user1 has 2 nodes, user2 has 2 nodes user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID)) user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID)) + require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially") require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially") @@ -2876,6 +2943,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify new nodes were created for user1 with the same machine keys t.Logf("Verifying new nodes created for user1 from user2's machine keys...") + for i := 2; i < 4; i++ { node := nodes[i] // Should be able to find a node with user1 and this machine key (the new one) @@ -2899,7 +2967,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Expected behavior: // - User1's original node should STILL EXIST (expired) // - User2 should get a NEW node created (NOT transfer) -// - Both nodes share the same machine key (same physical device) +// - Both nodes share the same machine key (same physical device). func TestWebFlowReauthDifferentUser(t *testing.T) { machineKey := key.NewMachine() nodeKey1 := key.NewNode() @@ -3043,7 +3111,8 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Count nodes per user user1Nodes := 0 user2Nodes := 0 - for i := 0; i < allNodesSlice.Len(); i++ { + + for i := range allNodesSlice.Len() { n := allNodesSlice.At(i) if n.UserID().Get() == user1.ID { user1Nodes++ @@ -3060,7 +3129,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { }) } -// Helper function to create test app +// Helper function to create test app. func createTestApp(t *testing.T) *Headscale { t.Helper() @@ -3147,6 +3216,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { } t.Log("Step 1: Initial registration with pre-auth key") + initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public()) require.NoError(t, err, "initial registration should succeed") require.NotNil(t, initialResp) @@ -3172,6 +3242,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // - System reboots // The Tailscale client persists the pre-auth key in its state and sends it on every registration t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") + restartReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ AuthKey: pakNew.Key, // Same key, now marked as Used=true @@ -3188,10 +3259,12 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public()) // This is the assertion that currently FAILS in v0.27.0 - assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") //nolint:testifylint // intentionally uses assert to show bug + if err != nil { t.Logf("Error received (this is the bug): %v", err) t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially") + return // Stop here to show the bug clearly } @@ -3289,7 +3362,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { } _, err = app.handleRegister(context.Background(), req, machineKey.Public()) - assert.Error(t, err, "expired pre-auth key should be rejected") + require.Error(t, err, "expired pre-auth key should be rejected") assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration") } diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index b1a14eba..9eeaf7e6 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -77,7 +77,7 @@ func (hsdb *HSDatabase) CreateAPIKey( Expiration: expiration, } - if err := hsdb.DB.Save(&key).Error; err != nil { + if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr return "", nil, fmt.Errorf("saving API key to database: %w", err) } @@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { keys := []types.APIKey{} - if err := hsdb.DB.Find(&keys).Error; err != nil { + + err := hsdb.DB.Find(&keys).Error + if err != nil { return nil, err } @@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { - if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil { + err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error + if err != nil { return err } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index b1a6d52b..0fc6bb68 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -53,6 +53,8 @@ type HSDatabase struct { // NewHeadscaleDatabase creates a new database connection and runs migrations. // It accepts the full configuration to allow migrations access to policy settings. +// +//nolint:gocyclo // complex database initialization with many migrations func NewHeadscaleDatabase( cfg *types.Config, regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], @@ -76,7 +78,7 @@ func NewHeadscaleDatabase( ID: "202501221827", Migrate: func(tx *gorm.DB) error { // Remove any invalid routes associated with a node that does not exist. - if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { + if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error if err != nil { return err @@ -84,14 +86,14 @@ func NewHeadscaleDatabase( } // Remove any invalid routes without a node_id. - if tx.Migrator().HasTable(&types.Route{}) { + if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations err := tx.Exec("delete from routes where node_id is null").Error if err != nil { return err } } - err := tx.AutoMigrate(&types.Route{}) + err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations if err != nil { return fmt.Errorf("automigrating types.Route: %w", err) } @@ -109,6 +111,7 @@ func NewHeadscaleDatabase( if err != nil { return fmt.Errorf("automigrating types.PreAuthKey: %w", err) } + err = tx.AutoMigrate(&types.Node{}) if err != nil { return fmt.Errorf("automigrating types.Node: %w", err) @@ -155,7 +158,8 @@ AND auth_key_id NOT IN ( nodeRoutes := map[uint64][]netip.Prefix{} - var routes []types.Route + var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations + err = tx.Find(&routes).Error if err != nil { return fmt.Errorf("fetching routes: %w", err) @@ -171,7 +175,7 @@ AND auth_key_id NOT IN ( tsaddr.SortPrefixes(routes) routes = slices.Compact(routes) - data, err := json.Marshal(routes) + data, _ := json.Marshal(routes) err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error if err != nil { @@ -180,7 +184,7 @@ AND auth_key_id NOT IN ( } // Drop the old table. - _ = tx.Migrator().DropTable(&types.Route{}) + _ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations return nil }, @@ -256,10 +260,13 @@ AND auth_key_id NOT IN ( // Check if routes table exists and drop it (should have been migrated already) var routesExists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists) if err == nil && routesExists { log.Info().Msg("dropping leftover routes table") - if err := tx.Exec("DROP TABLE routes").Error; err != nil { + + err := tx.Exec("DROP TABLE routes").Error + if err != nil { return fmt.Errorf("dropping routes table: %w", err) } } @@ -281,6 +288,7 @@ AND auth_key_id NOT IN ( for _, table := range tablesToRename { // Check if table exists before renaming var exists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists) if err != nil { return fmt.Errorf("checking if table %s exists: %w", table, err) @@ -291,7 +299,8 @@ AND auth_key_id NOT IN ( _ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error // Rename current table to _old - if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil { + err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error + if err != nil { return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err) } } @@ -365,7 +374,8 @@ AND auth_key_id NOT IN ( } for _, createSQL := range tableCreationSQL { - if err := tx.Exec(createSQL).Error; err != nil { + err := tx.Exec(createSQL).Error + if err != nil { return fmt.Errorf("creating new table: %w", err) } } @@ -394,7 +404,8 @@ AND auth_key_id NOT IN ( } for _, copySQL := range dataCopySQL { - if err := tx.Exec(copySQL).Error; err != nil { + err := tx.Exec(copySQL).Error + if err != nil { return fmt.Errorf("copying data: %w", err) } } @@ -417,14 +428,16 @@ AND auth_key_id NOT IN ( } for _, indexSQL := range indexes { - if err := tx.Exec(indexSQL).Error; err != nil { + err := tx.Exec(indexSQL).Error + if err != nil { return fmt.Errorf("creating index: %w", err) } } // Drop old tables only after everything succeeds for _, table := range tablesToRename { - if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil { + err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error + if err != nil { log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded") } } @@ -760,6 +773,7 @@ AND auth_key_id NOT IN ( // or else it blocks... sqlConn.SetMaxIdleConns(maxIdleConns) + sqlConn.SetMaxOpenConns(maxOpenConns) defer sqlConn.SetMaxIdleConns(1) defer sqlConn.SetMaxOpenConns(1) @@ -777,7 +791,7 @@ AND auth_key_id NOT IN ( }, } - if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { + if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("validating schema: %w", err) } } @@ -803,6 +817,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { switch cfg.Type { case types.DatabaseSqlite: dir := filepath.Dir(cfg.Sqlite.Path) + err := util.EnsureDir(dir) if err != nil { return nil, fmt.Errorf("creating directory for sqlite: %w", err) @@ -856,7 +871,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { Str("path", dbString). Msg("Opening database") - if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { + if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr if !sslEnabled { dbString += " sslmode=disable" } @@ -911,7 +926,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig // Get the current foreign key status var fkOriginallyEnabled int - if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { + if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr return fmt.Errorf("checking foreign key status: %w", err) } @@ -940,28 +955,31 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig if needsFKDisabled { // Disable foreign keys for this migration - if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { + err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error + if err != nil { return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err) } } else { // Ensure foreign keys are enabled for this migration - if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + err := dbConn.Exec("PRAGMA foreign_keys = ON").Error + if err != nil { return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err) } } // Run up to this specific migration (will only run the next pending migration) - if err := migrations.MigrateTo(migrationID); err != nil { + err := migrations.MigrateTo(migrationID) + if err != nil { return fmt.Errorf("running migration %s: %w", migrationID, err) } } - if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr return fmt.Errorf("restoring foreign keys: %w", err) } // Run the rest of the migrations - if err := migrations.Migrate(); err != nil { + if err := migrations.Migrate(); err != nil { //nolint:noinlineerr return err } @@ -979,16 +997,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig if err != nil { return err } + defer rows.Close() for rows.Next() { var violation constraintViolation - if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil { + + err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex) + if err != nil { return err } violatedConstraints = append(violatedConstraints, violation) } - _ = rows.Close() + + if err := rows.Err(); err != nil { //nolint:noinlineerr + return err + } if len(violatedConstraints) > 0 { for _, violation := range violatedConstraints { @@ -1003,7 +1027,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig } } else { // PostgreSQL can run all migrations in one block - no foreign key issues - if err := migrations.Migrate(); err != nil { + err := migrations.Migrate() + if err != nil { return err } } @@ -1014,6 +1039,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() + sqlDB, err := hsdb.DB.DB() if err != nil { return err @@ -1029,7 +1055,7 @@ func (hsdb *HSDatabase) Close() error { } if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { - db.Exec("VACUUM") + db.Exec("VACUUM") //nolint:errcheck,noctx } return db.Close() @@ -1038,12 +1064,14 @@ func (hsdb *HSDatabase) Close() error { func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error { rx := hsdb.DB.Begin() defer rx.Rollback() + return fn(rx) } func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { rx := db.Begin() defer rx.Rollback() + ret, err := fn(rx) if err != nil { var no T @@ -1056,7 +1084,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { tx := hsdb.DB.Begin() defer tx.Rollback() - if err := fn(tx); err != nil { + + err := fn(tx) + if err != nil { return err } @@ -1066,6 +1096,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) { tx := db.Begin() defer tx.Rollback() + ret, err := fn(tx) if err != nil { var no T diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3cd0d14e..3c687b39 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "os" "os/exec" @@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { // Verify api_keys data preservation var apiKeyCount int + err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error require.NoError(t, err) assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema") @@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { return err } - _, err = db.Exec(string(schemaContent)) + _, err = db.ExecContext(context.Background(), string(schemaContent)) return err } @@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { func requireConstraintFailed(t *testing.T, err error) { t.Helper() require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) } @@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) { }{ { name: "no-duplicate-username-if-no-oidc", - run: func(t *testing.T, db *gorm.DB) { + run: func(t *testing.T, db *gorm.DB) { //nolint:thelper _, err := CreateUser(db, types.User{Name: "user1"}) require.NoError(t, err) _, err = CreateUser(db, types.User{Name: "user1"}) @@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) { }, { name: "no-oidc-duplicate-username-and-id", - run: func(t *testing.T, db *gorm.DB) { + run: func(t *testing.T, db *gorm.DB) { //nolint:thelper user := types.User{ Model: gorm.Model{ID: 1}, Name: "user1", @@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) { }, { name: "no-oidc-duplicate-id", - run: func(t *testing.T, db *gorm.DB) { + run: func(t *testing.T, db *gorm.DB) { //nolint:thelper user := types.User{ Model: gorm.Model{ID: 1}, Name: "user1", @@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) { }, { name: "allow-duplicate-username-cli-then-oidc", - run: func(t *testing.T, db *gorm.DB) { + run: func(t *testing.T, db *gorm.DB) { //nolint:thelper _, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username require.NoError(t, err) @@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) { }, { name: "allow-duplicate-username-oidc-then-cli", - run: func(t *testing.T, db *gorm.DB) { + run: func(t *testing.T, db *gorm.DB) { //nolint:thelper user := types.User{ Name: "user1", ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, @@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { } // Construct the pg_restore command - cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) + cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) // Set the output streams cmd.Stdout = os.Stdout @@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { // skip already-applied migrations and only run new ones. func TestSQLiteAllTestdataMigrations(t *testing.T) { t.Parallel() + schemas, err := os.ReadDir("testdata/sqlite") require.NoError(t, err) diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go index d118b7fd..290a6310 100644 --- a/hscontrol/db/ephemeral_garbage_collector_test.go +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Basic deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex - var deletionWg sync.WaitGroup + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + deletionWg sync.WaitGroup + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionWg.Done() } @@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { go gc.Start() // Schedule several nodes for deletion with short expiry - const expiry = fifty - const numNodes = 100 + const ( + expiry = fifty + numNodes = 100 + ) // Set up wait group for expected deletions + deletionWg.Add(numNodes) for i := 1; i <= numNodes; i++ { - gc.Schedule(types.NodeID(i), expiry) + gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test } // Wait for all scheduled deletions to complete @@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // Schedule and immediately cancel to test that part of the code for i := numNodes + 1; i <= numNodes*2; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test gc.Schedule(nodeID, time.Hour) gc.Cancel(nodeID) } @@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Start GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() - const shortExpiry = fifty - const longExpiry = 1 * time.Hour + const ( + shortExpiry = fifty + longExpiry = 1 * time.Hour + ) nodeID := types.NodeID(1) @@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionNotifier <- nodeID } // Start the GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() nodeID := types.NodeID(1) + const expiry = fifty // Schedule node for deletion @@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + nodeDeleted := make(chan struct{}) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() close(nodeDeleted) // Signal that deletion happened } @@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Use a WaitGroup to ensure the GC has started var startWg sync.WaitGroup startWg.Add(1) + go func() { startWg.Done() // Signal that the goroutine has started gc.Start() }() + startWg.Wait() // Wait for the GC to start // Close GC right away @@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Check no node was deleted deleteMutex.Lock() + nodesDeleted := len(deletedIDs) + deleteMutex.Unlock() assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") @@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() } @@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { go gc.Start() // Number of concurrent scheduling goroutines - const numSchedulers = 10 - const nodesPerScheduler = 50 + const ( + numSchedulers = 10 + nodesPerScheduler = 50 + ) const closeAfterNodes = 25 // Close GC after this many nodes per scheduler @@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { case <-stopScheduling: return default: - nodeID := types.NodeID(baseNodeID + j + 1) - gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test + nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test + gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test atomic.AddInt64(&scheduledCount, 1) // Yield to other goroutines to introduce variability diff --git a/hscontrol/db/ip.go b/hscontrol/db/ip.go index c6a2b399..7402f473 100644 --- a/hscontrol/db/ip.go +++ b/hscontrol/db/ip.go @@ -17,7 +17,11 @@ import ( "tailscale.com/net/tsaddr" ) -var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip") +var ( + errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip") + errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix") + errIPAllocatorNil = errors.New("ip allocator was nil") +) // IPAllocator is a singleton responsible for allocating // IP addresses for nodes and making sure the same @@ -62,8 +66,10 @@ func NewIPAllocator( strategy: strategy, } - var v4s []sql.NullString - var v6s []sql.NullString + var ( + v4s []sql.NullString + v6s []sql.NullString + ) if db != nil { err := db.Read(func(rx *gorm.DB) error { @@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) { i.mu.Lock() defer i.mu.Unlock() - var err error - var ret4 *netip.Addr - var ret6 *netip.Addr + var ( + err error + ret4 *netip.Addr + ret6 *netip.Addr + ) if i.prefix4 != nil { ret4, err = i.next(i.prev4, i.prefix4) if err != nil { return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err) } + i.prev4 = *ret4 } @@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) { if err != nil { return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err) } + i.prev6 = *ret6 } @@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip. } func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) { - var err error - var ip netip.Addr + var ( + err error + ip netip.Addr + ) switch i.strategy { case types.IPAllocationStrategySequential: @@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) { if !pfx.Contains(ip) { return netip.Addr{}, fmt.Errorf( - "generated ip(%s) not in prefix(%s)", + "%w: ip(%s) not in prefix(%s)", + errGeneratedIPNotInPrefix, ip.String(), pfx.String(), ) @@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool { // If a prefix type has been removed (IPv4 or IPv6), it // will remove the IPs in that family from the node. func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { - var err error - var ret []string + var ( + err error + ret []string + ) + err = db.Write(func(tx *gorm.DB) error { if i == nil { - return errors.New("backfilling IPs: ip allocator was nil") + return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil) } log.Trace().Caller().Msgf("starting to backfill IPs") @@ -295,6 +311,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { node.IPv4 = ret4 changed = true + ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname)) } @@ -307,6 +324,7 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { node.IPv6 = ret6 changed = true + ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname)) } diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 7ba335e8..35798426 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -21,9 +21,7 @@ var mpp = func(pref string) *netip.Prefix { return &p } -var na = func(pref string) netip.Addr { - return netip.MustParseAddr(pref) -} +var na = netip.MustParseAddr var nap = func(pref string) *netip.Addr { n := na(pref) @@ -158,8 +156,10 @@ func TestIPAllocatorSequential(t *testing.T) { types.IPAllocationStrategySequential, ) - var got4s []netip.Addr - var got6s []netip.Addr + var ( + got4s []netip.Addr + got6s []netip.Addr + ) for range tt.getCount { got4, got6, err := alloc.Next() @@ -175,6 +175,7 @@ func TestIPAllocatorSequential(t *testing.T) { got6s = append(got6s, *got6) } } + if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" { t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff) } @@ -288,6 +289,7 @@ func TestBackfillIPAddresses(t *testing.T) { fullNodeP := func(i int) *types.Node { v4 := fmt.Sprintf("100.64.0.%d", i) v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i) + return &types.Node{ IPv4: nap(v4), IPv6: nap(v6), @@ -484,6 +486,7 @@ func TestBackfillIPAddresses(t *testing.T) { func TestIPAllocatorNextNoReservedIPs(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) + defer db.Close() alloc, err := NewIPAllocator( diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index d6066eba..91276f50 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -27,8 +27,14 @@ import ( const ( NodeGivenNameHashLength = 8 NodeGivenNameTrimSize = 2 + + // defaultTestNodePrefix is the default hostname prefix for nodes created in tests. + defaultTestNodePrefix = "testnode" ) +// ErrNodeNameNotUnique is returned when a node name is not unique. +var ErrNodeNameNotUnique = errors.New("node name is not unique") + var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var ( @@ -52,12 +58,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) // If at least one peer ID is given, only these peer nodes will be returned. func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx. + + err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). Where("id <> ?", nodeID). - Where(peerIDs).Find(&nodes).Error; err != nil { + Where(peerIDs).Find(&nodes).Error + if err != nil { return types.Nodes{}, err } @@ -76,11 +84,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) // or for the given nodes if at least one node ID is given as parameter. func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx. + + err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Where(nodeIDs).Find(&nodes).Error; err != nil { + Where(nodeIDs).Find(&nodes).Error + if err != nil { return nil, err } @@ -90,7 +100,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes := types.Nodes{} - if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil { + + err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error + if err != nil { return nil, err } @@ -222,7 +234,7 @@ func SetTags( return nil } -// SetTags takes a Node struct pointer and update the forced tags. +// SetApprovedRoutes takes a Node struct pointer and updates the approved routes. func SetApprovedRoutes( tx *gorm.DB, nodeID types.NodeID, @@ -254,7 +266,7 @@ func SetApprovedRoutes( return err } - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr return fmt.Errorf("updating approved routes: %w", err) } @@ -294,10 +306,10 @@ func RenameNode(tx *gorm.DB, } if count > 0 { - return errors.New("name is not unique") + return ErrNodeNameNotUnique } - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr return fmt.Errorf("renaming node in database: %w", err) } @@ -329,7 +341,8 @@ func DeleteNode(tx *gorm.DB, node *types.Node, ) error { // Unscoped causes the node to be fully removed from the database. - if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil { + err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error + if err != nil { return err } @@ -343,9 +356,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode( nodeID types.NodeID, ) error { return hsdb.Write(func(tx *gorm.DB) error { - if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil { + err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error + if err != nil { return err } + return nil }) } @@ -395,7 +410,8 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n // so we store the node.Expire and node.Nodekey that has been set when // adding it to the registrationCache if node.IPv4 != nil || node.IPv6 != nil { - if err := tx.Save(&node).Error; err != nil { + err := tx.Save(&node).Error + if err != nil { return nil, fmt.Errorf("registering existing node in database: %w", err) } @@ -431,7 +447,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n node.GivenName = givenName } - if err := tx.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr return nil, fmt.Errorf("saving node to database: %w", err) } @@ -656,7 +672,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) panic("CreateNodeForTest requires a valid user") } - nodeName := "testnode" + nodeName := defaultTestNodePrefix if len(hostname) > 0 && hostname[0] != "" { nodeName = hostname[0] } @@ -728,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname panic("CreateNodesForTest requires a valid user") } - prefix := "testnode" + prefix := defaultTestNodePrefix if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { prefix = hostnamePrefix[0] } @@ -751,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int panic("CreateRegisteredNodesForTest requires a valid user") } - prefix := "testnode" + prefix := defaultTestNodePrefix if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { prefix = hostnamePrefix[0] } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 7e00f9ca..a151baff 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -187,6 +187,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { suppliedName string randomSuffix bool } + tests := []struct { name string args args @@ -467,10 +468,10 @@ func TestAutoApproveRoutes(t *testing.T) { require.NoError(t, err) users, err := adb.ListUsers() - assert.NoError(t, err) + require.NoError(t, err) nodes, err := adb.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) pm, err := pmf(users, nodes.ViewSlice()) require.NoError(t, err) @@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes1) == 0 { expectedRoutes1 = nil } + if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes2) == 0 { expectedRoutes2 = nil } + if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -520,6 +523,7 @@ func TestAutoApproveRoutes(t *testing.T) { func TestEphemeralGarbageCollectorOrder(t *testing.T) { want := []types.NodeID{1, 3} got := []types.NodeID{} + var mu sync.Mutex deletionCount := make(chan struct{}, 10) @@ -527,6 +531,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) { e := NewEphemeralGarbageCollector(func(ni types.NodeID) { mu.Lock() defer mu.Unlock() + got = append(got, ni) deletionCount <- struct{}{} @@ -576,8 +581,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) { } func TestEphemeralGarbageCollectorLoads(t *testing.T) { - var got []types.NodeID - var mu sync.Mutex + var ( + got []types.NodeID + mu sync.Mutex + ) want := 1000 @@ -589,6 +596,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) { // Yield to other goroutines to introduce variability runtime.Gosched() + got = append(got, ni) atomic.AddInt64(&deletedCount, 1) @@ -616,9 +624,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) { } } -func generateRandomNumber(t *testing.T, max int64) int64 { +//nolint:unused +func generateRandomNumber(t *testing.T, maxVal int64) int64 { t.Helper() - maxB := big.NewInt(max) + + maxB := big.NewInt(maxVal) + n, err := rand.Int(rand.Reader, maxB) if err != nil { t.Fatalf("getting random number: %s", err) @@ -722,7 +733,7 @@ func TestNodeNaming(t *testing.T) { nodeInvalidHostname := types.Node{ MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data UserID: &user2.ID, RegisterMethod: util.RegisterMethodAuthKey, } @@ -746,12 +757,15 @@ func TestNodeNaming(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) if err != nil { return err } - _, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil) + + _, _ = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil) _, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil) + return err }) require.NoError(t, err) @@ -810,25 +824,25 @@ func TestNodeNaming(t *testing.T) { err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "test") }) - assert.ErrorContains(t, err, "name is not unique") + require.ErrorContains(t, err, "name is not unique") // Rename invalid chars err = db.Write(func(tx *gorm.DB) error { - return RenameNode(tx, nodes[2].ID, "我的电脑") + return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data }) - assert.ErrorContains(t, err, "invalid characters") + require.ErrorContains(t, err, "invalid characters") // Rename too short err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[3].ID, "a") }) - assert.ErrorContains(t, err, "at least 2 characters") + require.ErrorContains(t, err, "at least 2 characters") // Rename with emoji err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "hostname-with-💩") }) - assert.ErrorContains(t, err, "invalid characters") + require.ErrorContains(t, err, "invalid characters") // Rename with only emoji err = db.Write(func(tx *gorm.DB) error { @@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) { }, { name: "chinese_chars_with_dash_rejected", - newName: "server-北京-01", + newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data wantErr: "invalid characters", }, { name: "chinese_only_rejected", - newName: "我的电脑", + newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data wantErr: "invalid characters", }, { @@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) { }, { name: "mixed_chinese_emoji_rejected", - newName: "测试💻机器", + newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data wantErr: "invalid characters", }, { @@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err @@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err diff --git a/hscontrol/db/policy.go b/hscontrol/db/policy.go index bdc8af41..83bb4812 100644 --- a/hscontrol/db/policy.go +++ b/hscontrol/db/policy.go @@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) { Data: policy, } - if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil { + err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error + if err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 8196aa92..d88d8aee 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -138,7 +138,7 @@ func CreatePreAuthKey( Hash: hash, // Store hash } - if err := tx.Save(&key).Error; err != nil { + if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr return nil, fmt.Errorf("creating key in database: %w", err) } @@ -155,9 +155,7 @@ func CreatePreAuthKey( } func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) { - return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - return ListPreAuthKeys(rx) - }) + return Read(hsdb.DB, ListPreAuthKeys) } // ListPreAuthKeys returns all PreAuthKeys in the database. @@ -329,10 +327,11 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { } k.Used = true + return nil } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. +// ExpirePreAuthKey marks a PreAuthKey as expired. func ExpirePreAuthKey(tx *gorm.DB, id uint64) error { now := time.Now() return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go index d27977a4..b4c80795 100644 --- a/hscontrol/db/sqliteconfig/config.go +++ b/hscontrol/db/sqliteconfig/config.go @@ -362,7 +362,8 @@ func (c *Config) Validate() error { // ToURL builds a properly encoded SQLite connection string using _pragma parameters // compatible with modernc.org/sqlite driver. func (c *Config) ToURL() (string, error) { - if err := c.Validate(); err != nil { + err := c.Validate() + if err != nil { return "", fmt.Errorf("invalid config: %w", err) } @@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) { if c.BusyTimeout > 0 { pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout)) } + if c.JournalMode != "" { pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode)) } + if c.AutoVacuum != "" { pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum)) } + if c.WALAutocheckpoint >= 0 { pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint)) } + if c.Synchronous != "" { pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous)) } + if c.ForeignKeys { pragmas = append(pragmas, "foreign_keys=ON") } diff --git a/hscontrol/db/sqliteconfig/config_test.go b/hscontrol/db/sqliteconfig/config_test.go index 66955bb9..7829d9e9 100644 --- a/hscontrol/db/sqliteconfig/config_test.go +++ b/hscontrol/db/sqliteconfig/config_test.go @@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) { t.Errorf("Config.ToURL() error = %v", err) return } + if got != tt.want { t.Errorf("Config.ToURL() = %q, want %q", got, tt.want) } @@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) { Path: "", BusyTimeout: -1, } + _, err := config.ToURL() if err == nil { t.Error("Config.ToURL() with invalid config should return error") diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index bb54ea1e..00adaa64 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -1,6 +1,7 @@ package sqliteconfig import ( + "context" "database/sql" "path/filepath" "strings" @@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { defer db.Close() // Test connection - if err := db.Ping(); err != nil { + ctx := context.Background() + + err = db.PingContext(ctx) + if err != nil { t.Fatalf("Failed to ping database: %v", err) } @@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { for pragma, expectedValue := range tt.expected { t.Run("pragma_"+pragma, func(t *testing.T) { var actualValue any + query := "PRAGMA " + pragma - err := db.QueryRow(query).Scan(&actualValue) + + err := db.QueryRowContext(ctx, query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) } @@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } defer db.Close() + ctx := context.Background() + // Create test tables with foreign key relationship schema := ` CREATE TABLE parent ( @@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { ); ` - if _, err := db.Exec(schema); err != nil { + _, err = db.ExecContext(ctx, schema) + if err != nil { t.Fatalf("Failed to create schema: %v", err) } // Insert parent record - if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { + _, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')") + if err != nil { t.Fatalf("Failed to insert parent: %v", err) } // Test 1: Valid foreign key should work - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") + _, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") if err != nil { t.Fatalf("Valid foreign key insert failed: %v", err) } // Test 2: Invalid foreign key should fail - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") + _, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") if err == nil { t.Error("Expected foreign key constraint violation, but insert succeeded") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } // Test 3: Deleting referenced parent should fail - _, err = db.Exec("DELETE FROM parent WHERE id = 1") + _, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1") if err == nil { t.Error("Expected foreign key constraint violation when deleting referenced parent") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) { defer db.Close() var actualMode string - err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) + + err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 15a85cf8..080d080b 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL { t.Helper() ctx := t.Context() + srv, err := postgrestest.Start(ctx) if err != nil { t.Fatal(err) } + t.Cleanup(srv.Cleanup) u, err := srv.CreateDatabase(ctx) if err != nil { t.Fatal(err) } + t.Logf("created local postgres: %s", u) pu, _ := url.Parse(u) diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 30e28453..06a898d0 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -3,12 +3,19 @@ package db import ( "context" "encoding" + "errors" "fmt" "reflect" "gorm.io/gorm/schema" ) +var ( + errUnmarshalTextValue = errors.New("unmarshalling text value") + errUnsupportedType = errors.New("unsupported type") + errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported") +) + // Got from https://github.com/xdg-go/strum/blob/main/types.go var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() @@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect if dbValue != nil { var bytes []byte + switch v := dbValue.(type) { case []byte: bytes = v case string: bytes = []byte(v) default: - return fmt.Errorf("unmarshalling text value: %#v", dbValue) + return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue) } if isTextUnmarshaler(fieldValue) { maybeInstantiatePtr(fieldValue) f := fieldValue.MethodByName("UnmarshalText") args := []reflect.Value{reflect.ValueOf(bytes)} + ret := f.Call(args) if !ret[0].IsNil() { - return decodingError(field.Name, ret[0].Interface().(error)) + if err, ok := ret[0].Interface().(error); ok { + return decodingError(field.Name, err) + } } // If the underlying field is to a pointer type, we need to @@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect return nil } else { - return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) + return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface()) } } @@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec // always comparable, particularly when reflection is involved: // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) { - return nil, nil + return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer } + b, err := v.MarshalText() if err != nil { return nil, err @@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec return string(b), nil default: - return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) + return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v) } } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 6aff9ed1..36ea50e5 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -12,9 +12,11 @@ import ( ) var ( - ErrUserExists = errors.New("user already exists") - ErrUserNotFound = errors.New("user not found") - ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrUserExists = errors.New("user already exists") + ErrUserNotFound = errors.New("user not found") + ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs") + ErrUserNotUnique = errors.New("expected exactly one user") ) func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { @@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) { - if err := util.ValidateHostname(user.Name); err != nil { + err := util.ValidateHostname(user.Name) + if err != nil { return nil, err } - if err := tx.Create(&user).Error; err != nil { + + err = tx.Create(&user).Error + if err != nil { return nil, fmt.Errorf("creating user: %w", err) } @@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error { if err != nil { return err } + if len(nodes) > 0 { return ErrUserStillHasNodes } @@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error { if err != nil { return err } + for _, key := range keys { err = DestroyPreAuthKey(tx, key.ID) if err != nil { @@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user") // not exist or if another User exists with the new name. func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } - if err = util.ValidateHostname(newName); err != nil { + + if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr return err } @@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { // ListUsers gets all the existing users. func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { if len(where) > 1 { - return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where)) } var user *types.User @@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { } users := []types.User{} - if err := tx.Where(user).Find(&users).Error; err != nil { + + err := tx.Where(user).Find(&users).Error + if err != nil { return nil, err } @@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { } if len(users) != 1 { - return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users)) } return &users[0], nil diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 629b7be1..93200b95 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { overview := h.state.DebugOverviewJSON() + overviewJSON, err := json.MarshalIndent(overview, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(overviewJSON) + _, _ = w.Write(overviewJSON) } else { // Default to text/plain for backward compatibility overview := h.state.DebugOverview() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(overview)) + _, _ = w.Write([]byte(overview)) } })) // Configuration endpoint debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config := h.state.DebugConfig() + configJSON, err := json.MarshalIndent(config, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(configJSON) + _, _ = w.Write(configJSON) })) // Policy endpoint @@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server { } else { w.Header().Set("Content-Type", "text/plain") } + w.WriteHeader(http.StatusOK) - w.Write([]byte(policy)) + _, _ = w.Write([]byte(policy)) })) // Filter rules endpoint @@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + filterJSON, err := json.MarshalIndent(filter, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(filterJSON) + _, _ = w.Write(filterJSON) })) // SSH policies endpoint debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sshPolicies := h.state.DebugSSHPolicies() + sshJSON, err := json.MarshalIndent(sshPolicies, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(sshJSON) + _, _ = w.Write(sshJSON) })) // DERP map endpoint @@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { derpInfo := h.state.DebugDERPJSON() + derpJSON, err := json.MarshalIndent(derpInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(derpJSON) + _, _ = w.Write(derpJSON) } else { // Default to text/plain for backward compatibility derpInfo := h.state.DebugDERPMap() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(derpInfo)) + _, _ = w.Write([]byte(derpInfo)) } })) @@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { nodeStoreNodes := h.state.DebugNodeStoreJSON() + nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(nodeStoreJSON) + _, _ = w.Write(nodeStoreJSON) } else { // Default to text/plain for backward compatibility nodeStoreInfo := h.state.DebugNodeStore() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(nodeStoreInfo)) + _, _ = w.Write([]byte(nodeStoreInfo)) } })) // Registration cache endpoint debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cacheInfo := h.state.DebugRegistrationCache() + cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(cacheJSON) + _, _ = w.Write(cacheJSON) })) // Routes endpoint @@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { routes := h.state.DebugRoutes() + routesJSON, err := json.MarshalIndent(routes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(routesJSON) + _, _ = w.Write(routesJSON) } else { // Default to text/plain for backward compatibility routes := h.state.DebugRoutesString() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(routes)) + _, _ = w.Write([]byte(routes)) } })) @@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { policyManagerInfo := h.state.DebugPolicyManagerJSON() + policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(policyManagerJSON) + _, _ = w.Write(policyManagerJSON) } else { // Default to text/plain for backward compatibility policyManagerInfo := h.state.DebugPolicyManager() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(policyManagerInfo)) + _, _ = w.Write([]byte(policyManagerInfo)) } })) @@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server { if res == nil { w.WriteHeader(http.StatusOK) - w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + _, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + return } @@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(resJSON) + _, _ = w.Write(resJSON) })) // Batcher endpoint @@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(batcherJSON) + _, _ = w.Write(batcherJSON) } else { // Default to text/plain for backward compatibility batcherInfo := h.debugBatcher() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(batcherInfo)) + _, _ = w.Write([]byte(batcherInfo)) } })) @@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string { activeConnections: info.ActiveConnections, }) totalNodes++ + if info.Connected { connectedCount++ } @@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string { activeConnections: 0, }) totalNodes++ + if connected { connectedCount++ } + return true }) } @@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { ActiveConnections: 0, } info.TotalNodes++ + return true }) } diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 42d74abe..3dc06d07 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -28,11 +28,14 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) { return nil, err } defer derpFile.Close() + var derpMap tailcfg.DERPMap + b, err := io.ReadAll(derpFile) if err != nil { return nil, err } + err = yaml.Unmarshal(b, &derpMap) return &derpMap, err @@ -57,12 +60,14 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { } defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } var derpMap tailcfg.DERPMap + err = json.Unmarshal(body, &derpMap) return &derpMap, err @@ -134,6 +139,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) { for id := range dm.Regions { ids = append(ids, id) } + slices.Sort(ids) for _, id := range ids { @@ -160,16 +166,18 @@ func derpRandom() *rand.Rand { derpRandomOnce.Do(func() { seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String()) - rnd := rand.New(rand.NewSource(0)) - rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) + rnd := rand.New(rand.NewSource(0)) //nolint:gosec // weak random is fine for DERP scrambling + rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) //nolint:gosec // safe conversion derpRandomInst = rnd }) + return derpRandomInst } func resetDerpRandomForTesting() { derpRandomMu.Lock() defer derpRandomMu.Unlock() + derpRandomOnce = sync.Once{} derpRandomInst = nil } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go index 91d605a6..445c1044 100644 --- a/hscontrol/derp/derp_test.go +++ b/hscontrol/derp/derp_test.go @@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { viper.Set("dns.base_domain", tt.baseDomain) + defer viper.Reset() + resetDerpRandomForTesting() testMap := tt.derpMap.View().AsStruct() diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 8f1545ba..fdca651a 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -75,9 +75,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { if err != nil { return tailcfg.DERPRegion{}, err } - var host string - var port int - var portStr string + + var ( + host string + port int + portStr string + ) // Extract hostname and port from URL host, portStr, err = net.SplitHostPort(serverURL.Host) @@ -98,12 +101,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { // If debug flag is set, resolve hostname to IP address if debugUseDERPIP { - ips, err := net.LookupIP(host) + ips, err := new(net.Resolver).LookupIPAddr(context.Background(), host) if err != nil { log.Error().Caller().Err(err).Msgf("failed to resolve DERP hostname %s to IP, using hostname", host) } else if len(ips) > 0 { // Use the first IP address - ipStr := ips[0].String() + ipStr := ips[0].IP.String() log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: resolved %s to %s", host, ipStr) host = ipStr } @@ -130,10 +133,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { if err != nil { return tailcfg.DERPRegion{}, err } + portSTUN, err := strconv.Atoi(portSTUNStr) if err != nil { return tailcfg.DERPRegion{}, err } + localDERPregion.Nodes[0].STUNPort = portSTUN log.Info().Caller().Msgf("derp region: %+v", localDERPregion) @@ -155,8 +160,10 @@ func (d *DERPServer) DERPHandler( Caller(). Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.") } + writer.Header().Set("Content-Type", "text/plain") writer.WriteHeader(http.StatusUpgradeRequired) + _, err := writer.Write([]byte("DERP requires connection upgrade")) if err != nil { log.Error(). @@ -206,6 +213,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques return } defer websocketConn.Close(websocket.StatusInternalError, "closing") + if websocketConn.Subprotocol() != "derp" { websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") @@ -225,6 +233,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) { log.Error().Caller().Msg("derp requires Hijacker interface from Gin") writer.Header().Set("Content-Type", "text/plain") writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("HTTP does not support general TCP support")) if err != nil { log.Error(). @@ -241,6 +250,7 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) { log.Error().Caller().Err(err).Msgf("hijack failed") writer.Header().Set("Content-Type", "text/plain") writer.WriteHeader(http.StatusInternalServerError) + _, err = writer.Write([]byte("HTTP does not support general TCP support")) if err != nil { log.Error(). @@ -281,6 +291,7 @@ func DERPProbeHandler( writer.WriteHeader(http.StatusOK) default: writer.WriteHeader(http.StatusMethodNotAllowed) + _, err := writer.Write([]byte("bogus probe method")) if err != nil { log.Error(). @@ -310,9 +321,11 @@ func DERPBootstrapDNSHandler( resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) defer cancel() + var resolver net.Resolver - for _, region := range derpMap.Regions().All() { - for _, node := range region.Nodes().All() { // we don't care if we override some nodes + + for _, region := range derpMap.Regions().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator + for _, node := range region.Nodes().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) if err != nil { log.Trace(). @@ -322,11 +335,14 @@ func DERPBootstrapDNSHandler( continue } + dnsEntries[node.HostName()] = addrs } } + writer.Header().Set("Content-Type", "application/json") writer.WriteHeader(http.StatusOK) + err := json.NewEncoder(writer).Encode(dnsEntries) if err != nil { log.Error(). @@ -339,7 +355,7 @@ func DERPBootstrapDNSHandler( // ServeSTUN starts a STUN server on the configured addr. func (d *DERPServer) ServeSTUN() { - packetConn, err := net.ListenPacket("udp", d.cfg.STUNAddr) + packetConn, err := new(net.ListenConfig).ListenPacket(context.Background(), "udp", d.cfg.STUNAddr) if err != nil { log.Fatal().Msgf("failed to open STUN listener: %v", err) } @@ -350,16 +366,18 @@ func (d *DERPServer) ServeSTUN() { if !ok { log.Fatal().Msg("stun listener is not a UDP listener") } + serverSTUNListener(context.Background(), udpConn) } func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { - var buf [64 << 10]byte var ( + buf [64 << 10]byte bytesRead int udpAddr *net.UDPAddr err error ) + for { bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:]) if err != nil { @@ -380,12 +398,14 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { } log.Trace().Caller().Msgf("stun request from %v", udpAddr) + pkt := buf[:bytesRead] if !stun.Is(pkt) { log.Trace().Caller().Msgf("udp packet is not stun") continue } + txid, err := stun.ParseBindingRequest(pkt) if err != nil { log.Trace().Caller().Err(err).Msgf("stun parse error") @@ -394,7 +414,8 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { } addr, _ := netip.AddrFromSlice(udpAddr.IP) - res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) + res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) //nolint:gosec // port is always <=65535 + _, err = packetConn.WriteTo(res, udpAddr) if err != nil { log.Trace().Caller().Err(err).Msgf("issue writing to UDP") @@ -416,7 +437,9 @@ type DERPVerifyTransport struct { func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { buf := new(bytes.Buffer) - if err := t.handleVerifyRequest(req, buf); err != nil { + + err := t.handleVerifyRequest(req, buf) + if err != nil { log.Error().Caller().Err(err).Msg("failed to handle client verify request") return nil, err diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 82b3078b..525af62b 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "os" "sync" @@ -15,6 +16,9 @@ import ( "tailscale.com/util/set" ) +// ErrPathIsDirectory is returned when a directory path is provided where a file is expected. +var ErrPathIsDirectory = errors.New("path is a directory, only file is supported") + type ExtraRecordsMan struct { mu sync.RWMutex records set.Set[tailcfg.DNSRecord] @@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) { } if fi.IsDir() { - return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) + return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path) } records, hash, err := readExtraRecordsFromPath(path) @@ -85,19 +89,22 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher event channel closing") return } + switch event.Op { case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") + if event.Name != e.path { continue } + e.updateRecords() // If a file is removed or renamed, fsnotify will loose track of it // and not watch it. We will therefore attempt to re-add it with a backoff. case fsnotify.Remove, fsnotify.Rename: _, err := backoff.Retry(context.Background(), func() (struct{}, error) { - if _, err := os.Stat(e.path); err != nil { + if _, err := os.Stat(e.path); err != nil { //nolint:noinlineerr return struct{}{}, err } @@ -123,6 +130,7 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher error channel closing") return } + log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) } } @@ -165,6 +173,7 @@ func (e *ExtraRecordsMan) updateRecords() { e.hashes[e.path] = newHash log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) + e.updateCh <- e.records.Slice() } @@ -183,6 +192,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error } var records []tailcfg.DNSRecord + err = json.Unmarshal(b, &records) if err != nil { return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) diff --git a/hscontrol/grpcv1_test.go b/hscontrol/grpcv1_test.go index 4cf5b7d4..626204ec 100644 --- a/hscontrol/grpcv1_test.go +++ b/hscontrol/grpcv1_test.go @@ -17,6 +17,7 @@ func Test_validateTag(t *testing.T) { type args struct { tag string } + tests := []struct { name string args args @@ -45,7 +46,8 @@ func Test_validateTag(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr { + err := validateTag(tt.args.tag) + if (err != nil) != tt.wantErr { t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 9797d271..72891344 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -20,7 +20,7 @@ import ( ) const ( - // The CapabilityVersion is used by Tailscale clients to indicate + // NoiseCapabilityVersion is used by Tailscale clients to indicate // their codebase version. Tailscale clients can communicate over TS2021 // from CapabilityVersion 28, but we only have good support for it // since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port). @@ -56,7 +56,7 @@ type HTTPError struct { func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } func (e HTTPError) Unwrap() error { return e.Err } -// Error returns an HTTPError containing the given information. +// NewHTTPError returns an HTTPError containing the given information. func NewHTTPError(code int, msg string, err error) HTTPError { return HTTPError{Code: code, Msg: msg, Err: err} } @@ -92,7 +92,7 @@ func (h *Headscale) handleVerifyRequest( } var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest - if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { + if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { //nolint:noinlineerr return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("parsing DERP client request: %w", err)) } @@ -155,7 +155,11 @@ func (h *Headscale) KeyHandler( } writer.Header().Set("Content-Type", "application/json") - json.NewEncoder(writer).Encode(resp) + + err := json.NewEncoder(writer).Encode(resp) + if err != nil { + log.Error().Err(err).Msg("failed to encode public key response") + } return } @@ -180,8 +184,12 @@ func (h *Headscale) HealthHandler( res.Status = "fail" } - json.NewEncoder(writer).Encode(res) + encErr := json.NewEncoder(writer).Encode(res) + if encErr != nil { + log.Error().Err(encErr).Msg("failed to encode health response") + } } + err := h.state.PingDB(req.Context()) if err != nil { respond(err) @@ -218,6 +226,7 @@ func (h *Headscale) VersionHandler( writer.WriteHeader(http.StatusOK) versionInfo := types.GetVersionInfo() + err := json.NewEncoder(writer).Encode(versionInfo) if err != nil { log.Error(). @@ -244,7 +253,7 @@ func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { registrationId.String()) } -// RegisterWebAPI shows a simple message in the browser to point to the CLI +// RegisterHandler shows a simple message in the browser to point to the CLI // Listens in /register/:registration_id. // // This is not part of the Tailscale control API, as we could send whatever URL @@ -267,7 +276,11 @@ func (a *AuthProviderWeb) RegisterHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) + + _, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) + if err != nil { + log.Error().Err(err).Msg("failed to write register response") + } } func FaviconHandler(writer http.ResponseWriter, req *http.Request) { diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 1652b213..1f092a9c 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -16,6 +16,14 @@ import ( "tailscale.com/tailcfg" ) +// Mapper errors. +var ( + ErrInvalidNodeID = errors.New("invalid nodeID") + ErrMapperNil = errors.New("mapper is nil") + ErrNodeConnectionNil = errors.New("nodeConnection is nil") + ErrNodeNotFoundMapper = errors.New("node not found") +) + var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", @@ -81,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t } if nodeID == 0 { - return nil, fmt.Errorf("invalid nodeID: %d", nodeID) + return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID) } if mapper == nil { - return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) + return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID) } // Handle self-only responses @@ -136,7 +144,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { - return errors.New("nodeConnection is nil") + return ErrNodeConnectionNil } nodeID := nc.nodeID() diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index d4d5dd9b..1be722d4 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -2,6 +2,7 @@ package mapper import ( "crypto/rand" + "encoding/hex" "errors" "fmt" "sync" @@ -18,7 +19,13 @@ import ( "tailscale.com/types/ptr" ) -var errConnectionClosed = errors.New("connection channel already closed") +// LockFreeBatcher errors. +var ( + errConnectionClosed = errors.New("connection channel already closed") + ErrInitialMapSendTimeout = errors.New("sending initial map: timeout") + ErrBatcherShuttingDown = errors.New("batcher shutting down") + ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)") +) // LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. type LockFreeBatcher struct { @@ -81,6 +88,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse if err != nil { nlog.Error().Err(err).Msg("initial map generation failed") nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("generating initial map for node %d: %w", id, err) } @@ -90,11 +98,12 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse case c <- initialMap: // Success case <-time.After(5 * time.Second): //nolint:mnd - nlog.Error().Err(errors.New("timeout")).Msg("initial map send timeout") //nolint:err113 - nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd - Msg("initial map send timed out because channel was blocked or receiver not ready") + nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout") + nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd + Msg("initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) - return fmt.Errorf("sending initial map to node %d: timeout", id) + + return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id) } // Update connection status @@ -135,6 +144,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo nlog.Debug().Caller(). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("node connection removed but keeping online, other connections remain") + return true // Node still has active connections } @@ -219,10 +229,12 @@ func (b *LockFreeBatcher) worker(workerID int) { // This is used for synchronous map generation. if w.resultCh != nil { var result workResult + if nc, exists := b.nodes.Load(w.nodeID); exists { var err error result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + result.err = err if result.err != nil { b.workErrors.Add(1) @@ -235,7 +247,7 @@ func (b *LockFreeBatcher) worker(workerID int) { nc.updateSentPeers(result.mapResponse) } } else { - result.err = fmt.Errorf("node %d not found", w.nodeID) + result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID) b.workErrors.Add(1) wlog.Error().Err(result.err). @@ -402,6 +414,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() { } } } + return true }) @@ -454,6 +467,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { if nodeConn.hasActiveConnections() { ret.Store(id, true) } + return true }) @@ -469,6 +483,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret.Store(id, false) } } + return true }) @@ -488,7 +503,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang case result := <-resultCh: return result.mapResponse, result.err case <-b.done: - return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) + return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id) } } @@ -523,8 +538,9 @@ type multiChannelNodeConn struct { // generateConnectionID generates a unique connection identifier. func generateConnectionID() string { bytes := make([]byte, 8) - rand.Read(bytes) - return fmt.Sprintf("%x", bytes) + _, _ = rand.Read(bytes) + + return hex.EncodeToString(bytes) } // newMultiChannelNodeConn creates a new multi-channel node connection. @@ -557,7 +573,9 @@ func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") mc.mutex.Lock() + mutexWaitDur := time.Since(mutexWaitStart) + defer mc.mutex.Unlock() mc.connections = append(mc.connections, entry) @@ -579,9 +597,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)). Int("remaining_connections", len(mc.connections)). Msg("successfully removed connection") + return true } } + return false } @@ -615,6 +635,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // This is not an error - the node will receive a full map when it reconnects mc.log.Debug().Caller(). Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + return nil // Return success instead of error } @@ -623,7 +644,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: broadcasting to all connections") var lastErr error + successCount := 0 + var failedConnections []int // Track failed connections for removal // Send to all connections @@ -632,8 +655,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i). Msg("send: attempting to send to connection") - if err := conn.send(data); err != nil { + err := conn.send(data) + if err != nil { lastErr = err + failedConnections = append(failedConnections, i) mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)). Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i). @@ -695,7 +720,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { case <-time.After(50 * time.Millisecond): // Connection is likely stale - client isn't reading from channel // This catches the case where Docker containers are killed but channels remain open - return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id) + return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout) } } @@ -805,6 +830,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { Connected: connected, ActiveConnections: activeConnCount, } + return true }) @@ -819,6 +845,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { ActiveConnections: 0, } } + return true }) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 70d5e377..9e544633 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -35,6 +35,7 @@ type batcherTestCase struct { // that would normally be sent by poll.go in production. type testBatcherWrapper struct { Batcher + state *state.State } @@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe } // Finally remove from the real batcher - removed := t.Batcher.RemoveNode(id, c) - if !removed { - return false - } - - return true + return t.Batcher.RemoveNode(id, c) } // wrapBatcherForTest wraps a batcher with test-specific behavior. @@ -129,8 +125,6 @@ const ( SMALL_BUFFER_SIZE = 3 TINY_BUFFER_SIZE = 1 // For maximum contention LARGE_BUFFER_SIZE = 200 - - reservedResponseHeaderSize = 4 ) // TestData contains all test entities created for a test scenario. @@ -241,8 +235,8 @@ func setupBatcherWithTestData( } derpMap, err := derp.GetDERPMap(cfg.DERP) - assert.NoError(t, err) - assert.NotNil(t, derpMap) + require.NoError(t, err) + require.NotNil(t, derpMap) state.SetDERPMap(derpMap) @@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) { } // getStats returns a copy of the statistics for a node. +// +//nolint:unused func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { ut.mu.RLock() defer ut.mu.RUnlock() @@ -386,16 +382,14 @@ type UpdateInfo struct { } // parseUpdateAndAnalyze parses an update and returns detailed information. -func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { - info := UpdateInfo{ +func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo { + return UpdateInfo{ PeerCount: len(resp.Peers), PatchCount: len(resp.PeersChangedPatch), IsFull: len(resp.Peers) > 0, IsPatch: len(resp.PeersChangedPatch) > 0, IsDERP: resp.DERPMap != nil, } - - return info, nil } // start begins consuming updates from the node's channel and tracking stats. @@ -417,7 +411,8 @@ func (n *node) start() { atomic.AddInt64(&n.updateCount, 1) // Parse update and track detailed stats - if info, err := parseUpdateAndAnalyze(data); err == nil { + info := parseUpdateAndAnalyze(data) + { // Track update types if info.IsFull { atomic.AddInt64(&n.fullCount, 1) @@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) // Wait for connection to be established assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { for i := range allNodes { node := &allNodes[i] - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullUpdate()) @@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { connectedCount := 0 + for i := range allNodes { node := &allNodes[i] @@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { }, 5*time.Minute, 5*time.Second, "waiting for full connectivity") t.Logf("✅ All nodes achieved full connectivity!") + totalTime := time.Since(startTime) // Disconnect all nodes @@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) { defer cleanup() batcher := testData.Batcher - tn := testData.Nodes[0] - tn2 := testData.Nodes[1] + tn := &testData.Nodes[0] + tn2 := &testData.Nodes[1] // Test AddNode with real node ID - batcher.AddNode(tn.n.ID, tn.ch, 100) + _ = batcher.AddNode(tn.n.ID, tn.ch, 100) if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") @@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) { } // Drain any initial messages from first node - drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) + drainChannelTimeout(tn.ch, 100*time.Millisecond) // Add the second node and verify update message - batcher.AddNode(tn2.n.ID, tn2.ch, 100) + _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. @@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) { } } -func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { - count := 0 - +func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) { timer := time.NewTimer(timeout) defer timer.Stop() for { select { - case data := <-ch: - count++ - // Optional: add debug output if needed - _ = data + case <-ch: + // Drain message case <-timer.C: return } @@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // even when real node updates are being processed, ensuring no race conditions // occur during channel replacement with actual workload. func XTestBatcherChannelClosingRace(t *testing.T) { + t.Helper() + for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { // Create test environment with real database and nodes @@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { defer cleanup() batcher := testData.Batcher - testNode := testData.Nodes[0] + testNode := &testData.Nodes[0] var ( channelIssues int @@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, 1) wg.Go(func() { - batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) }) // Add real work during connection chaos @@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Go(func() { runtime.Gosched() // Yield to introduce timing variability - batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) + + _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }) // Remove second connection @@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { defer cleanup() batcher := testData.Batcher - testNode := testData.Nodes[0] + testNode := &testData.Nodes[0] var ( panics int @@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) batcher.AddWork(change.DERPMap()) // Consumer goroutine to validate data and detect channel issues @@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { for range i % 3 { runtime.Gosched() // Introduce timing variability } + batcher.RemoveNode(testNode.n.ID, ch) // Yield to allow workers to process and close channels @@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // real node data. The test validates that stable clients continue to function // normally and receive proper updates despite the connection churn from other clients, // ensuring system stability under concurrent load. +// +//nolint:gocyclo // complex concurrent test scenario func TestBatcherConcurrentClients(t *testing.T) { if testing.Short() { t.Skip("Skipping concurrent client test in short mode") @@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) { stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) - for _, node := range stableNodes { + for i := range stableNodes { + node := &stableNodes[i] ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, reason := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) { // Connection churn cycles - rapidly connect/disconnect to test concurrency safety for i := range numCycles { - for _, node := range churningNodes { + for j := range churningNodes { + node := &churningNodes[j] + wg.Add(2) // Connect churning node @@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) { ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() - batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { @@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, _ := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for range i % 5 { runtime.Gosched() // Introduce timing variability } + churningChannelsMutex.Lock() ch, exists := churningChannels[nodeID] @@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) { if i%7 == 0 && len(allNodes) > 0 { // Node-specific changes using real nodes - node := allNodes[i%len(allNodes)] + node := &allNodes[i%len(allNodes)] // Use a valid expiry time for testing since test nodes don't have expiry set testExpiry := time.Now().Add(24 * time.Hour) batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry)) @@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) { t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls", expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork) - for _, node := range stableNodes { + for i := range stableNodes { + node := &stableNodes[i] if stats, exists := allStats[node.n.ID]; exists { stableUpdateCount += stats.TotalUpdates t.Logf("Stable node %d: %d updates", @@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) { } } - for _, node := range churningNodes { + for i := range churningNodes { + node := &churningNodes[i] if stats, exists := allStats[node.n.ID]; exists { churningUpdateCount += stats.TotalUpdates } @@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) { } // Verify all stable clients are still functional - for _, node := range stableNodes { + for i := range stableNodes { + node := &stableNodes[i] if !batcher.IsConnected(node.n.ID) { t.Errorf("Stable node %d lost connection during racing", node.n.ID) } @@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) { // It validates that the system remains stable with no deadlocks, panics, or // missed updates under sustained high load. The test uses real node data to // generate authentic update scenarios and tracks comprehensive statistics. +// +//nolint:gocyclo,thelper // complex scalability test scenario func XTestBatcherScalability(t *testing.T) { if testing.Short() { t.Skip("Skipping scalability test in short mode") @@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) { description string } - var testCases []testCase + testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes)) // Generate all combinations of the test matrix for _, nodeCount := range nodes { @@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) { for i := range testNodes { node := &testNodes[i] - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true @@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) { } // Connection/disconnection cycles for subset of nodes - for i, node := range chaosNodes { + for i := range chaosNodes { + node := &chaosNodes[i] // Only add work if this is connection chaos or mixed if tc.chaosType == "connection" || tc.chaosType == "mixed" { wg.Add(2) @@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) { channel, tailcfg.CapabilityVersion(100), ) + connectedNodesMutex.Lock() connectedNodes[nodeID] = true @@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) { t.Logf("Created %d nodes in database", len(allNodes)) // Connect nodes one at a time and wait for each to be connected - for i, node := range allNodes { - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + for i := range allNodes { + node := &allNodes[i] + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Wait for node to be connected @@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) { }, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect") // Check how many peers each node should see - for i, node := range allNodes { + for i := range allNodes { + node := &allNodes[i] peers := testData.State.ListPeers(node.n.ID) t.Logf("Node %d should see %d peers from state", i, peers.Len()) } @@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 1: Connect all nodes initially t.Logf("Phase 1: Connecting all nodes...") - for i, node := range allNodes { + + for i := range allNodes { + node := &allNodes[i] + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add node %d: %v", i, err) @@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) t.Logf("Phase 2: Rapid disconnect all nodes...") - for i, node := range allNodes { + + for i := range allNodes { + node := &allNodes[i] removed := batcher.RemoveNode(node.n.ID, node.ch) t.Logf("Node %d RemoveNode result: %t", i, removed) } // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) t.Logf("Phase 3: Rapid reconnect with new channels...") + newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) - for i, node := range allNodes { + for i := range allNodes { + node := &allNodes[i] newChannels[i] = make(chan *tailcfg.MapResponse, 10) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) if err != nil { t.Errorf("Failed to reconnect node %d: %v", i, err) @@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) { debugInfo := debugBatcher.Debug() disconnectedCount := 0 - for i, node := range allNodes { + for i := range allNodes { + node := &allNodes[i] if info, exists := debugInfo[node.n.ID]; exists { t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info) @@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) { if infoMap, ok := info.(map[string]any); ok { if connected, ok := infoMap["connected"].(bool); ok && !connected { disconnectedCount++ + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) } } } else { disconnectedCount++ + t.Logf("Node %d missing from debug info entirely", i) } @@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) { case update := <-newChannels[i]: if update != nil { receivedCount++ + t.Logf("Node %d received update successfully", i) } case <-timeout: @@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) { } } +//nolint:gocyclo // complex multi-connection test scenario func TestBatcherMultiConnection(t *testing.T) { for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { @@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) { defer cleanup() batcher := testData.Batcher - node1 := testData.Nodes[0] - node2 := testData.Nodes[1] + node1 := &testData.Nodes[0] + node2 := &testData.Nodes[1] t.Logf("=== MULTI-CONNECTION TEST ===") // Phase 1: Connect first node with initial connection t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add node1: %v", err) @@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 2: Add second connection for node1 (multi-connection scenario) t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add second connection for node1: %v", err) @@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 3: Add third connection for node1 t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add third connection for node1: %v", err) @@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 4: Verify debug status shows correct connection count t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { Debug() map[types.NodeID]any }); ok { @@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) { if info, exists := debugInfo[node1.n.ID]; exists { t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { if activeConnections, ok := infoMap["active_connections"].(int); ok { if activeConnections != 3 { @@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) { t.Logf("SUCCESS: Node1 correctly shows 3 active connections") } } + if connected, ok := infoMap["connected"].(bool); ok && !connected { t.Errorf("Node1 should show as connected with 3 active connections") } diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index c666ff24..58848883 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -1,7 +1,6 @@ package mapper import ( - "errors" "net/netip" "sort" "time" @@ -36,6 +35,7 @@ const ( // NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() + return &MapResponseBuilder{ resp: &tailcfg.MapResponse{ KeepAlive: false, @@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { nv, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFoundMapper) return b } @@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, } + return b } @@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFoundMapper) return b } @@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFoundMapper) return b } @@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFoundMapper) return b } @@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFoundMapper) return b } @@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - return nil, errors.New("node not found") + return nil, ErrNodeNotFoundMapper } // Get unreduced matchers for peer relationship determination. @@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) // WithPeersRemoved adds removed peer IDs. func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { - var tailscaleIDs []tailcfg.NodeID + tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs)) for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } + b.resp.PeersRemoved = tailscaleIDs return b } -// Build finalizes the response and returns marshaled bytes +// Build finalizes the response and returns marshaled bytes. func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { if len(b.errs) > 0 { return nil, multierr.New(b.errs...) } + if debugDumpMapResponsePath != "" { writeDebugMapResponse(b.resp, b.debugType, b.nodeID) } diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go index 978b2c0e..3de60c97 100644 --- a/hscontrol/mapper/builder_test.go +++ b/hscontrol/mapper/builder_test.go @@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) { // Build should return a multierr data, err := result.Build() - assert.Nil(t, data) - assert.Error(t, err) + require.Nil(t, data) + require.Error(t, err) // The error should contain information about multiple errors assert.Contains(t, err.Error(), "multiple errors") diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 47d222b4..4505f765 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -24,7 +24,6 @@ import ( const ( nextDNSDoHPrefix = "https://dns.nextdns.io" - mapperIDLength = 8 debugMapResponsePerm = 0o755 ) @@ -50,6 +49,7 @@ type mapper struct { created time.Time } +//nolint:unused type patch struct { timestamp time.Time change *tailcfg.PeerChange @@ -60,7 +60,6 @@ func newMapper( state *state.State, ) *mapper { // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) - return &mapper{ state: state, cfg: cfg, @@ -76,6 +75,7 @@ func generateUserProfiles( ) []tailcfg.UserProfile { userMap := make(map[uint]*types.UserView) ids := make([]uint, 0, len(userMap)) + user := node.Owner() if !user.Valid() { log.Error(). @@ -84,14 +84,17 @@ func generateUserProfiles( return nil } + userID := user.Model().ID userMap[userID] = &user ids = append(ids, userID) + for _, peer := range peers.All() { peerUser := peer.Owner() if !peerUser.Valid() { continue } + peerUserID := peerUser.Model().ID userMap[peerUserID] = &peerUser ids = append(ids, peerUserID) @@ -99,7 +102,9 @@ func generateUserProfiles( slices.Sort(ids) ids = slices.Compact(ids) + var profiles []tailcfg.UserProfile + for _, id := range ids { if userMap[id] != nil { profiles = append(profiles, userMap[id].TailscaleUserProfile()) @@ -149,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { } // fullMapResponse returns a MapResponse for the given node. +// +//nolint:unused func (m *mapper) fullMapResponse( nodeID types.NodeID, capVer tailcfg.CapabilityVersion, @@ -316,6 +323,7 @@ func writeDebugMapResponse( perms := fs.FileMode(debugMapResponsePerm) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) + err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -329,6 +337,7 @@ func writeDebugMapResponse( ) log.Trace().Msgf("writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) if err != nil { panic(err) @@ -337,7 +346,7 @@ func writeDebugMapResponse( func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { if debugDumpMapResponsePath == "" { - return nil, nil + return nil, nil //nolint:nilnil // intentional: no data when debug path not set } return ReadMapResponsesFromDirectory(debugDumpMapResponsePath) @@ -350,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe } result := make(map[types.NodeID][]tailcfg.MapResponse) + for _, node := range nodes { if !node.IsDir() { continue @@ -385,6 +395,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe } var resp tailcfg.MapResponse + err = json.Unmarshal(body, &resp) if err != nil { log.Error().Err(err).Msgf("unmarshalling file %s", file.Name()) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 1bafd135..368e1829 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -3,14 +3,10 @@ package mapper import ( "fmt" "net/netip" - "slices" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -81,90 +77,3 @@ func TestDNSConfigMapResponse(t *testing.T) { }) } } - -// mockState is a mock implementation that provides the required methods. -type mockState struct { - polMan policy.PolicyManager - derpMap *tailcfg.DERPMap - primary *routes.PrimaryRoutes - nodes types.Nodes - peers types.Nodes -} - -func (m *mockState) DERPMap() *tailcfg.DERPMap { - return m.derpMap -} - -func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { - if m.polMan == nil { - return tailcfg.FilterAllowAll, nil - } - return m.polMan.Filter() -} - -func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - if m.polMan == nil { - return nil, nil - } - return m.polMan.SSHPolicy(node) -} - -func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { - if m.polMan == nil { - return false - } - return m.polMan.NodeCanHaveTag(node, tag) -} - -func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { - if m.primary == nil { - return nil - } - return m.primary.PrimaryRoutes(nodeID) -} - -func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - if len(peerIDs) > 0 { - // Filter peers by the provided IDs - var filtered types.Nodes - for _, peer := range m.peers { - if slices.Contains(peerIDs, peer.ID) { - filtered = append(filtered, peer) - } - } - - return filtered, nil - } - // Return all peers except the node itself - var filtered types.Nodes - for _, peer := range m.peers { - if peer.ID != nodeID { - filtered = append(filtered, peer) - } - } - - return filtered, nil -} - -func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { - if len(nodeIDs) > 0 { - // Filter nodes by the provided IDs - var filtered types.Nodes - for _, node := range m.nodes { - if slices.Contains(nodeIDs, node.ID) { - filtered = append(filtered, node) - } - } - - return filtered, nil - } - - return m.nodes, nil -} - -func Test_fullMapResponse(t *testing.T) { - t.Skip("Test needs to be refactored for new state-based architecture") - // TODO: Refactor this test to work with the new state-based mapper - // The test architecture needs to be updated to work with the state interface - // instead of the old direct dependency injection pattern -} diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 5b7030de..70572f5a 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -19,6 +19,7 @@ import ( func TestTailNode(t *testing.T) { mustNK := func(str string) key.NodePublic { var k key.NodePublic + _ = k.UnmarshalText([]byte(str)) return k @@ -26,6 +27,7 @@ func TestTailNode(t *testing.T) { mustDK := func(str string) key.DiscoPublic { var k key.DiscoPublic + _ = k.UnmarshalText([]byte(str)) return k @@ -33,6 +35,7 @@ func TestTailNode(t *testing.T) { mustMK := func(str string) key.MachinePublic { var k key.MachinePublic + _ = k.UnmarshalText([]byte(str)) return k @@ -255,7 +258,7 @@ func TestNodeExpiry(t *testing.T) { }, { name: "localtime", - exp: tp(time.Time{}.Local()), + exp: tp(time.Time{}.Local()), //nolint:gosmopolitan wantTimeZero: true, }, } @@ -284,7 +287,9 @@ func TestNodeExpiry(t *testing.T) { if err != nil { t.Fatalf("nodeExpiry() error = %v", err) } + var deseri tailcfg.Node + err = json.Unmarshal(seri, &deseri) if err != nil { t.Fatalf("nodeExpiry() error = %v", err) diff --git a/hscontrol/metrics.go b/hscontrol/metrics.go index 749d651e..09cbc393 100644 --- a/hscontrol/metrics.go +++ b/hscontrol/metrics.go @@ -71,6 +71,7 @@ func prometheusMiddleware(next http.Handler) http.Handler { rw := &respWriterProm{ResponseWriter: w} timer := prometheus.NewTimer(httpDuration.WithLabelValues(path)) + next.ServeHTTP(rw, r) timer.ObserveDuration() httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc() @@ -79,6 +80,7 @@ func prometheusMiddleware(next http.Handler) http.Handler { type respWriterProm struct { http.ResponseWriter + status int written int64 wroteHeader bool @@ -94,6 +96,7 @@ func (r *respWriterProm) Write(b []byte) (int, error) { if !r.wroteHeader { r.WriteHeader(http.StatusOK) } + n, err := r.ResponseWriter.Write(b) r.written += int64(n) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 149891f5..1e974408 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -19,6 +19,9 @@ import ( "tailscale.com/types/key" ) +// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version. +var ErrUnsupportedClientVersion = errors.New("unsupported client version") + const ( // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. ts2021UpgradePath = "/ts2021" @@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler( } func unsupportedClientError(version tailcfg.CapabilityVersion) error { - return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version) + return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version) } func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { @@ -137,17 +140,20 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { // an HTTP/2 settings frame, which isn't of type 'T') var notH2Frame [5]byte copy(notH2Frame[:], earlyPayloadMagic) + var lenBuf [4]byte - binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) + binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) //nolint:gosec // JSON length is bounded // These writes are all buffered by caller, so fine to do them // separately: - if _, err := writer.Write(notH2Frame[:]); err != nil { + if _, err := writer.Write(notH2Frame[:]); err != nil { //nolint:noinlineerr return err } - if _, err := writer.Write(lenBuf[:]); err != nil { + + if _, err := writer.Write(lenBuf[:]); err != nil { //nolint:noinlineerr return err } - if _, err := writer.Write(earlyJSON); err != nil { + + if _, err := writer.Write(earlyJSON); err != nil { //nolint:noinlineerr return err } @@ -199,7 +205,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( body, _ := io.ReadAll(req.Body) var mapRequest tailcfg.MapRequest - if err := json.Unmarshal(body, &mapRequest); err != nil { + if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr httpError(writer, err) return } @@ -219,6 +225,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct()) sess.log.Trace().Caller().Msg("a node sending a MapRequest with Noise protocol") + if !sess.isStreaming() { sess.serve() } else { @@ -241,14 +248,16 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } - registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { + registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck var resp *tailcfg.RegisterResponse + body, err := io.ReadAll(req.Body) if err != nil { return &tailcfg.RegisterRequest{}, regErr(err) } + var regReq tailcfg.RegisterRequest - if err := json.Unmarshal(body, ®Req); err != nil { + if err := json.Unmarshal(body, ®Req); err != nil { //nolint:noinlineerr return ®Req, regErr(err) } @@ -261,6 +270,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( resp = &tailcfg.RegisterResponse{ Error: httpErr.Msg, } + return ®Req, resp } @@ -278,7 +288,8 @@ func (ns *noiseServer) NoiseRegistrationHandler( writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { + err := json.NewEncoder(writer).Encode(registerResponse) + if err != nil { log.Error().Caller().Err(err).Msg("noise registration handler: failed to encode RegisterResponse") return } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 772e84d6..9d284921 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -68,7 +68,7 @@ func NewAuthProviderOIDC( ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already - oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) + oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) //nolint:contextcheck if err != nil { return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err) } @@ -163,6 +163,7 @@ func (a *AuthProviderOIDC) RegisterHandler( for k, v := range a.cfg.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info @@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } stateCookieName := getCookieName("state", state) + cookieState, err := req.Cookie(stateCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) @@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( httpError(writer, err) return } + if idToken.Nonce == "" { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) return } nonceCookieName := getCookieName("nonce", idToken.Nonce) + nonce, err := req.Cookie(nonceCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } + if idToken.Nonce != nonce.Value { httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return @@ -231,7 +236,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( nodeExpiry := a.determineNodeExpiry(idToken.Expiry) var claims types.OIDCClaims - if err := idToken.Claims(&claims); err != nil { + if err := idToken.Claims(&claims); err != nil { //nolint:noinlineerr httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) return } @@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Fetch user information (email, groups, name, etc) from the userinfo endpoint // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo var userinfo *oidc.UserInfo + userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token)) if err != nil { util.LogErr(err, "could not get userinfo; only using claims from id token") @@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified) claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username) claims.Name = cmp.Or(userinfo2.Name, claims.Name) + claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL) if userinfo2.Groups != nil { claims.Groups = userinfo2.Groups @@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Msgf("could not create or update user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) + _, werr := writer.Write([]byte("Could not create or update user")) if werr != nil { log.Error(). @@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Register the node if it does not exist. if registrationId != nil { verb := "Reauthenticated" + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { @@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } + httpError(writer, err) + return } @@ -316,15 +327,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } // TODO(kradalby): replace with go-elem - content, err := renderOIDCCallbackTemplate(user, verb) - if err != nil { - httpError(writer, err) - return - } + content := renderOIDCCallbackTemplate(user, verb) writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - if _, err := writer.Write(content.Bytes()); err != nil { + + if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr util.LogErr(err, "Failed to write HTTP response") } @@ -370,6 +378,7 @@ func (a *AuthProviderOIDC) getOauth2Token( if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } + if regInfo.Verifier != nil { exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} } @@ -394,6 +403,7 @@ func (a *AuthProviderOIDC) extractIDToken( } verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID}) + idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("verifying ID token: %w", err)) @@ -516,6 +526,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( newUser bool c change.Change ) + user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) @@ -589,9 +600,9 @@ func (a *AuthProviderOIDC) handleRegistration( func renderOIDCCallbackTemplate( user *types.User, verb string, -) (*bytes.Buffer, error) { +) *bytes.Buffer { html := templates.OIDCCallback(user.Display(), verb).Render() - return bytes.NewBufferString(html), nil + return bytes.NewBufferString(html) } // getCookieName generates a unique cookie name based on a cookie value. diff --git a/hscontrol/platform_config.go b/hscontrol/platform_config.go index 23c4d25d..74929ea9 100644 --- a/hscontrol/platform_config.go +++ b/hscontrol/platform_config.go @@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage( ) { writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) + _, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) } // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. @@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage( ) { writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) + _, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) } func (h *Headscale) ApplePlatformConfig( @@ -37,6 +37,7 @@ func (h *Headscale) ApplePlatformConfig( req *http.Request, ) { vars := mux.Vars(req) + platform, ok := vars["platform"] if !ok { httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil)) @@ -64,17 +65,20 @@ func (h *Headscale) ApplePlatformConfig( switch platform { case "macos-standalone": - if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil { + err := macosStandaloneTemplate.Execute(&payload, platformConfig) + if err != nil { httpError(writer, err) return } case "macos-app-store": - if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil { + err := macosAppStoreTemplate.Execute(&payload, platformConfig) + if err != nil { httpError(writer, err) return } case "ios": - if err := iosTemplate.Execute(&payload, platformConfig); err != nil { + err := iosTemplate.Execute(&payload, platformConfig) + if err != nil { httpError(writer, err) return } @@ -90,7 +94,7 @@ func (h *Headscale) ApplePlatformConfig( } var content bytes.Buffer - if err := commonTemplate.Execute(&content, config); err != nil { + if err := commonTemplate.Execute(&content, config); err != nil { //nolint:noinlineerr httpError(writer, err) return } @@ -98,7 +102,7 @@ func (h *Headscale) ApplePlatformConfig( writer.Header(). Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write(content.Bytes()) + _, _ = writer.Write(content.Bytes()) } type AppleMobileConfig struct { diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index 1e6312b8..b52cb4dc 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -16,15 +16,18 @@ type Match struct { dests *netipx.IPSet } -func (m Match) DebugString() string { +func (m *Match) DebugString() string { var sb strings.Builder sb.WriteString("Match:\n") sb.WriteString(" Sources:\n") + for _, prefix := range m.srcs.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } + sb.WriteString(" Destinations:\n") + for _, prefix := range m.dests.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } @@ -42,7 +45,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { } func MatchFromFilterRule(rule tailcfg.FilterRule) Match { - dests := []string{} + dests := make([]string, 0, len(rule.DstPorts)) for _, dest := range rule.DstPorts { dests = append(dests, dest.IP) } @@ -98,7 +101,7 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { // cased for exit nodes. // This checks if dests is a superset of TheInternet(), which handles // merged filter rules where TheInternet is combined with other destinations. -func (m Match) DestsIsTheInternet() bool { +func (m *Match) DestsIsTheInternet() bool { if m.dests.ContainsPrefix(tsaddr.AllIPv4()) || m.dests.ContainsPrefix(tsaddr.AllIPv6()) { return true diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index f4db88a4..6dfacd91 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,18 +19,18 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) - SetPolicy([]byte) (bool, error) + SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SetPolicy(pol []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) // NodeCanHaveTag reports whether the given node can have the given tag. - NodeCanHaveTag(types.NodeView, string) bool + NodeCanHaveTag(node types.NodeView, tag string) bool // TagExists reports whether the given tag is defined in the policy. TagExists(tag string) bool // NodeCanApproveRoute reports whether the given node can approve the given route. - NodeCanApproveRoute(types.NodeView, netip.Prefix) bool + NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool Version() int DebugString() string @@ -38,8 +38,11 @@ type PolicyManager interface { // NewPolicyManager returns a new policy manager. func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { - var polMan PolicyManager - var err error + var ( + polMan PolicyManager + err error + ) + polMan, err = policyv2.NewPolicyManager(pol, users, nodes) if err != nil { return nil, err @@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ if err != nil { return nil, err } + polMans = append(polMans, pm) } @@ -66,7 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ } func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { - var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) + polmanFuncs := make([]func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error), 0, 1) polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { return policyv2.NewPolicyManager(pol, u, n) diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 9d9545f8..e598349e 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -126,6 +126,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove if !slices.Equal(sortedCurrent, newApproved) { // Log what changed var added, kept []netip.Prefix + for _, route := range newApproved { if !slices.Contains(sortedCurrent, route) { added = append(added, route) diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 61c69067..a9b36f75 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -9,6 +9,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/types/key" @@ -76,7 +77,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { }` pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -313,11 +314,14 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { nodes := types.Nodes{&node} // Create policy manager or use nil if specified - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + if tt.name != "nil_policy_manager" { pm, err = pmf(users, nodes.ViewSlice()) - assert.NoError(t, err) + require.NoError(t, err) } else { pm = nil } diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index b62e94db..486fdec7 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -33,6 +33,7 @@ func TestReduceNodes(t *testing.T) { rules []tailcfg.FilterRule node *types.Node } + tests := []struct { name string args args @@ -783,9 +784,11 @@ func TestReduceNodes(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -796,7 +799,7 @@ func TestReduceNodes(t *testing.T) { func TestReduceNodesFromPolicy(t *testing.T) { n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { - var routes []netip.Prefix + routes := make([]netip.Prefix, 0, len(routess)) for _, route := range routess { routes = append(routes, netip.MustParsePrefix(route)) } @@ -1034,8 +1037,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(nil, tt.nodes.ViewSlice()) require.NoError(t, err) @@ -1053,9 +1059,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -1233,7 +1241,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, + errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`, }, { name: "invalid-check-period", @@ -1280,7 +1288,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: "autogroup \"autogroup:invalid\" is not supported", + errorMessage: "autogroup not supported for SSH user", }, { name: "autogroup-nonroot-should-use-wildcard-with-root-excluded", @@ -1453,13 +1461,17 @@ func TestSSHPolicyRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) if tt.expectErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errorMessage) + return } @@ -1482,6 +1494,7 @@ func TestReduceRoutes(t *testing.T) { routes []netip.Prefix rules []tailcfg.FilterRule } + tests := []struct { name string args args @@ -2103,6 +2116,7 @@ func TestReduceRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( tt.args.node.View(), tt.args.routes, diff --git a/hscontrol/policy/policyutil/reduce.go b/hscontrol/policy/policyutil/reduce.go index e4549c10..6d95a297 100644 --- a/hscontrol/policy/policyutil/reduce.go +++ b/hscontrol/policy/policyutil/reduce.go @@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf for _, rule := range rules { // record if the rule is actually relevant for the given node. var dests []tailcfg.NetPortRange + DEST_LOOP: for _, dest := range rule.DstPorts { expanded, err := util.ParseIPSet(dest.IP, nil) diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go index 0851e303..ced422c4 100644 --- a/hscontrol/policy/policyutil/reduce_test.go +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -798,10 +798,14 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm policy.PolicyManager - var err error + var ( + pm policy.PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) require.NoError(t, err) + got, _ := pm.Filter() t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) got = policyutil.ReduceFilterRules(tt.node.View(), got) diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 39b15cee..7393b3b2 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -830,6 +830,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if tt.name == "empty policy" { // We expect this one to have a valid but empty policy require.NoError(t, err) + if err != nil { return } @@ -844,6 +845,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if diff := cmp.Diff(tt.canApprove, result); diff != "" { t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) } + assert.Equal(t, tt.canApprove, result, "Unexpected route approval result") }) } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 7a7c6629..9c2c5f17 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -17,7 +17,10 @@ import ( "tailscale.com/types/views" ) -var ErrInvalidAction = errors.New("invalid action") +var ( + ErrInvalidAction = errors.New("invalid action") + errSelfInSources = errors.New("autogroup:self cannot be used in sources") +) // compileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. @@ -45,9 +48,10 @@ func (pol *Policy) compileFilterRules( continue } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { // Check if destination is a wildcard - use "*" directly instead of expanding if _, isWildcard := dest.Alias.(Asterix); isWildcard { @@ -142,14 +146,18 @@ func (pol *Policy) compileFilterRulesForNode( // It returns a slice of filter rules because when an ACL has both autogroup:self // and other destinations, they need to be split into separate rules with different // source filtering logic. +// +//nolint:gocyclo // complex ACL compilation logic func (pol *Policy) compileACLWithAutogroupSelf( acl ACL, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], ) ([]*tailcfg.FilterRule, error) { - var autogroupSelfDests []AliasWithPorts - var otherDests []AliasWithPorts + var ( + autogroupSelfDests []AliasWithPorts + otherDests []AliasWithPorts + ) for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -159,14 +167,15 @@ func (pol *Policy) compileACLWithAutogroupSelf( } } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() + var rules []*tailcfg.FilterRule var resolvedSrcIPs []*netipx.IPSet for _, src := range acl.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - return nil, fmt.Errorf("autogroup:self cannot be used in sources") + return nil, errSelfInSources } ips, err := src.Resolve(pol, users, nodes) @@ -188,6 +197,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(autogroupSelfDests) > 0 && !node.IsTagged() { // Pre-filter to same-user untagged devices once - reuse for both sources and destinations sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if !n.IsTagged() && n.User().ID() == node.User().ID() { sameUserNodes = append(sameUserNodes, n) @@ -197,6 +207,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(sameUserNodes) > 0 { // Filter sources to only same-user untagged devices var srcIPs netipx.IPSetBuilder + for _, ips := range resolvedSrcIPs { for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set @@ -213,6 +224,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if srcSet != nil && len(srcSet.Prefixes()) > 0 { var destPorts []tailcfg.NetPortRange + for _, dest := range autogroupSelfDests { for _, n := range sameUserNodes { for _, port := range dest.Ports { @@ -318,13 +330,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { } } +//nolint:gocyclo // complex SSH policy compilation logic func (pol *Policy) compileSSHPolicy( users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], ) (*tailcfg.SSHPolicy, error) { if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { - return nil, nil + return nil, nil //nolint:nilnil // intentional: no SSH policy when none configured } log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname()) @@ -335,8 +348,10 @@ func (pol *Policy) compileSSHPolicy( // Separate destinations into autogroup:self and others // This is needed because autogroup:self requires filtering sources to same-user only, // while other destinations should use all resolved sources - var autogroupSelfDests []Alias - var otherDests []Alias + var ( + autogroupSelfDests []Alias + otherDests []Alias + ) for _, dst := range rule.Destinations { if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -359,6 +374,7 @@ func (pol *Policy) compileSSHPolicy( } var action tailcfg.SSHAction + switch rule.Action { case SSHActionAccept: action = sshAction(true, 0) @@ -374,9 +390,11 @@ func (pol *Policy) compileSSHPolicy( // by default, we do not allow root unless explicitly stated userMap["root"] = "" } + if rule.Users.ContainsRoot() { userMap["root"] = "root" } + for _, u := range rule.Users.NormalUsers() { userMap[u.String()] = u.String() } @@ -386,6 +404,7 @@ func (pol *Policy) compileSSHPolicy( if len(autogroupSelfDests) > 0 && !node.IsTagged() { // Build destination set for autogroup:self (same-user untagged devices only) var dest netipx.IPSetBuilder + for _, n := range nodes.All() { if !n.IsTagged() && n.User().ID() == node.User().ID() { n.AppendToIPSet(&dest) @@ -402,6 +421,7 @@ func (pol *Policy) compileSSHPolicy( // Filter sources to only same-user untagged devices // Pre-filter to same-user untagged devices for efficiency sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if !n.IsTagged() && n.User().ID() == node.User().ID() { sameUserNodes = append(sameUserNodes, n) @@ -409,6 +429,7 @@ func (pol *Policy) compileSSHPolicy( } var filteredSrcIPs netipx.IPSetBuilder + for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { @@ -444,11 +465,13 @@ func (pol *Policy) compileSSHPolicy( if len(otherDests) > 0 { // Build destination set for other destinations var dest netipx.IPSetBuilder + for _, dst := range otherDests { ips, err := dst.Resolve(pol, users, nodes) if err != nil { log.Trace().Caller().Err(err).Msgf("resolving destination ips") } + if ips != nil { dest.AddSet(ips) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 0d9b44a3..44490f29 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -623,7 +623,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { if sshPolicy == nil { return // Expected empty result } + assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match") + return } @@ -709,7 +711,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } // TestSSHIntegrationReproduction reproduces the exact scenario from the integration test -// TestSSHOneUserToAll that was failing with empty sshUsers +// TestSSHOneUserToAll that was failing with empty sshUsers. func TestSSHIntegrationReproduction(t *testing.T) { // Create users matching the integration test users := types.Users{ @@ -775,7 +777,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { } // TestSSHJSONSerialization verifies that the SSH policy can be properly serialized -// to JSON and that the sshUsers field is not empty +// to JSON and that the sshUsers field is not empty. func TestSSHJSONSerialization(t *testing.T) { users := types.Users{ {Name: "user1", Model: gorm.Model{ID: 1}}, @@ -815,6 +817,7 @@ func TestSSHJSONSerialization(t *testing.T) { // Parse back to verify structure var parsed tailcfg.SSHPolicy + err = json.Unmarshal(jsonData, &parsed) require.NoError(t, err) @@ -899,6 +902,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if len(rules) != 1 { t.Fatalf("expected 1 rule, got %d", len(rules)) } @@ -915,6 +919,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { found := false addr := netip.MustParseAddr(expectedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -932,6 +937,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"} for _, excludedIP := range excludedSourceIPs { addr := netip.MustParseAddr(excludedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -1144,7 +1150,8 @@ func TestAutogroupTagged(t *testing.T) { require.NoError(t, err) // Verify autogroup:tagged includes all tagged nodes - taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice()) + ag := AutoGroupTagged + taggedIPs, err := ag.Resolve(policy, users, nodes.ViewSlice()) require.NoError(t, err) require.NotNil(t, taggedIPs) @@ -1366,14 +1373,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) { assert.Empty(t, rules3, "user3 should have no rules") } -// Helper function to create IP addresses for testing +// Helper function to create IP addresses for testing. func createAddr(ip string) *netip.Addr { addr, _ := netip.ParseAddr(ip) return &addr } // TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly -// with autogroup:self in destinations +// with autogroup:self in destinations. func TestSSHWithAutogroupSelfInDestination(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1421,6 +1428,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // Test for user2's first node @@ -1439,12 +1447,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule2.Principals { principalIPs2[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2) // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy3 != nil { assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self") } @@ -1452,7 +1462,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user // is in the source and autogroup:self in destination, only that user's devices -// can SSH (and only if they match the target user) +// can SSH (and only if they match the target user). func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1494,18 +1504,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1") } } -// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations +// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations. func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1552,19 +1564,21 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)") } } // TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices -// are excluded from both sources and destinations when autogroup:self is used +// are excluded from both sources and destinations when autogroup:self is used. func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1609,6 +1623,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs, "should only include untagged devices") @@ -1616,6 +1631,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self") } @@ -1664,10 +1680,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Verify autogroup:self rule has filtered sources (only same-user devices) selfRule := sshPolicy1.Rules[0] require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices") + selfPrincipals := make([]string, len(selfRule.Principals)) for i, p := range selfRule.Principals { selfPrincipals[i] = p.NodeIP } + require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals, "autogroup:self rule should only include same-user untagged devices") @@ -1679,10 +1697,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") routerRule := sshPolicyRouter.Rules[0] + routerPrincipals := make([]string, len(routerRule.Principals)) for i, p := range routerRule.Principals { routerPrincipals[i] = p.NodeIP } + require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 6472658a..74b7ba6a 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Filter: filter, Policy: pm.pol, }) + filterChanged := filterHash != pm.filterHash if filterChanged { log.Debug(). @@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("filter.rules.new", len(filter)). Msg("Policy filter hash changed") } + pm.filter = filter + pm.filterHash = filterHash if filterChanged { pm.matchers = matcher.MatchesFromFilterRules(pm.filter) @@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } tagOwnerMapHash := deephash.Hash(&tagMap) + tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash if tagOwnerChanged { log.Debug(). @@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("tagOwners.new", len(tagMap)). Msg("Tag owner hash changed") } + pm.tagOwnerMap = tagMap pm.tagOwnerMapHash = tagOwnerMapHash @@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } autoApproveMapHash := deephash.Hash(&autoMap) + autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash if autoApproveChanged { log.Debug(). @@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("autoApprovers.new", len(autoMap)). Msg("Auto-approvers hash changed") } + pm.autoApproveMap = autoMap pm.autoApproveMapHash = autoApproveMapHash exitSetHash := deephash.Hash(&exitSet) + exitSetChanged := exitSetHash != pm.exitSetHash if exitSetChanged { log.Debug(). @@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Str("exitSet.hash.new", exitSetHash.String()[:8]). Msg("Exit node set hash changed") } + pm.exitSet = exitSet pm.exitSetHash = exitSetHash @@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { if !needsUpdate { log.Trace(). Msg("Policy evaluation detected no changes - all hashes match") + return false, nil } @@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } + pm.sshPolicyMap[node.ID()] = sshPol return sshPol, nil @@ -403,6 +414,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil reducedFilter := policyutil.ReduceFilterRules(node, pm.filter) pm.filterRulesMap[node.ID()] = reducedFilter + return reducedFilter, nil } @@ -447,7 +459,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul // This is different from FilterForNode which returns REDUCED rules for packet filtering. // // For global policies: returns the global matchers (same for all nodes) -// For autogroup:self: returns node-specific matchers from unreduced compiled rules +// For autogroup:self: returns node-specific matchers from unreduced compiled rules. func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { if pm == nil { return nil, nil @@ -479,6 +491,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() + pm.users = users // Clear SSH policy map when users change to force SSH policy recomputation @@ -690,6 +703,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr if pm.exitSet == nil { return false } + if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) { return true } @@ -753,8 +767,10 @@ func (pm *PolicyManager) DebugString() string { } fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap)) + for prefix, approveAddrs := range pm.autoApproveMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range approveAddrs.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } @@ -763,14 +779,17 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap)) + for prefix, tagOwners := range pm.tagOwnerMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range tagOwners.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } } sb.WriteString("\n\n") + if pm.filter != nil { filter, err := json.MarshalIndent(pm.filter, "", " ") if err == nil { @@ -783,6 +802,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Matchers:\n") sb.WriteString("an internal structure used to filter nodes and routes\n") + for _, match := range pm.matchers { sb.WriteString(match.DebugString()) sb.WriteString("\n") @@ -790,6 +810,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Nodes:\n") + for _, node := range pm.nodes.All() { sb.WriteString(node.String()) sb.WriteString("\n") @@ -867,6 +888,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // Check if IPs changed (simple check - could be more sophisticated) oldIPs := oldNode.IPs() + newIPs := newNode.IPs() if len(oldIPs) != len(newIPs) { affectedUsers[newNode.User().ID()] = struct{}{} @@ -888,6 +910,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S for nodeID := range pm.filterRulesMap { // Find the user for this cached node var nodeUserID uint + found := false // Check in new nodes first @@ -899,8 +922,10 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S found = true break } + nodeUserID = node.User().ID() found = true + break } } @@ -913,8 +938,10 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S found = true break } + nodeUserID = node.User().ID() found = true + break } } diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 11f63cf7..2b2258a2 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -14,7 +14,7 @@ import ( "tailscale.com/types/ptr" ) -func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { +func node(name, ipv4, ipv6 string, user types.User) *types.Node { return &types.Node{ ID: 0, Hostname: name, @@ -22,7 +22,6 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) IPv6: ap(ipv6), User: ptr.To(user), UserID: ptr.To(user.ID), - Hostinfo: hostinfo, } } @@ -57,6 +56,7 @@ func TestPolicyManager(t *testing.T) { if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff( tt.wantMatchers, matchers, @@ -77,6 +77,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { {Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"}, } + //nolint:goconst // test-specific inline policy for clarity policy := `{ "acls": [ { @@ -88,14 +89,14 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { }` initialNodes := types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), } for i, n := range initialNodes { - n.ID = types.NodeID(i + 1) + n.ID = types.NodeID(i + 1) //nolint:gosec // safe conversion in test } pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice()) @@ -107,7 +108,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { require.NoError(t, err) } - require.Equal(t, len(initialNodes), len(pm.filterRulesMap)) + require.Len(t, pm.filterRulesMap, len(initialNodes)) tests := []struct { name string @@ -118,10 +119,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "no_changes", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 0, description: "No changes should clear no cache entries", @@ -129,11 +130,11 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "node_added", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0], nil), // New node - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0]), // New node + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 2, // user1's existing nodes should be cleared description: "Adding a node should clear cache for that user's existing nodes", @@ -141,10 +142,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "node_removed", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), // user1-node2 removed - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 2, // user1's remaining node + removed node should be cleared description: "Removing a node should clear cache for that user's remaining nodes", @@ -152,10 +153,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "user_changed", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2], nil), // Changed to user3 - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0]), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2]), // Changed to user3 + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 3, // user1's node + user2's node + user3's nodes should be cleared description: "Changing a node's user should clear cache for both old and new users", @@ -163,10 +164,10 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { { name: "ip_changed", newNodes: types.Nodes{ - node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0], nil), // IP changed - node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), - node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), - node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0]), // IP changed + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0]), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1]), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2]), }, expectedCleared: 2, // user1's nodes should be cleared description: "Changing a node's IP should clear cache for that user's nodes", @@ -177,15 +178,18 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { t.Run(tt.name, func(t *testing.T) { for i, n := range tt.newNodes { found := false + for _, origNode := range initialNodes { if n.Hostname == origNode.Hostname { n.ID = origNode.ID found = true + break } } + if !found { - n.ID = types.NodeID(len(initialNodes) + i + 1) + n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec // safe conversion in test } } @@ -370,16 +374,16 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) { // TestAutogroupSelfReducedVsUnreducedRules verifies that: // 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships -// 2. FilterForNode returns reduced compiled rules for packet filters +// 2. FilterForNode returns reduced compiled rules for packet filters. func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"} user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"} users := types.Users{user1, user2} // Create two nodes - node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil) + node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1) node1.ID = 1 - node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil) + node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2) node2.ID = 2 nodes := types.Nodes{node1, node2} @@ -410,6 +414,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations // For node1, destinations should only be node1's IPs node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"} + for _, rule := range filterNode1 { for _, dst := range rule.DstPorts { require.Contains(t, node1IPs, dst.IP, @@ -419,6 +424,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // For node2, destinations should only be node2's IPs node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"} + for _, rule := range filterNode2 { for _, dst := range rule.DstPorts { require.Contains(t, node2IPs, dst.IP, diff --git a/hscontrol/policy/v2/tailscale_compat_test.go b/hscontrol/policy/v2/tailscale_compat_test.go index 7124a1af..ac72cae2 100644 --- a/hscontrol/policy/v2/tailscale_compat_test.go +++ b/hscontrol/policy/v2/tailscale_compat_test.go @@ -9655,7 +9655,7 @@ func TestTailscaleCompatErrorCases(t *testing.T) { {"action": "accept", "src": ["tag:nonexistent"], "dst": ["tag:server:22"]} ] }`, - wantErr: `Tag "tag:nonexistent" is not defined in the Policy`, + wantErr: `tag not defined in policy: "tag:nonexistent"`, reference: "Test 6.4: tag:nonexistent → tag:server:22", }, @@ -9674,7 +9674,7 @@ func TestTailscaleCompatErrorCases(t *testing.T) { {"action": "accept", "src": ["autogroup:self"], "dst": ["tag:server:22"]} ] }`, - wantErr: `"autogroup:self" used in source, it can only be used in ACL destinations`, + wantErr: `autogroup:self can only be used in ACL destinations`, reference: "Test 13.41: autogroup:self as SOURCE", }, diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 3dcbb14e..c9d6b7d5 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -22,7 +22,7 @@ import ( "tailscale.com/util/slicesx" ) -// Global JSON options for consistent parsing across all struct unmarshaling +// Global JSON options for consistent parsing across all struct unmarshaling. var policyJSONOpts = []json.Options{ json.DefaultOptionsV2(), json.MatchCaseInsensitiveNames(true), @@ -51,6 +51,55 @@ var ( ErrACLAutogroupSelfInvalidSource = errors.New("autogroup:self destination requires sources to be users, groups, or autogroup:member only") ) +// Policy validation errors. +var ( + ErrUnknownAliasType = errors.New("unknown alias type") + ErrUnknownAutoApprover = errors.New("unknown auto approver type") + ErrUnknownOwnerType = errors.New("unknown owner type") + ErrInvalidUsername = errors.New("username must contain @") + ErrUserNotFound = errors.New("user not found") + ErrMultipleUsersFound = errors.New("multiple users found") + ErrInvalidGroupFormat = errors.New("group must start with 'group:'") + ErrInvalidTagFormat = errors.New("tag must start with 'tag:'") + ErrInvalidHostname = errors.New("invalid hostname") + ErrHostResolve = errors.New("error resolving host") + ErrInvalidPrefix = errors.New("invalid prefix") + ErrInvalidAutogroup = errors.New("invalid autogroup") + ErrUnknownAutogroup = errors.New("unknown autogroup") + ErrHostportMissingColon = errors.New("hostport must contain a colon") + ErrTypeNotSupported = errors.New("type not supported") + ErrInvalidAlias = errors.New("invalid alias format") + ErrInvalidAutoApprover = errors.New("invalid auto approver format") + ErrInvalidOwner = errors.New("invalid owner format") + ErrGroupNotDefined = errors.New("group not defined in policy") + ErrInvalidGroupMember = errors.New("invalid group member type") + ErrGroupValueNotArray = errors.New("group value must be an array of users") + ErrNestedGroups = errors.New("nested groups are not allowed") + ErrInvalidHostIP = errors.New("hostname contains invalid IP address") + ErrTagNotDefined = errors.New("tag not defined in policy") + ErrAutoApproverNotAlias = errors.New("auto approver is not an alias") + ErrInvalidACLAction = errors.New("invalid ACL action") + ErrInvalidSSHAction = errors.New("invalid SSH action") + ErrInvalidProtocolNumber = errors.New("invalid protocol number") + ErrProtocolLeadingZero = errors.New("leading 0 not permitted in protocol number") + ErrProtocolOutOfRange = errors.New("protocol number out of range (0-255)") + ErrAutogroupNotSupported = errors.New("autogroup not supported in headscale") + ErrAutogroupInternetSrc = errors.New("autogroup:internet can only be used in ACL destinations") + ErrAutogroupSelfSrc = errors.New("autogroup:self can only be used in ACL destinations") + ErrAutogroupNotSupportedACLSrc = errors.New("autogroup not supported for ACL sources") + ErrAutogroupNotSupportedACLDst = errors.New("autogroup not supported for ACL destinations") + ErrAutogroupNotSupportedSSHSrc = errors.New("autogroup not supported for SSH sources") + ErrAutogroupNotSupportedSSHDst = errors.New("autogroup not supported for SSH destinations") + ErrAutogroupNotSupportedSSHUsr = errors.New("autogroup not supported for SSH user") + ErrHostNotDefined = errors.New("host not defined in policy") + ErrSSHSourceAliasNotSupported = errors.New("alias not supported for SSH source") + ErrSSHDestAliasNotSupported = errors.New("alias not supported for SSH destination") + ErrUnknownSSHDestAlias = errors.New("unknown SSH destination alias type") + ErrUnknownSSHSrcAlias = errors.New("unknown SSH source alias type") + ErrUnknownField = errors.New("unknown field") + ErrProtocolNoSpecificPorts = errors.New("protocol does not support specific ports") +) + type Asterix int func (a Asterix) Validate() error { @@ -73,6 +122,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { } var alias string + switch v := a.Alias.(type) { case *Username: alias = string(*v) @@ -89,7 +139,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { case Asterix: alias = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } // If no ports are specified @@ -104,6 +154,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { // Otherwise, format as "alias:ports" var ports []string + for _, port := range a.Ports { if port.First == port.Last { ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) @@ -134,11 +185,12 @@ func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeV // Username is a string that represents a username, it must contain an @. type Username string -func (u Username) Validate() error { - if isUser(string(u)) { +func (u *Username) Validate() error { + if isUser(string(*u)) { return nil } - return fmt.Errorf("username must contain @, got: %q", u) + + return fmt.Errorf("%w, got: %q", ErrInvalidUsername, *u) } func (u *Username) String() string { @@ -146,29 +198,31 @@ func (u *Username) String() string { } // MarshalJSON marshals the Username to JSON. -func (u Username) MarshalJSON() ([]byte, error) { - return json.Marshal(string(u)) +func (u *Username) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*u)) } // MarshalJSON marshals the Prefix to JSON. -func (p Prefix) MarshalJSON() ([]byte, error) { +func (p *Prefix) MarshalJSON() ([]byte, error) { return json.Marshal(p.String()) } func (u *Username) UnmarshalJSON(b []byte) error { *u = Username(strings.Trim(string(b), `"`)) - if err := u.Validate(); err != nil { + + err := u.Validate() + if err != nil { return err } return nil } -func (u Username) CanBeTagOwner() bool { +func (u *Username) CanBeTagOwner() bool { return true } -func (u Username) CanBeAutoApprover() bool { +func (u *Username) CanBeAutoApprover() bool { return true } @@ -177,7 +231,7 @@ func (u Username) CanBeAutoApprover() bool { // If no matching user is found, it returns an error indicating no user matching. // If multiple matching users are found, it returns an error indicating multiple users matching. // It returns the matched types.User and a nil error if exactly one match is found. -func (u Username) resolveUser(users types.Users) (types.User, error) { +func (u *Username) resolveUser(users types.Users) (types.User, error) { var potentialUsers types.Users // At parsetime, we require all usernames to contain an "@" character, if the @@ -198,19 +252,21 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { } if len(potentialUsers) == 0 { - return types.User{}, fmt.Errorf("user with token %q not found", u.String()) + return types.User{}, fmt.Errorf("%w: token %q", ErrUserNotFound, u.String()) } if len(potentialUsers) > 1 { - return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String()) + return types.User{}, fmt.Errorf("%w: token %q found: %s", ErrMultipleUsersFound, u.String(), potentialUsers.String()) } return potentialUsers[0], nil } -func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error +func (u *Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ( + ips netipx.IPSetBuilder + errs []error + ) user, err := u.resolveUser(users) if err != nil { @@ -239,54 +295,59 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. // Group is a special string which is always prefixed with `group:`. type Group string -func (g Group) Validate() error { - if isGroup(string(g)) { +func (g *Group) Validate() error { + if isGroup(string(*g)) { return nil } - return fmt.Errorf(`group must start with "group:", got: %q`, g) + + return fmt.Errorf("%w, got: %q", ErrInvalidGroupFormat, *g) } func (g *Group) UnmarshalJSON(b []byte) error { *g = Group(strings.Trim(string(b), `"`)) - if err := g.Validate(); err != nil { + + err := g.Validate() + if err != nil { return err } return nil } -func (g Group) CanBeTagOwner() bool { +func (g *Group) CanBeTagOwner() bool { return true } -func (g Group) CanBeAutoApprover() bool { +func (g *Group) CanBeAutoApprover() bool { return true } // String returns the string representation of the Group. -func (g Group) String() string { - return string(g) +func (g *Group) String() string { + return string(*g) } -func (h Host) String() string { - return string(h) +func (h *Host) String() string { + return string(*h) } // MarshalJSON marshals the Host to JSON. -func (h Host) MarshalJSON() ([]byte, error) { - return json.Marshal(string(h)) +func (h *Host) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*h)) } // MarshalJSON marshals the Group to JSON. -func (g Group) MarshalJSON() ([]byte, error) { - return json.Marshal(string(g)) +func (g *Group) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*g)) } -func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error +func (g *Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ( + ips netipx.IPSetBuilder + errs []error + ) - for _, user := range p.Groups[g] { + for _, user := range p.Groups[*g] { uips, err := user.Resolve(nil, users, nodes) if err != nil { errs = append(errs, err) @@ -301,28 +362,31 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod // Tag is a special string which is always prefixed with `tag:`. type Tag string -func (t Tag) Validate() error { - if isTag(string(t)) { +func (t *Tag) Validate() error { + if isTag(string(*t)) { return nil } - return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) + + return fmt.Errorf("%w, got: %q", ErrInvalidTagFormat, *t) } func (t *Tag) UnmarshalJSON(b []byte) error { *t = Tag(strings.Trim(string(b), `"`)) - if err := t.Validate(); err != nil { + + err := t.Validate() + if err != nil { return err } return nil } -func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (t *Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var ips netipx.IPSetBuilder for _, node := range nodes.All() { // Check if node has this tag - if node.HasTag(string(t)) { + if node.HasTag(string(*t)) { node.AppendToIPSet(&ips) } } @@ -330,50 +394,56 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV return ips.IPSet() } -func (t Tag) CanBeAutoApprover() bool { +func (t *Tag) CanBeAutoApprover() bool { return true } -func (t Tag) CanBeTagOwner() bool { +func (t *Tag) CanBeTagOwner() bool { return true } -func (t Tag) String() string { - return string(t) +func (t *Tag) String() string { + return string(*t) } // MarshalJSON marshals the Tag to JSON. -func (t Tag) MarshalJSON() ([]byte, error) { - return json.Marshal(string(t)) +func (t *Tag) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*t)) } // Host is a string that represents a hostname. type Host string -func (h Host) Validate() error { - if isHost(string(h)) { +func (h *Host) Validate() error { + if isHost(string(*h)) { return nil } - return fmt.Errorf("hostname %q is invalid", h) + + return fmt.Errorf("%w: %q", ErrInvalidHostname, *h) } func (h *Host) UnmarshalJSON(b []byte) error { *h = Host(strings.Trim(string(b), `"`)) - if err := h.Validate(); err != nil { + + err := h.Validate() + if err != nil { return err } return nil } -func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error +func (h *Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ( + ips netipx.IPSetBuilder + errs []error + ) - pref, ok := p.Hosts[h] + pref, ok := p.Hosts[*h] if !ok { - return nil, fmt.Errorf("resolving host: %q", h) + return nil, fmt.Errorf("%w: %q", ErrHostResolve, *h) } + err := pref.Validate() if err != nil { errs = append(errs, err) @@ -391,6 +461,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView if err != nil { errs = append(errs, err) } + for _, node := range nodes.All() { if node.InIPSet(ipsTemp) { node.AppendToIPSet(&ips) @@ -402,15 +473,16 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView type Prefix netip.Prefix -func (p Prefix) Validate() error { - if netip.Prefix(p).IsValid() { +func (p *Prefix) Validate() error { + if netip.Prefix(*p).IsValid() { return nil } - return fmt.Errorf("prefix %q is invalid", p) + + return fmt.Errorf("%w: %s", ErrInvalidPrefix, p.String()) } -func (p Prefix) String() string { - return netip.Prefix(p).String() +func (p *Prefix) String() string { + return netip.Prefix(*p).String() } func (p *Prefix) parseString(addr string) error { @@ -419,6 +491,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + addrPref, err := addr.Prefix(addr.BitLen()) if err != nil { return err @@ -433,6 +506,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + *p = Prefix(pref) return nil @@ -443,7 +517,8 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { if err != nil { return err } - if err := p.Validate(); err != nil { + + if err := p.Validate(); err != nil { //nolint:noinlineerr return err } @@ -455,14 +530,16 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // of the Prefix and the Policy, Users, and Nodes. // // See [Policy], [types.Users], and [types.Nodes] for more details. -func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error +func (p *Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ( + ips netipx.IPSetBuilder + errs []error + ) - ips.AddPrefix(netip.Prefix(p)) + ips.AddPrefix(netip.Prefix(*p)) // If the IP is a single host, look for a node to ensure we add all the IPs of // the node to the IPSet. - appendIfNodeHasIP(nodes, &ips, netip.Prefix(p)) + appendIfNodeHasIP(nodes, &ips, netip.Prefix(*p)) return buildIPSetMultiErr(&ips, errs) } @@ -500,36 +577,38 @@ var autogroups = []AutoGroup{ AutoGroupSelf, } -func (ag AutoGroup) Validate() error { - if slices.Contains(autogroups, ag) { +func (ag *AutoGroup) Validate() error { + if slices.Contains(autogroups, *ag) { return nil } - return fmt.Errorf("autogroup is invalid, got: %q, must be one of %v", ag, autogroups) + return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutogroup, *ag, autogroups) } func (ag *AutoGroup) UnmarshalJSON(b []byte) error { *ag = AutoGroup(strings.Trim(string(b), `"`)) - if err := ag.Validate(); err != nil { + + err := ag.Validate() + if err != nil { return err } return nil } -func (ag AutoGroup) String() string { - return string(ag) +func (ag *AutoGroup) String() string { + return string(*ag) } // MarshalJSON marshals the AutoGroup to JSON. -func (ag AutoGroup) MarshalJSON() ([]byte, error) { - return json.Marshal(string(ag)) +func (ag *AutoGroup) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*ag)) } -func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { +func (ag *AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - switch ag { + switch *ag { case AutoGroupInternet: return util.TheInternet(), nil @@ -564,8 +643,13 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type // specially during policy compilation per-node for security. return nil, ErrAutogroupSelfRequiresPerNodeResolution + case AutoGroupNonRoot: + // autogroup:nonroot represents non-root users on multi-user devices. + // This is not supported in headscale and requires OS-level user detection. + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, *ag) + default: - return nil, fmt.Errorf("unknown autogroup %q", ag) + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, *ag) } } @@ -579,31 +663,36 @@ func (ag *AutoGroup) Is(c AutoGroup) bool { type Alias interface { Validate() error - UnmarshalJSON([]byte) error + UnmarshalJSON(b []byte) error // Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP // addresses that the Alias represents within Headscale. It is the product // of the Alias and the Policy, Users and Nodes. // This is an interface definition and the implementation is independent of // the Alias type. - Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error) + Resolve(pol *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) } type AliasWithPorts struct { Alias + Ports []tailcfg.PortRange } func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { var v any - if err := json.Unmarshal(b, &v); err != nil { + + err := json.Unmarshal(b, &v) + if err != nil { return err } switch vs := v.(type) { case string: - var portsPart string - var err error + var ( + portsPart string + err error + ) if strings.Contains(vs, ":") { vs, portsPart, err = splitDestinationAndPort(vs) @@ -615,21 +704,23 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Ports = ports } else { - return errors.New(`hostport must contain a colon (":")`) + return ErrHostportMissingColon } ve.Alias, err = parseAlias(vs) if err != nil { return err } - if err := ve.Validate(); err != nil { + + if err := ve.Validate(); err != nil { //nolint:noinlineerr return err } default: - return fmt.Errorf("type %T not supported", vs) + return fmt.Errorf("%w: %T", ErrTypeNotSupported, vs) } return nil @@ -661,6 +752,7 @@ func isHost(str string) bool { func parseAlias(vs string) (Alias, error) { var pref Prefix + err := pref.parseString(vs) if err == nil { return &pref, nil @@ -683,15 +775,7 @@ func parseAlias(vs string) (Alias, error) { return ptr.To(Host(vs)), nil } - return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: -- wildcard (*) -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") -- autogroup (starting with "autogroup:") -- host - -Please check the format and try again.`, vs) + return nil, fmt.Errorf("%w: %q", ErrInvalidAlias, vs) } // AliasEnc is used to deserialize a Alias. @@ -705,6 +789,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Alias = ptr return nil @@ -714,6 +799,7 @@ type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -728,13 +814,13 @@ func (a *Aliases) UnmarshalJSON(b []byte) error { } // MarshalJSON marshals the Aliases to JSON. -func (a Aliases) MarshalJSON() ([]byte, error) { - if a == nil { +func (a *Aliases) MarshalJSON() ([]byte, error) { + if *a == nil { return []byte("[]"), nil } - aliases := make([]string, len(a)) - for i, alias := range a { + aliases := make([]string, len(*a)) + for i, alias := range *a { switch v := alias.(type) { case *Username: aliases[i] = string(*v) @@ -751,18 +837,20 @@ func (a Aliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } } return json.Marshal(aliases) } -func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error +func (a *Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ( + ips netipx.IPSetBuilder + errs []error + ) - for _, alias := range a { + for _, alias := range *a { aips, err := alias.Resolve(p, users, nodes) if err != nil { errs = append(errs, err) @@ -785,6 +873,7 @@ func unmarshalPointer[T any]( parseFunc func(string) (T, error), ) (T, error) { var s string + err := json.Unmarshal(b, &s) if err != nil { var t T @@ -796,7 +885,7 @@ func unmarshalPointer[T any]( type AutoApprover interface { CanBeAutoApprover() bool - UnmarshalJSON([]byte) error + UnmarshalJSON(b []byte) error String() string } @@ -804,6 +893,7 @@ type AutoApprovers []AutoApprover func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { var autoApprovers []AutoApproverEnc + err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...) if err != nil { return err @@ -833,7 +923,7 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) { case *Group: approvers[i] = string(*v) default: - return nil, fmt.Errorf("unknown auto approver type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAutoApprover, v) } } @@ -850,12 +940,7 @@ func parseAutoApprover(s string) (AutoApprover, error) { return ptr.To(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoApprover, s) } // AutoApproverEnc is used to deserialize a AutoApprover. @@ -869,6 +954,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.AutoApprover = ptr return nil @@ -876,7 +962,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { type Owner interface { CanBeTagOwner() bool - UnmarshalJSON([]byte) error + UnmarshalJSON(b []byte) error String() string } @@ -891,6 +977,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Owner = ptr return nil @@ -900,6 +987,7 @@ type Owners []Owner func (o *Owners) UnmarshalJSON(b []byte) error { var owners []OwnerEnc + err := json.Unmarshal(b, &owners, policyJSONOpts...) if err != nil { return err @@ -929,7 +1017,7 @@ func (o Owners) MarshalJSON() ([]byte, error) { case *Tag: owners[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -946,12 +1034,7 @@ func parseOwner(s string) (Owner, error) { return ptr.To(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidOwner, s) } type Usernames []Username @@ -959,18 +1042,18 @@ type Usernames []Username // Groups are a map of Group to a list of Username. type Groups map[Group]Usernames -func (g Groups) Contains(group *Group) error { +func (g *Groups) Contains(group *Group) error { if group == nil { return nil } - for defined := range map[Group]Usernames(g) { + for defined := range map[Group]Usernames(*g) { if defined == *group { return nil } } - return fmt.Errorf(`Group %q is not defined in the Policy, please define or remove the reference to it`, group) + return fmt.Errorf("%w: %q", ErrGroupNotDefined, group) } // UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure @@ -980,41 +1063,49 @@ func (g Groups) Contains(group *Group) error { func (g *Groups) UnmarshalJSON(b []byte) error { // First unmarshal as a generic map to validate group names first var rawMap map[string]any - if err := json.Unmarshal(b, &rawMap); err != nil { + + err := json.Unmarshal(b, &rawMap) + if err != nil { return err } // Validate group names first before checking data types for key := range rawMap { group := Group(key) - if err := group.Validate(); err != nil { + + err := group.Validate() + if err != nil { return err } } // Then validate each field can be converted to []string rawGroups := make(map[string][]string) + for key, value := range rawMap { switch v := value.(type) { case []any: // Convert []interface{} to []string var stringSlice []string + for _, item := range v { if str, ok := item.(string); ok { stringSlice = append(stringSlice, str) } else { - return fmt.Errorf(`group "%s" contains invalid member type, expected string but got %T`, key, item) + return fmt.Errorf("%w: group %q expected string but got %T", ErrInvalidGroupMember, key, item) } } + rawGroups[key] = stringSlice case string: - return fmt.Errorf(`group "%s" value must be an array of users, got string: "%s"`, key, v) + return fmt.Errorf("%w: group %q got string: %q", ErrGroupValueNotArray, key, v) default: - return fmt.Errorf(`group "%s" value must be an array of users, got %T`, key, v) + return fmt.Errorf("%w: group %q got %T", ErrGroupValueNotArray, key, v) } } *g = make(Groups) + for key, value := range rawGroups { group := Group(key) // Group name already validated above @@ -1022,13 +1113,16 @@ func (g *Groups) UnmarshalJSON(b []byte) error { for _, u := range value { username := Username(u) - if err := username.Validate(); err != nil { + + err := username.Validate() + if err != nil { if isGroup(u) { - return fmt.Errorf("nested groups are not allowed, found %q inside %q", u, group) + return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group) } return err } + usernames = append(usernames, username) } @@ -1043,20 +1137,27 @@ type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { var rawHosts map[string]string - if err := json.Unmarshal(b, &rawHosts, policyJSONOpts...); err != nil { + + err := json.Unmarshal(b, &rawHosts, policyJSONOpts...) + if err != nil { return err } *h = make(Hosts) + for key, value := range rawHosts { host := Host(key) - if err := host.Validate(); err != nil { + + err := host.Validate() + if err != nil { return err } var prefix Prefix - if err := prefix.parseString(value); err != nil { - return fmt.Errorf(`hostname "%s" contains an invalid IP address: "%s"`, key, value) + + err = prefix.parseString(value) + if err != nil { + return fmt.Errorf("%w: hostname %q address %q", ErrInvalidHostIP, key, value) } (*h)[host] = prefix @@ -1066,21 +1167,21 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { } // MarshalJSON marshals the Hosts to JSON. -func (h Hosts) MarshalJSON() ([]byte, error) { - if h == nil { +func (h *Hosts) MarshalJSON() ([]byte, error) { + if *h == nil { return []byte("{}"), nil } rawHosts := make(map[string]string) - for host, prefix := range h { + for host, prefix := range *h { rawHosts[string(host)] = prefix.String() } return json.Marshal(rawHosts) } -func (h Hosts) exist(name Host) bool { - _, ok := h[name] +func (h *Hosts) exist(name Host) bool { + _, ok := (*h)[name] return ok } @@ -1091,6 +1192,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { } rawTagOwners := make(map[string][]string) + for tag, owners := range to { tagStr := string(tag) ownerStrs := make([]string, len(owners)) @@ -1104,7 +1206,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { case *Tag: ownerStrs[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -1128,7 +1230,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error { } } - return fmt.Errorf(`tag %q is not defined in the policy, please define or remove the reference to it`, tagOwner) + return fmt.Errorf("%w: %q", ErrTagNotDefined, tagOwner) } type AutoApproverPolicy struct { @@ -1167,6 +1269,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if p == nil { return nil, nil, nil } + var err error routes := make(map[netip.Prefix]*netipx.IPSetBuilder) @@ -1175,11 +1278,12 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if _, ok := routes[prefix]; !ok { routes[prefix] = new(netipx.IPSetBuilder) } + for _, autoApprover := range autoApprovers { aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1188,12 +1292,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } var exitNodeSetBuilder netipx.IPSetBuilder + if len(p.AutoApprovers.ExitNode) > 0 { for _, autoApprover := range p.AutoApprovers.ExitNode { aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1202,11 +1307,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } ret := make(map[netip.Prefix]*netipx.IPSet) + for prefix, builder := range routes { ipSet, err := builder.IPSet() if err != nil { return nil, nil, err } + ret[prefix] = ipSet } @@ -1237,8 +1344,8 @@ const ( ) // String returns the string representation of the Action. -func (a Action) String() string { - return string(a) +func (a *Action) String() string { + return string(*a) } // UnmarshalJSON implements JSON unmarshaling for Action. @@ -1248,19 +1355,20 @@ func (a *Action) UnmarshalJSON(b []byte) error { case "accept": *a = ActionAccept default: - return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + return fmt.Errorf("%w: %q, must be %q", ErrInvalidACLAction, str, ActionAccept) } + return nil } // MarshalJSON implements JSON marshaling for Action. -func (a Action) MarshalJSON() ([]byte, error) { - return json.Marshal(string(a)) +func (a *Action) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*a)) } // String returns the string representation of the SSHAction. -func (a SSHAction) String() string { - return string(a) +func (a *SSHAction) String() string { + return string(*a) } // UnmarshalJSON implements JSON unmarshaling for SSHAction. @@ -1272,14 +1380,15 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { case "check": *a = SSHActionCheck default: - return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + return fmt.Errorf("%w: %q, must be one of: accept, check", ErrInvalidSSHAction, str) } + return nil } // MarshalJSON implements JSON marshaling for SSHAction. -func (a SSHAction) MarshalJSON() ([]byte, error) { - return json.Marshal(string(a)) +func (a *SSHAction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*a)) } // Protocol represents a network protocol with its IANA number and descriptions. @@ -1304,13 +1413,13 @@ const ( ) // String returns the string representation of the Protocol. -func (p Protocol) String() string { - return string(p) +func (p *Protocol) String() string { + return string(*p) } // Description returns the human-readable description of the Protocol. -func (p Protocol) Description() string { - switch p { +func (p *Protocol) Description() string { + switch *p { case ProtocolNameICMP: return "Internet Control Message Protocol" case ProtocolNameIGMP: @@ -1337,6 +1446,8 @@ func (p Protocol) Description() string { return "Stream Control Transmission Protocol" case ProtocolNameFC: return "Fibre Channel" + case ProtocolNameIPInIP: + return "IP-in-IP Encapsulation" case ProtocolNameWildcard: return "Wildcard (not supported - use specific protocol)" default: @@ -1344,51 +1455,49 @@ func (p Protocol) Description() string { } } -// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. -func (p Protocol) parseProtocol() ([]int, bool) { - switch p { +func (p *Protocol) parseProtocol() []int { + switch *p { case "": // Empty protocol applies to TCP, UDP, ICMP, and ICMPv6 traffic // This matches Tailscale's behavior for protocol defaults - return []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP}, false + return []int{ProtocolTCP, ProtocolUDP, ProtocolICMP, ProtocolIPv6ICMP} case ProtocolNameWildcard: // Wildcard protocol - defensive handling (should not reach here due to validation) - return nil, false + return nil case ProtocolNameIGMP: - return []int{ProtocolIGMP}, true + return []int{ProtocolIGMP} case ProtocolNameIPv4, ProtocolNameIPInIP: - return []int{ProtocolIPv4}, true + return []int{ProtocolIPv4} case ProtocolNameTCP: - return []int{ProtocolTCP}, false + return []int{ProtocolTCP} case ProtocolNameEGP: - return []int{ProtocolEGP}, true + return []int{ProtocolEGP} case ProtocolNameIGP: - return []int{ProtocolIGP}, true + return []int{ProtocolIGP} case ProtocolNameUDP: - return []int{ProtocolUDP}, false + return []int{ProtocolUDP} case ProtocolNameGRE: - return []int{ProtocolGRE}, true + return []int{ProtocolGRE} case ProtocolNameESP: - return []int{ProtocolESP}, true + return []int{ProtocolESP} case ProtocolNameAH: - return []int{ProtocolAH}, true + return []int{ProtocolAH} case ProtocolNameSCTP: - return []int{ProtocolSCTP}, false + return []int{ProtocolSCTP} case ProtocolNameICMP: // ICMP only - use "ipv6-icmp" or protocol number 58 for ICMPv6 - return []int{ProtocolICMP}, true + return []int{ProtocolICMP} + case ProtocolNameIPv6ICMP: + return []int{ProtocolIPv6ICMP} + case ProtocolNameFC: + return []int{ProtocolFC} default: // Try to parse as a numeric protocol number // This should not fail since validation happened during unmarshaling - protocolNumber, _ := strconv.Atoi(string(p)) - - // Determine if wildcard is needed based on protocol number - needsWildcard := protocolNumber != ProtocolTCP && - protocolNumber != ProtocolUDP && - protocolNumber != ProtocolSCTP - - return []int{protocolNumber}, needsWildcard + protocolNumber, _ := strconv.Atoi(string(*p)) + return []int{protocolNumber} } } @@ -1400,7 +1509,8 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { *p = Protocol(strings.ToLower(str)) // Validate the protocol - if err := p.validate(); err != nil { + err := p.validate() + if err != nil { return err } @@ -1408,31 +1518,31 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { } // validate checks if the Protocol is valid. -func (p Protocol) validate() error { - switch p { +func (p *Protocol) validate() error { + switch *p { case "", ProtocolNameICMP, ProtocolNameIGMP, ProtocolNameIPv4, ProtocolNameIPInIP, ProtocolNameTCP, ProtocolNameEGP, ProtocolNameIGP, ProtocolNameUDP, ProtocolNameGRE, - ProtocolNameESP, ProtocolNameAH, ProtocolNameSCTP: + ProtocolNameESP, ProtocolNameAH, ProtocolNameSCTP, ProtocolNameIPv6ICMP, ProtocolNameFC: return nil case ProtocolNameWildcard: // Wildcard "*" is not allowed - Tailscale rejects it - return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + return errUnknownProtocolWildcard default: // Try to parse as a numeric protocol number - str := string(p) + str := string(*p) // Check for leading zeros (not allowed by Tailscale) if str == "0" || (len(str) > 1 && str[0] == '0') { - return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + return fmt.Errorf("%w: %q", ErrProtocolLeadingZero, str) } protocolNumber, err := strconv.Atoi(str) if err != nil { - return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocolNumber, *p) } if protocolNumber < 0 || protocolNumber > 255 { - return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + return fmt.Errorf("%w: %d", ErrProtocolOutOfRange, protocolNumber) } return nil @@ -1440,11 +1550,11 @@ func (p Protocol) validate() error { } // MarshalJSON implements JSON marshaling for Protocol. -func (p Protocol) MarshalJSON() ([]byte, error) { - return json.Marshal(string(p)) +func (p *Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(string(*p)) } -// Protocol constants matching the IANA numbers +// Protocol constants matching the IANA numbers. const ( ProtocolICMP = 1 // Internet Control Message ProtocolIGMP = 2 // Internet Group Management @@ -1475,12 +1585,13 @@ type ACL struct { func (a *ACL) UnmarshalJSON(b []byte) error { // First unmarshal into a map to filter out comment fields var raw map[string]any - if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil { + if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil { //nolint:noinlineerr return err } // Remove any fields that start with '#' filtered := make(map[string]any) + for key, value := range raw { if !strings.HasPrefix(key, "#") { filtered[key] = value @@ -1495,15 +1606,17 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Create a type alias to avoid infinite recursion type aclAlias ACL + var temp aclAlias // Unmarshal into the temporary struct using the v2 JSON options - if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil { + if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil { //nolint:noinlineerr return err } // Copy the result back to the original struct *a = ACL(temp) + return nil } @@ -1539,6 +1652,8 @@ var ( autogroupForSSHDst = []AutoGroup{AutoGroupMember, AutoGroupTagged, AutoGroupSelf} autogroupForSSHUser = []AutoGroup{AutoGroupNonRoot} autogroupNotSupported = []AutoGroup{} + + errUnknownProtocolWildcard = errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") ) func validateAutogroupSupported(ag *AutoGroup) error { @@ -1547,7 +1662,7 @@ func validateAutogroupSupported(ag *AutoGroup) error { } if slices.Contains(autogroupNotSupported, *ag) { - return fmt.Errorf("autogroup %q is not supported in headscale", *ag) + return fmt.Errorf("%w: %q", ErrAutogroupNotSupported, *ag) } return nil @@ -1559,15 +1674,15 @@ func validateAutogroupForSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + return ErrAutogroupInternetSrc } if src.Is(AutoGroupSelf) { - return errors.New(`"autogroup:self" used in source, it can only be used in ACL destinations`) + return ErrAutogroupSelfSrc } if !slices.Contains(autogroupForSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedACLSrc, *src, autogroupForSrc) } return nil @@ -1579,7 +1694,7 @@ func validateAutogroupForDst(dst *AutoGroup) error { } if !slices.Contains(autogroupForDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedACLDst, *dst, autogroupForDst) } return nil @@ -1591,11 +1706,11 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + return ErrAutogroupInternetSrc } if !slices.Contains(autogroupForSSHSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSSHSrc, *src, autogroupForSSHSrc) } return nil @@ -1607,11 +1722,11 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error { } if dst.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + return ErrAutogroupInternetSrc } if !slices.Contains(autogroupForSSHDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSSHDst, *dst, autogroupForSSHDst) } return nil @@ -1623,7 +1738,7 @@ func validateAutogroupForSSHUser(user *AutoGroup) error { } if !slices.Contains(autogroupForSSHUser, *user) { - return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSSHUsr, *user, autogroupForSSHUser) } return nil @@ -1735,6 +1850,8 @@ func validateACLSrcDstCombination(sources Aliases, destinations []AliasWithPorts // the unmarshaling process. // It runs through all rules and checks if there are any inconsistencies // in the policy that needs to be addressed before it can be used. +// +//nolint:gocyclo // comprehensive policy validation func (p *Policy) validate() error { if p == nil { panic("passed nil policy") @@ -1750,67 +1867,72 @@ func (p *Policy) validate() error { case *Host: h := src if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`host %q is not defined in the policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q", ErrHostNotDefined, *h)) } case *AutoGroup: ag := src - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSrc(ag); err != nil { + err = validateAutogroupForSrc(ag) + if err != nil { errs = append(errs, err) continue } case *Group: g := src - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } } for _, dst := range acl.Destinations { - switch dst.Alias.(type) { + switch h := dst.Alias.(type) { case *Host: - h := dst.Alias.(*Host) if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`host %q is not defined in the policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q", ErrHostNotDefined, *h)) } case *AutoGroup: - ag := dst.Alias.(*AutoGroup) - - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(h) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForDst(ag); err != nil { + err = validateAutogroupForDst(h) + if err != nil { errs = append(errs, err) continue } case *Group: - g := dst.Alias.(*Group) - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(h) + if err != nil { errs = append(errs, err) } case *Tag: - tagOwner := dst.Alias.(*Tag) - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(h) + if err != nil { errs = append(errs, err) } } } // Validate protocol-port compatibility - if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil { + if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil { //nolint:noinlineerr errs = append(errs, err) } @@ -1825,7 +1947,9 @@ func (p *Policy) validate() error { for _, user := range ssh.Users { if strings.HasPrefix(string(user), "autogroup:") { maybeAuto := AutoGroup(user) - if err := validateAutogroupForSSHUser(&maybeAuto); err != nil { + + err := validateAutogroupForSSHUser(&maybeAuto) + if err != nil { errs = append(errs, err) continue } @@ -1837,43 +1961,55 @@ func (p *Policy) validate() error { case *AutoGroup: ag := src - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSSHSrc(ag); err != nil { + err = validateAutogroupForSSHSrc(ag) + if err != nil { errs = append(errs, err) continue } case *Group: g := src - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } } + for _, dst := range ssh.Destinations { switch dst := dst.(type) { case *AutoGroup: ag := dst - if err := validateAutogroupSupported(ag); err != nil { + + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSSHDst(ag); err != nil { + err = validateAutogroupForSSHDst(ag) + if err != nil { errs = append(errs, err) continue } case *Tag: tagOwner := dst - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1891,7 +2027,9 @@ func (p *Policy) validate() error { switch tagOwner := tagOwner.(type) { case *Group: g := tagOwner - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: @@ -1916,12 +2054,16 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1932,12 +2074,16 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1966,17 +2112,18 @@ type SSH struct { type SSHSrcAliases []Alias // MarshalJSON marshals the Groups to JSON. -func (g Groups) MarshalJSON() ([]byte, error) { - if g == nil { +func (g *Groups) MarshalJSON() ([]byte, error) { + if *g == nil { return []byte("{}"), nil } raw := make(map[string][]string) - for group, usernames := range g { + for group, usernames := range *g { users := make([]string, len(usernames)) for i, username := range usernames { users[i] = string(username) } + raw[string(group)] = users } @@ -1985,6 +2132,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1996,10 +2144,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { case *Username, *Group, *Tag, *AutoGroup: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH source", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrSSHSourceAliasNotSupported, alias.Alias) } } @@ -2008,6 +2153,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -2023,10 +2169,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { "'autogroup:tagged' for tagged devices, or specific tags/users", ErrSSHWildcardDestination) default: - return fmt.Errorf( - "alias %T is not supported for SSH destination", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrSSHDestAliasNotSupported, alias.Alias) } } @@ -2055,7 +2198,7 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { // with a proper error message explaining alternatives aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH destination alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHDestAlias, v) } } @@ -2063,13 +2206,13 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { } // MarshalJSON marshals the SSHSrcAliases to JSON. -func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { - if a == nil { +func (a *SSHSrcAliases) MarshalJSON() ([]byte, error) { + if a == nil || *a == nil { return []byte("[]"), nil } - aliases := make([]string, len(a)) - for i, alias := range a { + aliases := make([]string, len(*a)) + for i, alias := range *a { switch v := alias.(type) { case *Username: aliases[i] = string(*v) @@ -2082,18 +2225,20 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH source alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHSrcAlias, v) } } return json.Marshal(aliases) } -func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error +func (a *SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ( + ips netipx.IPSetBuilder + errs []error + ) - for _, alias := range a { + for _, alias := range *a { aips, err := alias.Resolve(p, users, nodes) if err != nil { errs = append(errs, err) @@ -2141,27 +2286,31 @@ func (u SSHUser) MarshalJSON() ([]byte, error) { // This is the only entrypoint of reading a policy from a file or other source. func unmarshalPolicy(b []byte) (*Policy, error) { if len(b) == 0 { - return nil, nil + return nil, nil //nolint:nilnil // intentional: no policy when empty input } var policy Policy + ast, err := hujson.Parse(b) if err != nil { return nil, fmt.Errorf("parsing HuJSON: %w", err) } ast.Standardize() - if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { + + if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { //nolint:noinlineerr var serr *json.SemanticError - if errors.As(err, &serr) && serr.Err == json.ErrUnknownName { + if errors.As(err, &serr) && errors.Is(serr.Err, json.ErrUnknownName) { ptr := serr.JSONPointer name := ptr.LastToken() - return nil, fmt.Errorf("unknown field %q", name) + + return nil, fmt.Errorf("%w: %q", ErrUnknownField, name) } + return nil, fmt.Errorf("parsing policy from bytes: %w", err) } - if err := policy.validate(); err != nil { + if err := policy.validate(); err != nil { //nolint:noinlineerr return nil, err } @@ -2182,8 +2331,8 @@ func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWi for _, dst := range destinations { for _, portRange := range dst.Ports { // Check if it's not a wildcard port (0-65535) - if !(portRange.First == 0 && portRange.Last == 65535) { - return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) + if portRange.First != 0 || portRange.Last != 65535 { + return fmt.Errorf("%w: %q, only \"*\" is allowed", ErrProtocolNoSpecificPorts, protocol) } } } @@ -2204,6 +2353,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true @@ -2218,6 +2368,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range ssh.Destinations { if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 3830650f..16105ecb 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -82,6 +82,7 @@ func TestMarshalJSON(t *testing.T) { // Unmarshal back to verify round trip var roundTripped Policy + err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) @@ -366,7 +367,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: "alias v2.Asterix is not supported for SSH source", + wantErr: "alias not supported for SSH source: v2.Asterix", }, { name: "invalid-username", @@ -393,7 +394,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `group must start with "group:", got: "grou:example"`, + wantErr: `group must start with 'group:', got: "grou:example"`, }, { name: "group-in-group", @@ -408,7 +409,7 @@ func TestUnmarshalPolicy(t *testing.T) { } `, // wantErr: `username must contain @, got: "group:inner"`, - wantErr: `nested groups are not allowed, found "group:inner" inside "group:example"`, + wantErr: `nested groups are not allowed: found "group:inner" inside "group:example"`, }, { name: "invalid-addr", @@ -419,7 +420,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `hostname "derp" contains an invalid IP address: "10.0"`, + wantErr: `hostname contains invalid IP address: hostname "derp" address "10.0"`, }, { name: "invalid-prefix", @@ -430,7 +431,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `hostname "derp" contains an invalid IP address: "10.0/42"`, + wantErr: `hostname contains invalid IP address: hostname "derp" address "10.0/42"`, }, // TODO(kradalby): Figure out why this doesn't work. // { @@ -459,7 +460,7 @@ func TestUnmarshalPolicy(t *testing.T) { ], } `, - wantErr: `autogroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, + wantErr: `invalid autogroup: got "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, }, { name: "undefined-hostname-errors-2490", @@ -478,7 +479,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `host "user1" is not defined in the policy, please define or remove the reference to it`, + wantErr: `host not defined in policy: "user1"`, }, { name: "defined-hostname-does-not-err-2490", @@ -571,7 +572,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`, + wantErr: `autogroup:internet can only be used in ACL destinations`, }, { name: "autogroup:internet-in-ssh-src-not-allowed", @@ -590,7 +591,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`, + wantErr: `tag not defined in policy: "tag:test"`, }, { name: "autogroup:internet-in-ssh-dst-not-allowed", @@ -609,7 +610,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`, + wantErr: `autogroup:internet can only be used in ACL destinations`, }, { name: "ssh-basic", @@ -762,7 +763,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-dst", @@ -781,7 +782,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-ssh-src", @@ -800,7 +801,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `user destination requires source to contain only that same user "user@"`, }, { name: "group-must-be-defined-acl-tagOwner", @@ -811,7 +812,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-route", @@ -824,7 +825,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-exitnode", @@ -835,7 +836,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "tag-must-be-defined-acl-src", @@ -854,7 +855,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-dst", @@ -873,7 +874,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-src", @@ -892,7 +893,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-dst", @@ -914,7 +915,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-route", @@ -927,7 +928,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-exitnode", @@ -938,7 +939,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `tag "tag:notdefined" is not defined in the policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "missing-dst-port-is-err", @@ -957,7 +958,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `hostport must contain a colon (":")`, + wantErr: `hostport must contain a colon`, }, { name: "dst-port-zero-is-err", @@ -987,7 +988,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "rules"`, + wantErr: `unknown field: "rules"`, }, { name: "disallow-unsupported-fields-nested", @@ -1010,7 +1011,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `group must start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:', got: "INVALID_GROUP_FIELD"`, }, { name: "invalid-group-datatype", @@ -1022,7 +1023,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `group "group:invalid" value must be an array of users, got string: "should fail"`, + wantErr: `group value must be an array of users: group "group:invalid" got string: "should fail"`, }, { name: "invalid-group-name-and-datatype-fails-on-name-first", @@ -1034,7 +1035,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `group must start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:', got: "INVALID_GROUP_FIELD"`, }, { name: "disallow-unsupported-fields-hosts-level", @@ -1046,7 +1047,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`, + wantErr: `hostname contains invalid IP address: hostname "INVALID_HOST_FIELD" address "should fail"`, }, { name: "disallow-unsupported-fields-tagowners-level", @@ -1058,7 +1059,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`, + wantErr: `tag must start with 'tag:', got: "INVALID_TAG_FIELD"`, }, { name: "disallow-unsupported-fields-acls-level", @@ -1075,7 +1076,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_ACL_FIELD"`, + wantErr: `unknown field: "INVALID_ACL_FIELD"`, }, { name: "disallow-unsupported-fields-ssh-level", @@ -1092,7 +1093,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_SSH_FIELD"`, + wantErr: `unknown field: "INVALID_SSH_FIELD"`, }, { name: "disallow-unsupported-fields-policy-level", @@ -1109,7 +1110,7 @@ func TestUnmarshalPolicy(t *testing.T) { "INVALID_POLICY_FIELD": "should fail at policy level" } `, - wantErr: `unknown field "INVALID_POLICY_FIELD"`, + wantErr: `unknown field: "INVALID_POLICY_FIELD"`, }, { name: "disallow-unsupported-fields-autoapprovers-level", @@ -1124,7 +1125,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`, + wantErr: `unknown field: "INVALID_AUTO_APPROVER_FIELD"`, }, // headscale-admin uses # in some field names to add metadata, so we will ignore // those to ensure it doesnt break. @@ -1183,7 +1184,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "proto"`, + wantErr: `unknown field: "proto"`, }, { name: "protocol-wildcard-not-allowed", @@ -1279,7 +1280,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `leading 0 not permitted in protocol number "0"`, + wantErr: `leading 0 not permitted in protocol number: "0"`, }, { name: "protocol-empty-applies-to-tcp-udp-only", @@ -1326,7 +1327,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "icmp", only "*" is allowed`, }, { name: "protocol-icmp-with-wildcard-port-allowed", @@ -1374,7 +1375,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "gre", only "*" is allowed`, }, { name: "protocol-tcp-with-specific-port-allowed", @@ -2081,7 +2082,7 @@ func TestResolvePolicy(t *testing.T) { IPv4: ap("100.100.101.103"), }, }, - wantErr: `user with token "invaliduser@" not found`, + wantErr: `user not found: token "invaliduser@"`, }, { name: "invalid-tag", @@ -2105,7 +2106,7 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-member-comprehensive", - toResolve: ptr.To(AutoGroup(AutoGroupMember)), + toResolve: ptr.To(AutoGroupMember), nodes: types.Nodes{ // Node with no tags (should be included - is a member) { @@ -2155,7 +2156,7 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-tagged", - toResolve: ptr.To(AutoGroup(AutoGroupTagged)), + toResolve: ptr.To(AutoGroupTagged), nodes: types.Nodes{ // Node with no tags (should be excluded - not tagged) { @@ -2266,6 +2267,7 @@ func TestResolvePolicy(t *testing.T) { } var prefs []netip.Prefix + if ips != nil { if p := ips.Prefixes(); len(p) > 0 { prefs = p @@ -2437,9 +2439,11 @@ func TestResolveAutoApprovers(t *testing.T) { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff) } + if tt.wantAllIPRoutes != nil { if gotAllIPRoutes == nil { t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil") @@ -2586,6 +2590,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet { for _, p := range prefixes { builder.AddPrefix(mp(p)) } + ipSet, _ := builder.IPSet() return ipSet @@ -2595,6 +2600,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool { if x == nil || y == nil { return x == y } + return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...) } @@ -2823,6 +2829,7 @@ func TestResolveTagOwners(t *testing.T) { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff) } @@ -3098,6 +3105,7 @@ func TestNodeCanHaveTag(t *testing.T) { require.ErrorContains(t, err, tt.wantErr) return } + require.NoError(t, err) got := pm.NodeCanHaveTag(tt.node.View(), tt.tag) @@ -3358,6 +3366,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) if tt.wantErr { @@ -3368,8 +3377,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.expected.Action, acl.Action) assert.Equal(t, tt.expected.Protocol, acl.Protocol) - assert.Equal(t, len(tt.expected.Sources), len(acl.Sources)) - assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations)) + assert.Len(t, acl.Sources, len(tt.expected.Sources)) + assert.Len(t, acl.Destinations, len(tt.expected.Destinations)) // Compare sources for i, expectedSrc := range tt.expected.Sources { @@ -3409,14 +3418,15 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { // Unmarshal back var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) require.NoError(t, err) // Should be equal assert.Equal(t, original.Action, unmarshaled.Action) assert.Equal(t, original.Protocol, unmarshaled.Protocol) - assert.Equal(t, len(original.Sources), len(unmarshaled.Sources)) - assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations)) + assert.Len(t, unmarshaled.Sources, len(original.Sources)) + assert.Len(t, unmarshaled.Destinations, len(original.Destinations)) } func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { @@ -3484,15 +3494,16 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { _, err := unmarshalPolicy([]byte(policyJSON)) require.Error(t, err) - assert.Contains(t, err.Error(), `invalid action "deny"`) + assert.Contains(t, err.Error(), `invalid ACL action: "deny"`) } -// Helper function to parse aliases for testing +// Helper function to parse aliases for testing. func mustParseAlias(s string) Alias { alias, err := parseAlias(s) if err != nil { panic(err) } + return alias } diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index a4367775..ddf41f8e 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -9,6 +9,18 @@ import ( "tailscale.com/tailcfg" ) +// Port parsing errors. +var ( + ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port") + ErrInputStartsWithColon = errors.New("input cannot start with a colon character") + ErrInputEndsWithColon = errors.New("input cannot end with a colon character") + ErrInvalidPortRangeFormat = errors.New("invalid port range format") + ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port") + ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard") + ErrInvalidPortNumber = errors.New("invalid port number") + ErrPortNumberOutOfRange = errors.New("port number out of range") +) + // splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid. func splitDestinationAndPort(input string) (string, string, error) { // Find the last occurrence of the colon character @@ -16,13 +28,15 @@ func splitDestinationAndPort(input string) (string, string, error) { // Check if the colon character is present and not at the beginning or end of the string if lastColonIndex == -1 { - return "", "", errors.New("input must contain a colon character separating destination and port") + return "", "", ErrInputMissingColon } + if lastColonIndex == 0 { - return "", "", errors.New("input cannot start with a colon character") + return "", "", ErrInputStartsWithColon } + if lastColonIndex == len(input)-1 { - return "", "", errors.New("input cannot end with a colon character") + return "", "", ErrInputEndsWithColon } // Split the string into destination and port based on the last colon @@ -45,11 +59,12 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { for part := range parts { if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") + rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { return e == "" }) if len(rangeParts) != 2 { - return nil, errors.New("invalid port range format") + return nil, ErrInvalidPortRangeFormat } first, err := parsePort(rangeParts[0]) @@ -63,7 +78,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if first > last { - return nil, errors.New("invalid port range: first port is greater than last port") + return nil, ErrPortRangeInverted } portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last}) @@ -74,7 +89,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if port < 1 { - return nil, errors.New("first port must be >0, or use '*' for wildcard") + return nil, ErrPortMustBePositive } portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port}) @@ -88,11 +103,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { func parsePort(portStr string) (uint16, error) { port, err := strconv.Atoi(portStr) if err != nil { - return 0, errors.New("invalid port number") + return 0, ErrInvalidPortNumber } if port < 0 || port > 65535 { - return 0, errors.New("port number out of range") + return 0, ErrPortNumberOutOfRange } return uint16(port), nil diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index 2084b22f..496f4618 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -1,7 +1,6 @@ package v2 import ( - "errors" "testing" "github.com/google/go-cmp/cmp" @@ -24,9 +23,9 @@ func TestParseDestinationAndPort(t *testing.T) { {"tag:api-server:443", "tag:api-server", "443", nil}, {"example-host-1:*", "example-host-1", "*", nil}, {"hostname:80-90", "hostname", "80-90", nil}, - {"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")}, - {":invalid", "", "", errors.New("input cannot start with a colon character")}, - {"invalid:", "", "", errors.New("input cannot end with a colon character")}, + {"invalidinput", "", "", ErrInputMissingColon}, + {":invalid", "", "", ErrInputStartsWithColon}, + {"invalid:", "", "", ErrInputEndsWithColon}, } for _, testCase := range testCases { @@ -58,9 +57,11 @@ func TestParsePort(t *testing.T) { if err != nil && err.Error() != test.err { t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) } + if err == nil && test.err != "" { t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) } + if result != test.expected { t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) } @@ -92,9 +93,11 @@ func TestParsePortRange(t *testing.T) { if err != nil && err.Error() != test.err { t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) } + if err == nil && test.err != "" { t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) } + if diff := cmp.Diff(result, test.expected); diff != "" { t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 8d729df5..ded86068 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -30,7 +30,7 @@ const nodeNameContextKey = contextKey("nodeName") type mapSession struct { h *Headscale req tailcfg.MapRequest - ctx context.Context + ctx context.Context //nolint:containedctx capVer tailcfg.CapabilityVersion cancelChMu deadlock.Mutex @@ -54,7 +54,7 @@ func (h *Headscale) newMapSession( w http.ResponseWriter, node *types.Node, ) *mapSession { - ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) + ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) //nolint:gosec // weak random is fine for jitter return &mapSession{ h: h, @@ -162,6 +162,7 @@ func (m *mapSession) serveLongPoll() { // This is not my favourite solution, but it kind of works in our eventually consistent world. ticker := time.NewTicker(time.Second) defer ticker.Stop() + disconnected := true // Wait up to 10 seconds for the node to reconnect. // 10 seconds was arbitrary chosen as a reasonable time to reconnect. @@ -170,6 +171,7 @@ func (m *mapSession) serveLongPoll() { disconnected = false break } + <-ticker.C } @@ -222,7 +224,7 @@ func (m *mapSession) serveLongPoll() { // adding this before connecting it to the state ensure that // it does not miss any updates that might be sent in the split // time between the node connecting and the batcher being ready. - if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { //nolint:noinlineerr m.log.Error().Caller().Err(err).Msg("failed to add node to batcher") return } @@ -240,22 +242,26 @@ func (m *mapSession) serveLongPoll() { case <-m.cancelCh: m.log.Trace().Caller().Msg("poll cancelled received") mapResponseEnded.WithLabelValues("cancelled").Inc() + return case <-ctx.Done(): m.log.Trace().Caller().Str(zf.Chan, fmt.Sprintf("%p", m.ch)).Msg("poll context done") mapResponseEnded.WithLabelValues("done").Inc() + return // Consume updates sent to node case update, ok := <-m.ch: m.log.Trace().Caller().Bool(zf.OK, ok).Msg("received update from channel") + if !ok { m.log.Trace().Caller().Msg("update channel closed, streaming session is likely being replaced") return } - if err := m.writeMap(update); err != nil { + err := m.writeMap(update) + if err != nil { m.log.Error().Caller().Err(err).Msg("cannot write update to client") return } @@ -264,7 +270,8 @@ func (m *mapSession) serveLongPoll() { m.resetKeepAlive() case <-m.keepAliveTicker.C: - if err := m.writeMap(&keepAlive); err != nil { + err := m.writeMap(&keepAlive) + if err != nil { m.log.Error().Caller().Err(err).Msg("cannot write keep alive") return } @@ -272,6 +279,7 @@ func (m *mapSession) serveLongPoll() { if debugHighCardinalityMetrics { mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) } + mapResponseSent.WithLabelValues("ok", "keepalive").Inc() m.resetKeepAlive() } @@ -292,7 +300,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression) } - data := make([]byte, reservedResponseHeaderSize) + data := make([]byte, reservedResponseHeaderSize, reservedResponseHeaderSize+len(jsonBody)) //nolint:gosec // G115: JSON response size will not exceed uint32 max binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) data = append(data, jsonBody...) diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 52b1d75c..3a1db3dd 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -109,9 +109,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Msg("current primary no longer available") } } + if len(nodes) >= 1 { pr.primaries[prefix] = nodes[0] changed = true + log.Debug(). Caller(). Str(zf.Prefix, prefix.String()). @@ -128,6 +130,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Str(zf.Prefix, prefix.String()). Msg("cleaning up primary route that no longer has available nodes") delete(pr.primaries, prefix) + changed = true } } @@ -164,14 +167,17 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) // If no routes are being set, remove the node from the routes map. if len(prefixes) == 0 { wasPresent := false + if _, ok := pr.routes[node]; ok { delete(pr.routes, node) + wasPresent = true nlog.Debug(). Caller(). Msg("removed node from primary routes (no prefixes)") } + changed := pr.updatePrimaryLocked() nlog.Debug(). Caller(). @@ -253,12 +259,14 @@ func (pr *PrimaryRoutes) stringLocked() string { ids := types.NodeIDs(xmaps.Keys(pr.routes)) sort.Sort(ids) + for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) } fmt.Fprintln(&sb, "\n\nCurrent primary routes:") + for route, nodeID := range pr.primaries { fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID) } diff --git a/hscontrol/routes/primary_test.go b/hscontrol/routes/primary_test.go index 7a9767b2..b03c8f81 100644 --- a/hscontrol/routes/primary_test.go +++ b/hscontrol/routes/primary_test.go @@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) pr.SetRoutes(2, mp("192.168.2.0/24")) pr.SetRoutes(1) // Deregister by setting no routes + return pr.SetRoutes(1, mp("192.168.3.0/24")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) { { name: "multiple-nodes-register-same-route", operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - pr.SetRoutes(2, mp("192.168.1.0/24")) // true + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true + return pr.SetRoutes(3, mp("192.168.1.0/24")) // false }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) // false pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary - return pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(1) // true, 2 primary }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ 2: { @@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) + return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { var wg sync.WaitGroup wg.Add(2) + var change1, change2 bool + go func() { defer wg.Done() + change1 = pr.SetRoutes(1, mp("192.168.1.0/24")) }() go func() { defer wg.Done() + change2 = pr.SetRoutes(2, mp("192.168.2.0/24")) }() + wg.Wait() return change1 || change2 @@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pr := New() + change := tt.operations(pr) if change != tt.expectedChange { t.Errorf("change = %v, want %v", change, tt.expectedChange) } + comps := append(util.Comparers, cmpopts.EquateEmpty()) if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { t.Errorf("routes mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { t.Errorf("primaries mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 3ed1d79f..abb34eb0 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -77,6 +77,7 @@ func (s *State) DebugOverview() string { ephemeralCount := 0 now := time.Now() + for _, node := range allNodes.All() { if node.Valid() { userName := node.Owner().Name() @@ -103,17 +104,21 @@ func (s *State) DebugOverview() string { // User statistics sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users))) + for userName, nodeCount := range userNodeCounts { sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount)) } + sb.WriteString("\n") // Policy information sb.WriteString("Policy:\n") sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode)) + if s.cfg.Policy.Mode == types.PolicyModeFile { sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path)) } + sb.WriteString("\n") // DERP information @@ -123,6 +128,7 @@ func (s *State) DebugOverview() string { } else { sb.WriteString("DERP: not configured\n") } + sb.WriteString("\n") // Route information @@ -130,6 +136,7 @@ func (s *State) DebugOverview() string { if s.primaryRoutes.String() == "" { routeCount = 0 } + sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount)) sb.WriteString("\n") @@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string { for _, node := range region.Nodes { sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n", node.Name, node.HostName, node.DERPPort)) + if node.STUNPort != 0 { sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort)) } } + sb.WriteString("\n") } @@ -236,7 +245,7 @@ func (s *State) DebugPolicy() (string, error) { return string(pol), nil default: - return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode) + return "", fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, s.cfg.Policy.Mode) } } @@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { if s.primaryRoutes.String() == "" { routeCount = 0 } + info.PrimaryRoutes = routeCount return info diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go index 632af13c..65e7738c 100644 --- a/hscontrol/state/ephemeral_test.go +++ b/hscontrol/state/ephemeral_test.go @@ -21,6 +21,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // Create NodeStore store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -44,20 +45,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // 6. If DELETE came after UPDATE, the returned node should be invalid done := make(chan bool, 2) - var updatedNode types.NodeView - var updateOk bool + + var ( + updatedNode types.NodeView + updateOk bool + ) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { n.LastSeen = ptr.To(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(node.ID) + done <- true }() @@ -91,6 +98,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -148,6 +156,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { node := createTestNode(3, 1, "test-user", "test-node-3") store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -204,6 +213,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -214,8 +224,11 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 1. UpdateNode (from UpdateNodeFromMapRequest during polling) // 2. DeleteNode (from handleLogout when client sends logout request) - var updatedNode types.NodeView - var updateOk bool + var ( + updatedNode types.NodeView + updateOk bool + ) + done := make(chan bool, 2) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) @@ -223,12 +236,14 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { n.LastSeen = ptr.To(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(ephemeralNode.ID) + done <- true }() @@ -267,7 +282,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 5. UpdateNode and DeleteNode batch together // 6. UpdateNode returns a valid node (from before delete in batch) // 7. persistNodeToDB is called with the stale valid node -// 8. Node gets re-inserted into database instead of staying deleted +// 8. Node gets re-inserted into database instead of staying deleted. func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -279,6 +294,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -349,6 +365,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -399,7 +416,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // 3. UpdateNode and DeleteNode batch together // 4. UpdateNode returns a valid node (from before delete in batch) // 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node -// 6. persistNodeToDB must detect the node is deleted and refuse to persist +// 6. persistNodeToDB must detect the node is deleted and refuse to persist. func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -409,6 +426,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { } store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go index e7dfc11c..d8cddaa1 100644 --- a/hscontrol/state/maprequest.go +++ b/hscontrol/state/maprequest.go @@ -29,6 +29,7 @@ func netInfoFromMapRequest( Uint64("node.id", nodeID.Uint64()). Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP). Msg("using NetInfo from previous Hostinfo in MapRequest") + return currentHostinfo.NetInfo } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 99f781d4..8a842e49 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -1,15 +1,12 @@ package state import ( - "net/netip" "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -136,26 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node") }) } - -// Simple helper function for tests -func createTestNodeSimple(id types.NodeID) *types.Node { - user := types.User{ - Name: "test-user", - } - - machineKey := key.NewMachine() - nodeKey := key.NewNode() - - node := &types.Node{ - ID: id, - Hostname: "test-node", - UserID: ptr.To(uint(id)), - User: &user, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPv4: &netip.Addr{}, - IPv6: &netip.Addr{}, - } - - return node -} diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 6327b46b..1c921d6d 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -55,8 +55,8 @@ var ( }) nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: prometheusNamespace, - Name: "nodestore_nodes_total", - Help: "Total number of nodes in the NodeStore", + Name: "nodestore_nodes", + Help: "Number of nodes in the NodeStore", }) nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ Namespace: prometheusNamespace, @@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc for _, n := range allNodes { nodes[n.ID] = *n } + snap := snapshotFromNodes(nodes, peersFunc) store := &NodeStore{ @@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("put").Inc() return resultNode @@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("update").Inc() // Return the node and whether it exists (is valid) @@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() @@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() { if len(batch) != 0 { s.applyBatch(batch) } + return } + batch = append(batch, w) if len(batch) >= s.batchSize { s.applyBatch(batch) @@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) { w.updateFn(&n) nodes[w.nodeID] = n } + if w.nodeResult != nil { nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) } @@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) { nodeView := node.View() for _, w := range workItems { w.nodeResult <- nodeView + close(w.nodeResult) } } else { // Node was deleted or doesn't exist for _, w := range workItems { w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) } } @@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S peersByNode: func() map[types.NodeID][]types.NodeView { peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) defer peersTimer.ObserveDuration() + return peersFunc(allNodes) }(), nodesByUser: make(map[types.UserID][]types.NodeView), @@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S if newSnap.nodesByMachineKey[n.MachineKey] == nil { newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView) } + newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView } @@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string { // User distribution (shows internal UserID tracking, not display owner) sb.WriteString("Nodes by Internal User ID:\n") + for userID, nodes := range snapshot.nodesByUser { if len(nodes) > 0 { userName := "unknown" taggedCount := 0 + if len(nodes) > 0 && nodes[0].Valid() { userName = nodes[0].User().Name() // Count tagged nodes (which have UserID set but are owned by "tagged-devices") @@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string { } } } + sb.WriteString("\n") // Peer relationships summary sb.WriteString("Peer Relationships:\n") + totalPeers := 0 + for nodeID, peers := range snapshot.peersByNode { peerCount := len(peers) + totalPeers += peerCount if node, exists := snapshot.nodesByID[nodeID]; exists { sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n", nodeID, node.Hostname, peerCount)) } } + if len(snapshot.peersByNode) > 0 { avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode)) sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers)) } + sb.WriteString("\n") // Node key index @@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() { } s.writeQueue <- w + <-result } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 3d6184ba..736c3cfa 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -32,7 +32,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, peersFunc }, - validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper assert.Empty(t, snapshot.nodesByID) assert.Empty(t, snapshot.allNodes) assert.Empty(t, snapshot.peersByNode) @@ -45,9 +45,10 @@ func TestSnapshotFromNodes(t *testing.T) { nodes := map[types.NodeID]types.Node{ 1: createTestNode(1, 1, "user1", "node1"), } + return nodes, allowAllPeersFunc }, - validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper assert.Len(t, snapshot.nodesByID, 1) assert.Len(t, snapshot.allNodes, 1) assert.Len(t, snapshot.peersByNode, 1) @@ -70,7 +71,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, - validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper assert.Len(t, snapshot.nodesByID, 2) assert.Len(t, snapshot.allNodes, 2) assert.Len(t, snapshot.peersByNode, 2) @@ -95,7 +96,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, - validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper assert.Len(t, snapshot.nodesByID, 3) assert.Len(t, snapshot.allNodes, 3) assert.Len(t, snapshot.peersByNode, 3) @@ -124,7 +125,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, peersFunc }, - validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { //nolint:thelper assert.Len(t, snapshot.nodesByID, 4) assert.Len(t, snapshot.allNodes, 4) assert.Len(t, snapshot.peersByNode, 4) @@ -193,11 +194,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + for _, n := range nodes { if n.ID() != node.ID() { peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -208,6 +211,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 for _, n := range nodes { @@ -222,6 +226,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -236,7 +241,7 @@ func TestNodeStoreOperations(t *testing.T) { }{ { name: "create empty store and add single node", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) }, steps: []testStep{ @@ -274,7 +279,7 @@ func TestNodeStoreOperations(t *testing.T) { }, { name: "create store with initial node and add more", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper node1 := createTestNode(1, 1, "user1", "node1") initialNodes := types.Nodes{&node1} @@ -342,7 +347,7 @@ func TestNodeStoreOperations(t *testing.T) { }, { name: "test node deletion", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper node1 := createTestNode(1, 1, "user1", "node1") node2 := createTestNode(2, 1, "user1", "node2") node3 := createTestNode(3, 2, "user2", "node3") @@ -403,7 +408,7 @@ func TestNodeStoreOperations(t *testing.T) { }, { name: "test node updates", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper node1 := createTestNode(1, 1, "user1", "node1") node2 := createTestNode(2, 1, "user1", "node2") initialNodes := types.Nodes{&node1, &node2} @@ -445,7 +450,7 @@ func TestNodeStoreOperations(t *testing.T) { }, { name: "test with odd-even peers filtering", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper return NewNodeStore(nil, oddEvenPeersFunc, TestBatchSize, TestBatchTimeout) }, steps: []testStep{ @@ -455,10 +460,13 @@ func TestNodeStoreOperations(t *testing.T) { // Add nodes in sequence n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) assert.True(t, n4.Valid()) @@ -501,7 +509,7 @@ func TestNodeStoreOperations(t *testing.T) { }, { name: "test batch modifications return correct node state", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper node1 := createTestNode(1, 1, "user1", "node1") node2 := createTestNode(2, 1, "user1", "node2") initialNodes := types.Nodes{&node1, &node2} @@ -526,16 +534,20 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2 types.NodeView - var newNode3 types.NodeView - var ok1, ok2 bool + var ( + resultNode1, resultNode2 types.NodeView + newNode3 types.NodeView + ok1, ok2 bool + ) // These should all be processed in the same batch + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "batch-updated-node1" n.GivenName = "batch-given-1" }) + close(done1) }() @@ -544,12 +556,14 @@ func TestNodeStoreOperations(t *testing.T) { n.Hostname = "batch-updated-node2" n.GivenName = "batch-given-2" }) + close(done2) }() go func() { node3 := createTestNode(3, 1, "user1", "node3") newNode3 = store.PutNode(node3) + close(done3) }() @@ -602,20 +616,23 @@ func TestNodeStoreOperations(t *testing.T) { // This test verifies that when multiple updates to the same node // are batched together, each returned node reflects ALL changes // in the batch, not just the individual update's changes. - done1 := make(chan struct{}) done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2, resultNode3 types.NodeView - var ok1, ok2, ok3 bool + var ( + resultNode1, resultNode2, resultNode3 types.NodeView + ok1, ok2, ok3 bool + ) // These updates all modify node 1 and should be batched together // The final state should have all three modifications applied + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "multi-update-hostname" }) + close(done1) }() @@ -623,6 +640,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "multi-update-givenname" }) + close(done2) }() @@ -630,6 +648,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"tag1", "tag2"} }) + close(done3) }() @@ -673,7 +692,7 @@ func TestNodeStoreOperations(t *testing.T) { }, { name: "test UpdateNode result is immutable for database save", - setupFunc: func(t *testing.T) *NodeStore { + setupFunc: func(t *testing.T) *NodeStore { //nolint:thelper node1 := createTestNode(1, 1, "user1", "node1") node2 := createTestNode(2, 1, "user1", "node2") initialNodes := types.Nodes{&node1, &node2} @@ -723,14 +742,18 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var result1, result2, result3 types.NodeView - var ok1, ok2, ok3 bool + var ( + result1, result2, result3 types.NodeView + ok1, ok2, ok3 bool + ) // Start concurrent updates + go func() { result1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "concurrent-db-hostname" }) + close(done1) }() @@ -738,6 +761,7 @@ func TestNodeStoreOperations(t *testing.T) { result2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "concurrent-db-given" }) + close(done2) }() @@ -745,6 +769,7 @@ func TestNodeStoreOperations(t *testing.T) { result3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"concurrent-tag"} }) + close(done3) }() @@ -828,6 +853,7 @@ func TestNodeStoreOperations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := tt.setupFunc(t) + store.Start() defer store.Stop() @@ -847,10 +873,11 @@ type testStep struct { // --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- -// Helper for concurrent test nodes +// Helper for concurrent test nodes. func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { machineKey := key.NewMachine() nodeKey := key.NewNode() + return types.Node{ ID: id, Hostname: hostname, @@ -863,72 +890,88 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { } } -// --- Concurrency: concurrent PutNode operations --- +// --- Concurrency: concurrent PutNode operations ---. func TestNodeStoreConcurrentPutNode(t *testing.T) { const concurrentOps = 20 store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, concurrentOps) for i := range concurrentOps { wg.Add(1) + go func(nodeID int) { defer wg.Done() - node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + + node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") //nolint:gosec // safe conversion in test + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") } -// --- Batching: concurrent ops fit in one batch --- +// --- Batching: concurrent ops fit in one batch ---. func TestNodeStoreBatchingEfficiency(t *testing.T) { - const batchSize = 10 const ops = 15 // more than batchSize store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, ops) for i := range ops { wg.Add(1) + go func(nodeID int) { defer wg.Done() - node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + + node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") //nolint:gosec // test code with small integers + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") } -// --- Race conditions: many goroutines on same node --- +// --- Race conditions: many goroutines on same node ---. func TestNodeStoreRaceConditions(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -937,13 +980,18 @@ func TestNodeStoreRaceConditions(t *testing.T) { resultNode := store.PutNode(node) require.True(t, resultNode.Valid()) - const numGoroutines = 30 - const opsPerGoroutine = 10 + const ( + numGoroutines = 30 + opsPerGoroutine = 10 + ) + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) for i := range numGoroutines { wg.Add(1) + go func(gid int) { defer wg.Done() @@ -954,40 +1002,46 @@ func TestNodeStoreRaceConditions(t *testing.T) { n.Hostname = "race-updated" }) if !resultNode.Valid() { - errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) //nolint:err113 } case 1: retrieved, found := store.GetNode(nodeID) if !found || !retrieved.Valid() { - errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) //nolint:err113 } case 2: newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) if !resultNode.Valid() { - errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) //nolint:err113 } } } }(i) } + wg.Wait() close(errors) errorCount := 0 + for err := range errors { t.Error(err) + errorCount++ } + if errorCount > 0 { t.Fatalf("Race condition test failed with %d errors", errorCount) } } -// --- Resource cleanup: goroutine leak detection --- +// --- Resource cleanup: goroutine leak detection ---. func TestNodeStoreResourceCleanup(t *testing.T) { // initialGoroutines := runtime.NumGoroutine() store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1001,7 +1055,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) { const ops = 100 for i := range ops { - nodeID := types.NodeID(i + 1) + nodeID := types.NodeID(i + 1) //nolint:gosec // test code with small integers node := createConcurrentTestNode(nodeID, "cleanup-node") resultNode := store.PutNode(node) assert.True(t, resultNode.Valid()) @@ -1010,10 +1064,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }) retrieved, found := store.GetNode(nodeID) assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { store.DeleteNode(nodeID) } } + runtime.GC() // Wait for goroutines to settle and check for leaks @@ -1024,9 +1080,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }, time.Second, 10*time.Millisecond, "goroutines should not leak") } -// --- Timeout/deadlock: operations complete within reasonable time --- +// --- Timeout/deadlock: operations complete within reasonable time ---. func TestNodeStoreOperationTimeout(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1034,36 +1091,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) { defer cancel() const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) updateResults := make([]error, ops) // Launch all PutNode operations concurrently for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec // test code with small integers + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) node := createConcurrentTestNode(id, "timeout-node") resultNode := store.PutNode(node) endPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { - putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) + putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) //nolint:err113 } }(i, nodeID) } + wg.Wait() // Launch all UpdateNode operations concurrently wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec // test code with small integers + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) resultNode, ok := store.UpdateNode(id, func(n *types.Node) { @@ -1071,31 +1139,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) { }) endUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { - updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) + updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) //nolint:err113 } }(i, nodeID) } + done := make(chan struct{}) + go func() { wg.Wait() close(done) }() + select { case <-done: errorCount := 0 + for _, err := range putResults { if err != nil { t.Error(err) + errorCount++ } } + for _, err := range updateResults { if err != nil { t.Error(err) + errorCount++ } } + if errorCount == 0 { t.Log("All concurrent operations completed successfully within timeout") } else { @@ -1107,13 +1184,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) { } } -// --- Edge case: update non-existent node --- +// --- Edge case: update non-existent node ---. func TestNodeStoreUpdateNonExistentNode(t *testing.T) { for i := range 10 { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() - nonExistentID := types.NodeID(999 + i) + + nonExistentID := types.NodeID(999 + i) //nolint:gosec // test code with small integers updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { updateCallCount++ @@ -1127,20 +1206,22 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) { } } -// --- Allocation benchmark --- +// --- Allocation benchmark ---. func BenchmarkNodeStoreAllocations(b *testing.B) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() for i := 0; b.Loop(); i++ { - nodeID := types.NodeID(i + 1) + nodeID := types.NodeID(i + 1) //nolint:gosec // benchmark code with small integers node := createConcurrentTestNode(nodeID, "bench-node") store.PutNode(node) store.UpdateNode(nodeID, func(n *types.Node) { n.Hostname = "bench-updated" }) store.GetNode(nodeID) + if i%10 == 9 { store.DeleteNode(nodeID) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index a3827599..16ea06d3 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -230,6 +230,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // propagate correctly when switching between policy types. s.nodeStore.RebuildPeerMaps() + //nolint:prealloc // cs starts with one element and may grow cs := []change.Change{change.PolicyChange()} // Always call autoApproveNodes during policy reload, regardless of whether @@ -260,7 +261,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // CreateUser creates a new user and updates the policy manager. // Returns the created user, change set, and any error. func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) { - if err := s.db.DB.Save(&user).Error; err != nil { + if err := s.db.DB.Save(&user).Error; err != nil { //nolint:noinlineerr return nil, change.Change{}, fmt.Errorf("creating user: %w", err) } @@ -294,7 +295,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return nil, err } - if err := updateFn(user); err != nil { + if err := updateFn(user); err != nil { //nolint:noinlineerr return nil, err } @@ -512,7 +513,7 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { }) if !ok { - return nil, fmt.Errorf("node not found: %d", id) + return nil, fmt.Errorf("%w: %d", ErrNodeNotFound, id) } log.Info().EmbedObject(node).Msg("node disconnected") @@ -765,7 +766,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, // Check name uniqueness against NodeStore allNodes := s.nodeStore.ListNodes() - for i := 0; i < allNodes.Len(); i++ { + for i := range allNodes.Len() { node := allNodes.At(i) if node.ID() != nodeID && node.AsStruct().GivenName == newName { return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName) @@ -832,7 +833,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha var updates []change.Change - for _, node := range s.nodeStore.ListNodes().All() { + for _, node := range s.nodeStore.ListNodes().All() { //nolint:unqueryvet // NodeStore.ListNodes not a SQL query if !node.Valid() { continue } @@ -1850,7 +1851,7 @@ func (s *State) HandleNodeFromPreAuthKey( } } - return nil, nil + return nil, nil //nolint:nilnil // intentional: transaction success }) if err != nil { return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err) diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 1a949173..d6ef380e 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -13,6 +13,9 @@ import ( "tailscale.com/types/logger" ) +// ErrNoCertDomains is returned when no cert domains are available for HTTPS. +var ErrNoCertDomains = errors.New("no cert domains available for HTTPS") + func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error { opts := tailsql.Options{ Hostname: "tailsql-headscale", @@ -41,15 +44,17 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s defer tsNode.Close() logf("Starting tailscale (hostname=%q)", opts.Hostname) + lc, err := tsNode.LocalClient() if err != nil { return fmt.Errorf("connect local client: %w", err) } + opts.LocalClient = lc // for authentication // Make sure the Tailscale node starts up. It might not, if it is a new node // and the user did not provide an auth key. - if st, err := tsNode.Up(ctx); err != nil { + if st, err := tsNode.Up(ctx); err != nil { //nolint:noinlineerr return fmt.Errorf("starting tailscale: %w", err) } else { logf("tailscale started, node state %q", st.BackendState) @@ -71,28 +76,38 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. certDomains := tsNode.CertDomains() if len(certDomains) == 0 { - return errors.New("no cert domains available for HTTPS") + return ErrNoCertDomains } + base := "https://" + certDomains[0] - go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - target := base + r.RequestURI - http.Redirect(w, r, target, http.StatusPermanentRedirect) - })) + + go func() { + _ = http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { //nolint:gosec + target := base + r.RequestURI + http.Redirect(w, r, target, http.StatusPermanentRedirect) + })) + }() // log.Printf("Redirecting HTTP to HTTPS at %q", base) // For the real service, start a separate listener. // Note: Replaces the port 80 listener. var err error + lst, err = tsNode.ListenTLS("tcp", ":443") if err != nil { return fmt.Errorf("listen TLS: %w", err) } + logf("enabled serving via HTTPS") } mux := tsql.NewMux() tsweb.Debugger(mux) - go http.Serve(lst, mux) + + go func() { + _ = http.Serve(lst, mux) //nolint:gosec + }() + logf("TailSQL started") <-ctx.Done() logf("TailSQL shutting down...") diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index f4814519..d852753e 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -20,7 +20,11 @@ const ( DatabaseSqlite = "sqlite3" ) -var ErrCannotParsePrefix = errors.New("cannot parse prefix") +// Common errors. +var ( + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidRegistrationIDLength = errors.New("registration ID has invalid length") +) type StateUpdateType int @@ -100,6 +104,10 @@ func (su *StateUpdate) Empty() bool { return len(su.ChangePatches) == 0 case StatePeerRemoved: return len(su.Removed) == 0 + case StateFullUpdate, StateSelfUpdate, StateDERPUpdated: + // These update types don't have associated data to check, + // so they are never considered empty. + return false } return false @@ -175,8 +183,9 @@ func MustRegistrationID() RegistrationID { func RegistrationIDFromString(str string) (RegistrationID, error) { if len(str) != RegistrationIDLength { - return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) + return "", fmt.Errorf("%w: expected %d, got %d", ErrInvalidRegistrationIDLength, RegistrationIDLength, len(str)) } + return RegistrationID(str), nil } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4b0cd240..b030c384 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -33,10 +33,12 @@ const ( ) var ( - errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") - errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") - errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") - errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + ErrNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + ErrInvalidAllocationStrategy = errors.New("invalid prefix allocation strategy") ) type IPAllocationStrategy string @@ -301,6 +303,7 @@ func validatePKCEMethod(method string) error { if method != PKCEMethodPlain && method != PKCEMethodS256 { return errInvalidPKCEMethod } + return nil } @@ -326,6 +329,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetConfigFile(path) } else { viper.SetConfigName("config") + if path == "" { viper.AddConfigPath("/etc/headscale/") viper.AddConfigPath("$HOME/.headscale") @@ -401,8 +405,10 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) - if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { + err := viper.ReadInConfig() + if err != nil { + var configFileNotFoundError viper.ConfigFileNotFoundError + if errors.As(err, &configFileNotFoundError) { log.Warn().Msg("no config file found, using defaults") return nil } @@ -442,7 +448,8 @@ func validateServerConfig() error { depr.fatal("oidc.map_legacy_users") if viper.GetBool("oidc.enabled") { - if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil { + err := validatePKCEMethod(viper.GetString("oidc.pkce.method")) + if err != nil { return err } } @@ -556,6 +563,7 @@ func derpConfig() DERPConfig { automaticallyAddEmbeddedDerpRegion := viper.GetBool( "derp.server.automatically_add_embedded_derp_region", ) + if serverEnabled && stunAddr == "" { log.Fatal(). Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") @@ -625,13 +633,16 @@ func policyConfig() PolicyConfig { func logConfig() LogConfig { logLevelStr := viper.GetString("log.level") + logLevel, err := zerolog.ParseLevel(logLevelStr) if err != nil { logLevel = zerolog.DebugLevel } logFormatOpt := viper.GetString("log.format") + var logFormat string + switch logFormatOpt { case JSONLogFormat: logFormat = JSONLogFormat @@ -658,7 +669,7 @@ func databaseConfig() DatabaseConfig { type_ := viper.GetString("database.type") skipErrRecordNotFound := viper.GetBool("database.gorm.skip_err_record_not_found") - slowThreshold := viper.GetDuration("database.gorm.slow_threshold") * time.Millisecond + slowThreshold := time.Duration(viper.GetInt64("database.gorm.slow_threshold")) * time.Millisecond parameterizedQueries := viper.GetBool("database.gorm.parameterized_queries") prepareStmt := viper.GetBool("database.gorm.prepare_stmt") @@ -730,6 +741,7 @@ func dns() (DNSConfig, error) { if err != nil { return DNSConfig{}, fmt.Errorf("unmarshalling dns extra records: %w", err) } + dns.ExtraRecords = extraRecords } @@ -745,30 +757,23 @@ func (d *DNSConfig) globalResolvers() []*dnstype.Resolver { var resolvers []*dnstype.Resolver for _, nsStr := range d.Nameservers.Global { - warn := "" - if _, err := netip.ParseAddr(nsStr); err == nil { + if _, err := netip.ParseAddr(nsStr); err == nil { //nolint:noinlineerr resolvers = append(resolvers, &dnstype.Resolver{ Addr: nsStr, }) continue - } else { - warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err) } - if _, err := url.Parse(nsStr); err == nil { + if _, err := url.Parse(nsStr); err == nil { //nolint:noinlineerr resolvers = append(resolvers, &dnstype.Resolver{ Addr: nsStr, }) continue - } else { - warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err) } - if warn != "" { - log.Warn().Msg(warn) - } + log.Warn().Str("nameserver", nsStr).Msg("invalid global nameserver, ignoring") } return resolvers @@ -780,34 +785,30 @@ func (d *DNSConfig) globalResolvers() []*dnstype.Resolver { // If a nameserver is neither a valid URL nor a valid IP, it will be ignored. func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver { routes := make(map[string][]*dnstype.Resolver) + for domain, nameservers := range d.Nameservers.Split { var resolvers []*dnstype.Resolver + for _, nsStr := range nameservers { - warn := "" - if _, err := netip.ParseAddr(nsStr); err == nil { + if _, err := netip.ParseAddr(nsStr); err == nil { //nolint:noinlineerr resolvers = append(resolvers, &dnstype.Resolver{ Addr: nsStr, }) continue - } else { - warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err) } - if _, err := url.Parse(nsStr); err == nil { + if _, err := url.Parse(nsStr); err == nil { //nolint:noinlineerr resolvers = append(resolvers, &dnstype.Resolver{ Addr: nsStr, }) continue - } else { - warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err) } - if warn != "" { - log.Warn().Msg(warn) - } + log.Warn().Str("nameserver", nsStr).Str("domain", domain).Msg("invalid split dns nameserver, ignoring") } + routes[domain] = resolvers } @@ -822,6 +823,7 @@ func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { } cfg.Proxied = dns.MagicDNS + cfg.ExtraRecords = dns.ExtraRecords if dns.OverrideLocalDNS { cfg.Resolvers = dns.globalResolvers() @@ -830,10 +832,12 @@ func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { } routes := dns.splitResolvers() + cfg.Routes = routes if dns.BaseDomain != "" { cfg.Domains = []string{dns.BaseDomain} } + cfg.Domains = append(cfg.Domains, dns.SearchDomains...) return &cfg @@ -843,7 +847,7 @@ func prefixV4() (*netip.Prefix, error) { prefixV4Str := viper.GetString("prefixes.v4") if prefixV4Str == "" { - return nil, nil + return nil, nil //nolint:nilnil // empty prefix is valid, not an error } prefixV4, err := netip.ParsePrefix(prefixV4Str) @@ -853,6 +857,7 @@ func prefixV4() (*netip.Prefix, error) { builder := netipx.IPSetBuilder{} builder.AddPrefix(tsaddr.CGNATRange()) + ipSet, _ := builder.IPSet() if !ipSet.ContainsPrefix(prefixV4) { log.Warn(). @@ -867,7 +872,7 @@ func prefixV6() (*netip.Prefix, error) { prefixV6Str := viper.GetString("prefixes.v6") if prefixV6Str == "" { - return nil, nil + return nil, nil //nolint:nilnil // empty prefix is valid, not an error } prefixV6, err := netip.ParsePrefix(prefixV6Str) @@ -910,7 +915,7 @@ func LoadCLIConfig() (*Config, error) { // LoadServerConfig returns the full Headscale configuration to // host a Headscale server. This is called as part of `headscale serve`. func LoadServerConfig() (*Config, error) { - if err := validateServerConfig(); err != nil { + if err := validateServerConfig(); err != nil { //nolint:noinlineerr return nil, err } @@ -928,11 +933,13 @@ func LoadServerConfig() (*Config, error) { } if prefix4 == nil && prefix6 == nil { - return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + return nil, ErrNoPrefixConfigured } allocStr := viper.GetString("prefixes.allocation") + var alloc IPAllocationStrategy + switch allocStr { case string(IPAllocationStrategySequential): alloc = IPAllocationStrategySequential @@ -940,7 +947,8 @@ func LoadServerConfig() (*Config, error) { alloc = IPAllocationStrategyRandom default: return nil, fmt.Errorf( - "config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", + "%w: %q, allowed options: %s, %s", + ErrInvalidAllocationStrategy, allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom, @@ -957,15 +965,18 @@ func LoadServerConfig() (*Config, error) { randomizeClientPort := viper.GetBool("randomize_client_port") oidcClientSecret := viper.GetString("oidc.client_secret") + oidcClientSecretPath := viper.GetString("oidc.client_secret_path") if oidcClientSecretPath != "" && oidcClientSecret != "" { return nil, errOidcMutuallyExclusive } + if oidcClientSecretPath != "" { secretBytes, err := os.ReadFile(os.ExpandEnv(oidcClientSecretPath)) if err != nil { return nil, err } + oidcClientSecret = strings.TrimSpace(string(secretBytes)) } @@ -979,7 +990,8 @@ func LoadServerConfig() (*Config, error) { // - Control plane runs on login.tailscale.com/controlplane.tailscale.com // - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net) if dnsConfig.BaseDomain != "" { - if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil { + err := isSafeServerURL(serverURL, dnsConfig.BaseDomain) + if err != nil { return nil, err } } @@ -994,7 +1006,7 @@ func LoadServerConfig() (*Config, error) { PrefixV4: prefix4, PrefixV6: prefix6, - IPAllocation: IPAllocationStrategy(alloc), + IPAllocation: alloc, NoisePrivateKeyPath: util.AbsolutePathFromConfigPath( viper.GetString("noise.private_key_path"), @@ -1082,6 +1094,7 @@ func LoadServerConfig() (*Config, error) { if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 { return workers } + return DefaultBatcherWorkers() }(), RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), @@ -1117,6 +1130,7 @@ func isSafeServerURL(serverURL, baseDomain string) error { } s := len(serverDomainParts) + b := len(baseDomainParts) for i := range baseDomainParts { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { @@ -1134,9 +1148,12 @@ type deprecator struct { // warnWithAlias will register an alias between the newKey and the oldKey, // and log a deprecation warning if the oldKey is set. +// +//nolint:unused func (d *deprecator) warnWithAlias(newKey, oldKey string) { // NOTE: RegisterAlias is called with NEW KEY -> OLD KEY viper.RegisterAlias(newKey, oldKey) + if viper.IsSet(oldKey) { d.warns.Add( fmt.Sprintf( @@ -1179,6 +1196,8 @@ func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) { } // warn deprecates and adds an option to log a warning if the oldKey is set. +// +//nolint:unused func (d *deprecator) warnNoAlias(newKey, oldKey string) { if viper.IsSet(oldKey) { d.warns.Add( @@ -1193,6 +1212,8 @@ func (d *deprecator) warnNoAlias(newKey, oldKey string) { } // warn deprecates and adds an entry to the warn list of options if the oldKey is set. +// +//nolint:unused func (d *deprecator) warn(oldKey string) { if viper.IsSet(oldKey) { d.warns.Add( diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 6b9fc2ef..b281cb9d 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -26,7 +26,7 @@ func TestReadConfig(t *testing.T) { { name: "unmarshal-dns-full-config", configPath: "testdata/dns_full.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper dns, err := dns() if err != nil { return nil, err @@ -61,7 +61,7 @@ func TestReadConfig(t *testing.T) { { name: "dns-to-tailcfg.DNSConfig", configPath: "testdata/dns_full.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper dns, err := dns() if err != nil { return nil, err @@ -92,7 +92,7 @@ func TestReadConfig(t *testing.T) { { name: "unmarshal-dns-full-no-magic", configPath: "testdata/dns_full_no_magic.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper dns, err := dns() if err != nil { return nil, err @@ -127,7 +127,7 @@ func TestReadConfig(t *testing.T) { { name: "dns-to-tailcfg.DNSConfig", configPath: "testdata/dns_full_no_magic.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper dns, err := dns() if err != nil { return nil, err @@ -158,7 +158,7 @@ func TestReadConfig(t *testing.T) { { name: "base-domain-in-server-url-err", configPath: "testdata/base-domain-in-server-url.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper return LoadServerConfig() }, want: nil, @@ -167,7 +167,7 @@ func TestReadConfig(t *testing.T) { { name: "base-domain-not-in-server-url", configPath: "testdata/base-domain-not-in-server-url.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper cfg, err := LoadServerConfig() if err != nil { return nil, err @@ -187,7 +187,7 @@ func TestReadConfig(t *testing.T) { { name: "dns-override-true-errors", configPath: "testdata/dns-override-true-error.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper return LoadServerConfig() }, wantErr: "Fatal config error: dns.nameservers.global must be set when dns.override_local_dns is true", @@ -195,7 +195,7 @@ func TestReadConfig(t *testing.T) { { name: "dns-override-true", configPath: "testdata/dns-override-true.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper _, err := LoadServerConfig() if err != nil { return nil, err @@ -221,7 +221,7 @@ func TestReadConfig(t *testing.T) { { name: "policy-path-is-loaded", configPath: "testdata/policy-path-is-loaded.yaml", - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper // inline test closure cfg, err := LoadServerConfig() if err != nil { return nil, err @@ -242,6 +242,7 @@ func TestReadConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { viper.Reset() + err := LoadConfig(tt.configPath, true) require.NoError(t, err) @@ -276,14 +277,14 @@ func TestReadConfigFromEnv(t *testing.T) { "HEADSCALE_DATABASE_SQLITE_WRITE_AHEAD_LOG": "false", "HEADSCALE_PREFIXES_V4": "100.64.0.0/10", }, - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper // inline test closure t.Logf("all settings: %#v", viper.AllSettings()) assert.Equal(t, "trace", viper.GetString("log.level")) assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4")) assert.False(t, viper.GetBool("database.sqlite.write_ahead_log")) - return nil, nil + return nil, nil //nolint:nilnil // test setup returns nil to indicate no expected value }, want: nil, }, @@ -300,7 +301,7 @@ func TestReadConfigFromEnv(t *testing.T) { // "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`, // "HEADSCALE_DNS_EXTRA_RECORDS": `[{ name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" }]`, }, - setup: func(t *testing.T) (any, error) { + setup: func(t *testing.T) (any, error) { //nolint:thelper // inline test closure t.Logf("all settings: %#v", viper.AllSettings()) dns, err := dns() @@ -335,6 +336,7 @@ func TestReadConfigFromEnv(t *testing.T) { } viper.Reset() + err := LoadConfig("testdata/minimal.yaml", true) require.NoError(t, err) @@ -349,11 +351,10 @@ func TestReadConfigFromEnv(t *testing.T) { } func TestTLSConfigValidation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - t.Fatal(err) - } - // defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() + + var err error + configYaml := []byte(`--- tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: "" @@ -363,6 +364,7 @@ noise: // Populate a custom config file configFilePath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) @@ -398,10 +400,12 @@ server_url: http://127.0.0.1:8080 tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: TLS-ALPN-01 `) + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) } + err = LoadConfig(tmpDir, false) require.NoError(t, err) } @@ -463,6 +467,7 @@ func TestSafeServerURL(t *testing.T) { return } + assert.NoError(t, err) }) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 4625a298..c699d6df 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -53,7 +53,7 @@ func (id NodeID) StableID() tailcfg.StableNodeID { } func (id NodeID) NodeID() tailcfg.NodeID { - return tailcfg.NodeID(id) + return tailcfg.NodeID(id) //nolint:gosec // NodeID is bounded } func (id NodeID) Uint64() uint64 { @@ -162,11 +162,12 @@ func (node *Node) GivenNameHasBeenChanged() bool { // Strip invalid DNS characters for givenName comparison normalised := strings.ToLower(node.Hostname) normalised = invalidDNSRegex.ReplaceAllString(normalised, "") + return node.GivenName == normalised } // IsExpired returns whether the node registration has expired. -func (node Node) IsExpired() bool { +func (node *Node) IsExpired() bool { // If Expiry is not set, the client has not indicated that // it wants an expiry time, it is therefore considered // to mean "not expired" @@ -245,8 +246,14 @@ func (node *Node) RequestTags() []string { } func (node *Node) Prefixes() []netip.Prefix { - var addrs []netip.Prefix - for _, nodeAddress := range node.IPs() { + ips := node.IPs() + if len(ips) == 0 { + return nil + } + + addrs := make([]netip.Prefix, 0, len(ips)) + + for _, nodeAddress := range ips { ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) addrs = append(addrs, ip) } @@ -274,9 +281,14 @@ func (node *Node) IsExitNode() bool { } func (node *Node) IPsAsString() []string { - var ret []string + ips := node.IPs() + if len(ips) == 0 { + return nil + } - for _, ip := range node.IPs() { + ret := make([]string, 0, len(ips)) + + for _, ip := range ips { ret = append(ret, ip.String()) } @@ -480,7 +492,7 @@ func (node *Node) IsSubnetRouter() bool { return len(node.SubnetRoutes()) > 0 } -// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes +// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes. func (node *Node) AllApprovedRoutes() []netip.Prefix { return append(node.SubnetRoutes(), node.ExitRoutes()...) } @@ -527,7 +539,7 @@ func (node *Node) MarshalZerologObject(e *zerolog.Event) { // - logTracePeerChange in poll.go. func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { ret := tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), + NodeID: tailcfg.NodeID(node.ID), //nolint:gosec // NodeID is bounded } if node.NodeKey.String() != req.NodeKey.String() { @@ -553,11 +565,9 @@ func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP } else if node.Hostinfo.NetInfo == nil { ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP - } else { + } else if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP { // If there is a PreferredDERP check if it has changed. - if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP { - ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP - } + ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP } } @@ -618,13 +628,16 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } newHostname := strings.ToLower(hostInfo.Hostname) - if err := util.ValidateHostname(newHostname); err != nil { + + err := util.ValidateHostname(newHostname) + if err != nil { log.Warn(). Str("node.id", node.ID.String()). Str("current_hostname", node.Hostname). Str("rejected_hostname", hostInfo.Hostname). Err(err). Msg("Rejecting invalid hostname update from hostinfo") + return } @@ -716,6 +729,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node { func (nodes Nodes) DebugString() string { var sb strings.Builder sb.WriteString("Nodes:\n") + for _, node := range nodes { sb.WriteString(node.DebugString()) sb.WriteString("\n") @@ -724,7 +738,7 @@ func (nodes Nodes) DebugString() string { return sb.String() } -func (node Node) DebugString() string { +func (node *Node) DebugString() string { var sb strings.Builder fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID) @@ -897,7 +911,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer // GetFQDN returns the fully qualified domain name for the node. func (nv NodeView) GetFQDN(baseDomain string) (string, error) { if !nv.Valid() { - return "", errors.New("creating valid FQDN: node view is invalid") + return "", fmt.Errorf("creating valid FQDN: %w", ErrInvalidNodeView) } return nv.ж.GetFQDN(baseDomain) diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 2de2efc9..40634525 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -407,7 +407,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data }, want: Node{ GivenName: "valid-hostname", @@ -491,7 +491,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "server-北京-01", + Hostname: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data }, want: Node{ GivenName: "valid-hostname", @@ -505,7 +505,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data }, want: Node{ GivenName: "valid-hostname", @@ -533,7 +533,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "测试💻机器", + Hostname: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data }, want: Node{ GivenName: "valid-hostname", diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 39a94222..d7d8d741 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -116,7 +116,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { return &protoKey } -// canUsePreAuthKey checks if a pre auth key can be used. +// Validate checks if a pre auth key can be used. func (pak *PreAuthKey) Validate() error { if pak == nil { return PAKError("invalid authkey") diff --git a/hscontrol/types/preauth_key_test.go b/hscontrol/types/preauth_key_test.go index 4ab1c717..7bd1d552 100644 --- a/hscontrol/types/preauth_key_test.go +++ b/hscontrol/types/preauth_key_test.go @@ -111,6 +111,7 @@ func TestCanUsePreAuthKey(t *testing.T) { t.Errorf("expected error but got none") } else { var httpErr PAKError + ok := errors.As(err, &httpErr) if !ok { t.Errorf("expected HTTPError but got %T", err) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index b46e1162..2593bda0 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -4,6 +4,7 @@ import ( "cmp" "database/sql" "encoding/json" + "errors" "fmt" "net/mail" "net/url" @@ -20,6 +21,9 @@ import ( "tailscale.com/tailcfg" ) +// ErrCannotParseBoolean is returned when a value cannot be parsed as boolean. +var ErrCannotParseBoolean = errors.New("cannot parse value as boolean") + type UserID uint64 type Users []User @@ -42,9 +46,11 @@ var TaggedDevices = User{ func (u Users) String() string { var sb strings.Builder sb.WriteString("[ ") + for _, user := range u { fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) } + sb.WriteString(" ]") return sb.String() @@ -55,7 +61,8 @@ func (u Users) String() string { // At the end of the day, users in Tailscale are some kind of 'bubbles' or users // that contain our machines. type User struct { - gorm.Model + gorm.Model //nolint:embeddedstructfieldcheck + // The index `idx_name_provider_identifier` is to enforce uniqueness // between Name and ProviderIdentifier. This ensures that // you can have multiple users with the same name in OIDC, @@ -91,6 +98,7 @@ func (u *User) StringID() string { if u == nil { return "" } + return strconv.FormatUint(uint64(u.ID), 10) } @@ -130,7 +138,7 @@ func (u *User) profilePicURL() string { func (u *User) TailscaleUser() tailcfg.User { return tailcfg.User{ - ID: tailcfg.UserID(u.ID), + ID: tailcfg.UserID(u.ID), //nolint:gosec // UserID is bounded DisplayName: u.Display(), ProfilePicURL: u.profilePicURL(), Created: u.CreatedAt, @@ -150,7 +158,7 @@ func (u UserView) ID() uint { func (u *User) TailscaleLogin() tailcfg.Login { return tailcfg.Login{ - ID: tailcfg.LoginID(u.ID), + ID: tailcfg.LoginID(u.ID), //nolint:gosec // safe conversion for user ID Provider: u.Provider, LoginName: u.Username(), DisplayName: u.Display(), @@ -164,7 +172,7 @@ func (u UserView) TailscaleLogin() tailcfg.Login { func (u *User) TailscaleUserProfile() tailcfg.UserProfile { return tailcfg.UserProfile{ - ID: tailcfg.UserID(u.ID), + ID: tailcfg.UserID(u.ID), //nolint:gosec // UserID is bounded LoginName: u.Username(), DisplayName: u.Display(), ProfilePicURL: u.profilePicURL(), @@ -184,6 +192,7 @@ func (u *User) Proto() *v1.User { if name == "" { name = u.Username() } + return &v1.User{ Id: uint64(u.ID), Name: name, @@ -220,7 +229,7 @@ func (u UserView) MarshalZerologObject(e *zerolog.Event) { u.ж.MarshalZerologObject(e) } -// JumpCloud returns a JSON where email_verified is returned as a +// FlexibleBoolean handles JumpCloud's JSON where email_verified is returned as a // string "true" or "false" instead of a boolean. // This maps bool to a specific type with a custom unmarshaler to // ensure we can decode it from a string. @@ -229,6 +238,7 @@ type FlexibleBoolean bool func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { var val any + err := json.Unmarshal(data, &val) if err != nil { return fmt.Errorf("unmarshalling data: %w", err) @@ -242,10 +252,11 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { if err != nil { return fmt.Errorf("parsing %s as boolean: %w", v, err) } + *bit = FlexibleBoolean(pv) default: - return fmt.Errorf("parsing %v as boolean", v) + return fmt.Errorf("%w: %v", ErrCannotParseBoolean, v) } return nil @@ -279,9 +290,11 @@ func (c *OIDCClaims) Identifier() string { if c.Iss == "" && c.Sub == "" { return "" } + if c.Iss == "" { return CleanIdentifier(c.Sub) } + if c.Sub == "" { return CleanIdentifier(c.Iss) } @@ -292,9 +305,9 @@ func (c *OIDCClaims) Identifier() string { var result string // Try to parse as URL to handle URL joining correctly - if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { + if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { //nolint:noinlineerr // For URLs, use proper URL path joining - if joined, err := url.JoinPath(issuer, subject); err == nil { + if joined, err := url.JoinPath(issuer, subject); err == nil { //nolint:noinlineerr result = joined } } @@ -366,6 +379,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, trimmed) } } + if len(cleanParts) == 0 { return "" } @@ -408,6 +422,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) { if claims.Iss == "" && !strings.HasPrefix(identifier, "/") { identifier = "/" + identifier } + u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true} u.DisplayName = claims.Name u.ProfilePicURL = claims.ProfilePictureURL diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index 15386553..064388eb 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -66,10 +66,13 @@ func TestUnmarshallOIDCClaims(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims - if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + + err := json.Unmarshal([]byte(tt.jsonstr), &got) + if err != nil { t.Errorf("UnmarshallOIDCClaims() error = %v", err) return } + if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff) } @@ -190,6 +193,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) { } result := claims.Identifier() assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("Identifier() mismatch (-want +got):\n%s", diff) } @@ -282,6 +286,7 @@ func TestCleanIdentifier(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := CleanIdentifier(tt.identifier) assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff) } @@ -479,7 +484,9 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims - if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + + err := json.Unmarshal([]byte(tt.jsonstr), &got) + if err != nil { t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err) return } @@ -487,6 +494,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { var user User user.FromClaim(&got, tt.emailVerifiedRequired) + if diff := cmp.Diff(user, tt.want); diff != "" { t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/types/version.go b/hscontrol/types/version.go index 6676c92f..96dc58a6 100644 --- a/hscontrol/types/version.go +++ b/hscontrol/types/version.go @@ -38,9 +38,7 @@ func (v *VersionInfo) String() string { return sb.String() } -var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) { - return debug.ReadBuildInfo() -}) +var buildInfo = sync.OnceValues(debug.ReadBuildInfo) var GetVersionInfo = sync.OnceValue(func() *VersionInfo { info := &VersionInfo{ diff --git a/hscontrol/util/addr.go b/hscontrol/util/addr.go index c91ef0ba..782f15e6 100644 --- a/hscontrol/util/addr.go +++ b/hscontrol/util/addr.go @@ -91,6 +91,7 @@ func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) { func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) { var network, broadcast netip.Addr + ipRange := netipx.RangeOfPrefix(na) network = ipRange.From() broadcast = ipRange.To() diff --git a/hscontrol/util/addr_test.go b/hscontrol/util/addr_test.go index 0e08d707..7c6ae31a 100644 --- a/hscontrol/util/addr_test.go +++ b/hscontrol/util/addr_test.go @@ -29,6 +29,7 @@ func Test_parseIPSet(t *testing.T) { arg string bits *int } + tests := []struct { name string args args @@ -111,6 +112,7 @@ func Test_parseIPSet(t *testing.T) { return } + if diff := cmp.Diff(tt.want, got); diff != "" { t.Errorf("parseIPSet() = (-want +got):\n%s", diff) } diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index dcd58528..1b5a3806 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -18,13 +18,27 @@ const ( ipv4AddressLength = 32 ipv6AddressLength = 128 + // LabelHostnameLength is the maximum length for a DNS label, // value related to RFC 1123 and 952. LabelHostnameLength = 63 ) var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") -var ErrInvalidHostName = errors.New("invalid hostname") +// DNS validation errors. +var ( + ErrInvalidHostName = errors.New("invalid hostname") + ErrUsernameTooShort = errors.New("username must be at least 2 characters long") + ErrUsernameMustStartLetter = errors.New("username must start with a letter") + ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'") + ErrUsernameInvalidChar = errors.New("username contains invalid character") + ErrHostnameTooShort = errors.New("hostname is too short, must be at least 2 characters") + ErrHostnameTooLong = errors.New("hostname is too long, must not exceed 63 characters") + ErrHostnameMustBeLowercase = errors.New("hostname must be lowercase") + ErrHostnameHyphenBoundary = errors.New("hostname cannot start or end with a hyphen") + ErrHostnameDotBoundary = errors.New("hostname cannot start or end with a dot") + ErrHostnameInvalidChars = errors.New("hostname contains invalid characters") +) // ValidateUsername checks if a username is valid. // It must be at least 2 characters long, start with a letter, and contain @@ -34,12 +48,12 @@ var ErrInvalidHostName = errors.New("invalid hostname") func ValidateUsername(username string) error { // Ensure the username meets the minimum length requirement if len(username) < 2 { - return errors.New("username must be at least 2 characters long") + return ErrUsernameTooShort } // Ensure the username starts with a letter if !unicode.IsLetter(rune(username[0])) { - return errors.New("username must start with a letter") + return ErrUsernameMustStartLetter } atCount := 0 @@ -55,10 +69,10 @@ func ValidateUsername(username string) error { case char == '@': atCount++ if atCount > 1 { - return errors.New("username cannot contain more than one '@'") + return ErrUsernameTooManyAt } default: - return fmt.Errorf("username contains invalid character: '%c'", char) + return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char) } } @@ -70,44 +84,27 @@ func ValidateUsername(username string) error { // The hostname must already be lowercase and contain only valid characters. func ValidateHostname(name string) error { if len(name) < 2 { - return fmt.Errorf( - "hostname %q is too short, must be at least 2 characters", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameTooShort, name) } + if len(name) > LabelHostnameLength { - return fmt.Errorf( - "hostname %q is too long, must not exceed 63 characters", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameTooLong, name) } + if strings.ToLower(name) != name { - return fmt.Errorf( - "hostname %q must be lowercase (try %q)", - name, - strings.ToLower(name), - ) + return fmt.Errorf("%w: %q (try %q)", ErrHostnameMustBeLowercase, name, strings.ToLower(name)) } if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") { - return fmt.Errorf( - "hostname %q cannot start or end with a hyphen", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameHyphenBoundary, name) } if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { - return fmt.Errorf( - "hostname %q cannot start or end with a dot", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameDotBoundary, name) } if invalidDNSRegex.MatchString(name) { - return fmt.Errorf( - "hostname %q contains invalid characters, only lowercase letters, numbers, hyphens and dots are allowed", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameInvalidChars, name) } return nil @@ -170,6 +167,7 @@ func NormaliseHostname(name string) (string, error) { // and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next // class block only. +// GenerateIPv4DNSRootDomain generates the IPv4 reverse DNS root domains. // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // This allows us to then calculate the subnets included in the subsequent class block and generate the entries. func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { @@ -183,25 +181,27 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // wildcardBits is the number of bits not under the mask in the lastOctet wildcardBits := ByteSize - maskBits%ByteSize - // min is the value in the lastOctet byte of the IP - // max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1 - min := uint(netRange.IP[lastOctet]) - max := (min + 1<= 0; i-- { rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10)) } + rdnsSlice = append(rdnsSlice, "in-addr.arpa.") rdnsBase := strings.Join(rdnsSlice, ".") - fqdns := make([]dnsname.FQDN, 0, max-min+1) - for i := min; i <= max; i++ { + fqdns := make([]dnsname.FQDN, 0, maxVal-minVal+1) + for i := minVal; i <= maxVal; i++ { fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase)) if err != nil { continue } + fqdns = append(fqdns, fqdn) } @@ -226,6 +226,7 @@ func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next // class block only. +// GenerateIPv6DNSRootDomain generates the IPv6 reverse DNS root domains. // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // This allows us to then calculate the subnets included in the subsequent class block and generate the entries. func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { @@ -259,18 +260,22 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { } var fqdns []dnsname.FQDN + if maskBits%4 == 0 { dom, _ := makeDomain() fqdns = append(fqdns, dom) } else { domCount := 1 << (maskBits % nibbleLen) + fqdns = make([]dnsname.FQDN, 0, domCount) for i := range domCount { varNibble := fmt.Sprintf("%x", i) + dom, err := makeDomain(varNibble) if err != nil { continue } + fqdns = append(fqdns, dom) } } diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index b492e4d6..54306136 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -14,6 +14,7 @@ func TestNormaliseHostname(t *testing.T) { type args struct { name string } + tests := []struct { name string args args @@ -90,6 +91,7 @@ func TestNormaliseHostname(t *testing.T) { t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if !tt.wantErr && got != tt.want { t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want) } @@ -172,6 +174,7 @@ func TestValidateHostname(t *testing.T) { t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr && tt.errorContains != "" { if err == nil || !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains) diff --git a/hscontrol/util/file.go b/hscontrol/util/file.go index 86af636c..f6b09838 100644 --- a/hscontrol/util/file.go +++ b/hscontrol/util/file.go @@ -21,6 +21,9 @@ const ( PermissionFallback = 0o700 ) +// ErrDirectoryPermission is returned when creating a directory fails due to permission issues. +var ErrDirectoryPermission = errors.New("creating directory failed with permission error") + func AbsolutePathFromConfigPath(path string) string { // If a relative path is provided, prefix it with the directory where // the config file was found. @@ -42,18 +45,15 @@ func GetFileMode(key string) fs.FileMode { return PermissionFallback } - return fs.FileMode(mode) + return fs.FileMode(mode) //nolint:gosec // file mode is bounded by ParseUint } func EnsureDir(dir string) error { - if _, err := os.Stat(dir); os.IsNotExist(err) { + if _, err := os.Stat(dir); os.IsNotExist(err) { //nolint:noinlineerr err := os.MkdirAll(dir, PermissionFallback) if err != nil { if errors.Is(err, os.ErrPermission) { - return fmt.Errorf( - "creating directory %s, failed with permission error, is it located somewhere Headscale can write?", - dir, - ) + return fmt.Errorf("%w: %s", ErrDirectoryPermission, dir) } return fmt.Errorf("creating directory %s: %w", dir, err) diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go index f28cd4a3..03de9c34 100644 --- a/hscontrol/util/log.go +++ b/hscontrol/util/log.go @@ -87,5 +87,6 @@ func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...a if l.ParameterizedQueries { return sql, nil } + return sql, params } diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go index 098f1979..5f0adede 100644 --- a/hscontrol/util/prompt.go +++ b/hscontrol/util/prompt.go @@ -14,11 +14,14 @@ func YesNo(msg string) bool { fmt.Fprint(os.Stderr, msg+" [y/n] ") var resp string - fmt.Scanln(&resp) + + _, _ = fmt.Scanln(&resp) + resp = strings.ToLower(resp) switch resp { case "y", "yes", "sure": return true } + return false } diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go index d726ec60..ac405f8c 100644 --- a/hscontrol/util/prompt_test.go +++ b/hscontrol/util/prompt_test.go @@ -86,7 +86,8 @@ func TestYesNo(t *testing.T) { // Write test input go func() { defer w.Close() - w.WriteString(tt.input) + + _, _ = w.WriteString(tt.input) }() // Call the function @@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check the result @@ -104,10 +106,12 @@ func TestYesNo(t *testing.T) { // Check that the prompt was written to stderr var stderrBuf bytes.Buffer - io.Copy(&stderrBuf, stderrR) + + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() expectedPrompt := "Test question [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -130,7 +134,8 @@ func TestYesNoPromptMessage(t *testing.T) { // Write test input go func() { defer w.Close() - w.WriteString("n\n") + + _, _ = w.WriteString("n\n") }() // Call the function with a custom message @@ -140,14 +145,17 @@ func TestYesNoPromptMessage(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check that the custom message was included in the prompt var stderrBuf bytes.Buffer - io.Copy(&stderrBuf, stderrR) + + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() expectedPrompt := customMessage + " [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -186,7 +194,8 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Write test input go func() { defer w.Close() - w.WriteString(tc.input) + + _, _ = w.WriteString(tc.input) }() // Call the function @@ -195,10 +204,11 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Drain stderr - io.Copy(io.Discard, stderrR) + _, _ = io.Copy(io.Discard, stderrR) stderrR.Close() if result != tc.expected { diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index d1d7ece7..4390888d 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -17,7 +17,7 @@ func GenerateRandomBytes(n int) ([]byte, error) { bytes := make([]byte, n) // Note that err == nil only if we read len(b) bytes. - if _, err := rand.Read(bytes); err != nil { + if _, err := rand.Read(bytes); err != nil { //nolint:noinlineerr return nil, err } @@ -33,6 +33,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) { b, err := GenerateRandomBytes(n) uenc := base64.RawURLEncoding.EncodeToString(b) + return uenc[:n], err } @@ -42,13 +43,17 @@ func GenerateRandomStringURLSafe(n int) (string, error) { // number generator fails to function correctly, in which // case the caller should not continue. func GenerateRandomStringDNSSafe(size int) (string, error) { - var str string - var err error + var ( + str string + err error + ) + for len(str) < size { str, err = GenerateRandomStringURLSafe(size) if err != nil { return "", err } + str = strings.ToLower( strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""), ) @@ -99,6 +104,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string { DstIPs: %v } `, rule.SrcIPs, rule.DstPorts)) + if index < len(rules)-1 { sb.WriteString(", ") } diff --git a/hscontrol/util/test.go b/hscontrol/util/test.go index d93ae1f2..b7be5825 100644 --- a/hscontrol/util/test.go +++ b/hscontrol/util/test.go @@ -33,7 +33,7 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { return x.String() == y.String() }) -var ViewSliceIPProtoComparer = cmp.Comparer(func(a, b views.Slice[ipproto.Proto]) bool { return views.SliceEqual(a, b) }) +var ViewSliceIPProtoComparer = cmp.Comparer(views.SliceEqual[ipproto.Proto]) var Comparers []cmp.Option = []cmp.Option{ IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, ViewSliceIPProtoComparer, diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index b8109217..cbce663b 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -16,6 +16,15 @@ import ( "tailscale.com/util/cmpver" ) +// URL parsing errors. +var ( + ErrMultipleURLsFound = errors.New("multiple URLs found") + ErrNoURLFound = errors.New("no URL found") + ErrEmptyTracerouteOutput = errors.New("empty traceroute output") + ErrTracerouteHeaderParse = errors.New("parsing traceroute header") + ErrTracerouteDidNotReach = errors.New("traceroute did not reach target") +) + func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { if cmpver.Compare(minimum, toCheck) <= 0 || toCheck == "unstable" || @@ -30,20 +39,22 @@ func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { // It returns an error if not exactly one URL is found. func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { lines := strings.Split(output, "\n") + var urlStr string for _, line := range lines { line = strings.TrimSpace(line) if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") { if urlStr != "" { - return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) + return nil, fmt.Errorf("%w: %s and %s", ErrMultipleURLsFound, urlStr, line) } + urlStr = line } } if urlStr == "" { - return nil, errors.New("no URL found") + return nil, ErrNoURLFound } loginURL, err := url.Parse(urlStr) @@ -89,14 +100,15 @@ type Traceroute struct { func ParseTraceroute(output string) (Traceroute, error) { lines := strings.Split(strings.TrimSpace(output), "\n") if len(lines) < 1 { - return Traceroute{}, errors.New("empty traceroute output") + return Traceroute{}, ErrEmptyTracerouteOutput } // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`) + headerMatches := headerRegex.FindStringSubmatch(lines[0]) if len(headerMatches) < 2 { - return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) + return Traceroute{}, fmt.Errorf("%w: %s", ErrTracerouteHeaderParse, lines[0]) } hostname := headerMatches[1] @@ -105,6 +117,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if ipStr == "" { ipStr = headerMatches[3] } + ip, err := netip.ParseAddr(ipStr) if err != nil { return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) @@ -144,19 +157,23 @@ func ParseTraceroute(output string) (Traceroute, error) { } remainder := strings.TrimSpace(matches[2]) - var hopHostname string - var hopIP netip.Addr - var latencies []time.Duration + + var ( + hopHostname string + hopIP netip.Addr + latencies []time.Duration + ) // Check for Windows tracert format which has latencies before hostname // Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]" latencyFirst := false + if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") { // Check if latencies appear before any hostname/IP firstSpace := strings.Index(remainder, " ") if firstSpace > 0 { firstPart := remainder[:firstSpace] - if _, err := strconv.ParseFloat(strings.TrimPrefix(firstPart, "<"), 64); err == nil { + if _, err := strconv.ParseFloat(strings.TrimPrefix(firstPart, "<"), 64); err == nil { //nolint:noinlineerr latencyFirst = true } } @@ -171,12 +188,14 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Extract and remove the latency from the beginning latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues duration := time.Duration(ms * float64(time.Millisecond)) latencies = append(latencies, duration.Round(time.Microsecond)) } + remainder = strings.TrimSpace(remainder[latMatch[1]:]) } } @@ -202,9 +221,10 @@ func ParseTraceroute(output string) (Traceroute, error) { parts := strings.Fields(remainder) if len(parts) > 0 { hopHostname = parts[0] - if ip, err := netip.ParseAddr(parts[0]); err == nil { + if ip, err := netip.ParseAddr(parts[0]); err == nil { //nolint:noinlineerr hopIP = ip } + remainder = strings.TrimSpace(strings.Join(parts[1:], " ")) } } @@ -216,6 +236,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if len(match) > 1 { // Remove '<' prefix if present (e.g., "<1 ms") latStr := strings.TrimPrefix(match[1], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues @@ -243,7 +264,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // If we didn't reach the target, it's unsuccessful if !result.Success { - result.Err = errors.New("traceroute did not reach target") + result.Err = ErrTracerouteDidNotReach } return result, nil @@ -261,11 +282,11 @@ func IsCI() bool { return false } -// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults +// EnsureHostname guarantees a valid hostname for node registration. +// It extracts a hostname from Hostinfo, providing sensible defaults // if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences // and ensures nodes always have a valid hostname. // The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123). -// EnsureHostname guarantees a valid hostname for node registration. // This function never fails - it always returns a valid hostname. // // Strategy: @@ -280,15 +301,19 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri if key == "" { return "unknown-node" } + keyPrefix := key if len(key) > 8 { keyPrefix = key[:8] } - return fmt.Sprintf("node-%s", keyPrefix) + + return "node-" + keyPrefix } lowercased := strings.ToLower(hostinfo.Hostname) - if err := ValidateHostname(lowercased); err == nil { + + err := ValidateHostname(lowercased) + if err == nil { return lowercased } diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 33f27b7a..2dafc921 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1,7 +1,6 @@ package util import ( - "errors" "net/netip" "strings" "testing" @@ -11,11 +10,14 @@ import ( "tailscale.com/tailcfg" ) +const testUnknownNode = "unknown-node" + func TestTailscaleVersionNewerOrEqual(t *testing.T) { type args struct { minimum string toCheck string } + tests := []struct { name string args args @@ -180,6 +182,7 @@ Success.`, if err != nil { t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr) } + if gotURL.String() != tt.wantURL { t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL) } @@ -321,7 +324,7 @@ func TestParseTraceroute(t *testing.T) { }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteDidNotReach, }, wantErr: false, }, @@ -489,7 +492,7 @@ over a maximum of 30 hops: }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteDidNotReach, }, wantErr: false, }, @@ -834,7 +837,7 @@ func TestEnsureHostname(t *testing.T) { hostinfo: nil, machineKey: "", nodeKey: "", - want: "unknown-node", + want: testUnknownNode, }, { name: "empty_hostname_with_machine_key", @@ -861,7 +864,7 @@ func TestEnsureHostname(t *testing.T) { }, machineKey: "", nodeKey: "", - want: "unknown-node", + want: testUnknownNode, }, { name: "hostname_exactly_63_chars", @@ -902,7 +905,7 @@ func TestEnsureHostname(t *testing.T) { { name: "hostname_with_unicode", hostinfo: &tailcfg.Hostinfo{ - Hostname: "node-ñoño-测试", + Hostname: "node-ñoño-测试", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -983,7 +986,7 @@ func TestEnsureHostname(t *testing.T) { { name: "chinese_chars_with_dash_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "server-北京-01", + Hostname: "server-北京-01", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -992,7 +995,7 @@ func TestEnsureHostname(t *testing.T) { { name: "chinese_only_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -1010,7 +1013,7 @@ func TestEnsureHostname(t *testing.T) { { name: "mixed_chinese_emoji_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "测试💻机器", + Hostname: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -1066,6 +1069,7 @@ func TestEnsureHostname(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { @@ -1099,13 +1103,15 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { machineKey: "mkey12345678", nodeKey: "nkey12345678", wantHostname: "test-node", - checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { //nolint:thelper if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if hi.Hostname != "test-node" { t.Errorf("hostname = %v, want test-node", hi.Hostname) } + if hi.OS != "linux" { t.Errorf("OS = %v, want linux", hi.OS) } @@ -1143,10 +1149,11 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { machineKey: "", nodeKey: "nkey12345678", wantHostname: "node-nkey1234", - checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { //nolint:thelper if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if hi.Hostname != "node-nkey1234" { t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname) } @@ -1157,12 +1164,13 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { hostinfo: nil, machineKey: "", nodeKey: "", - wantHostname: "unknown-node", - checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + wantHostname: testUnknownNode, + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { //nolint:thelper if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } - if hi.Hostname != "unknown-node" { + + if hi.Hostname != testUnknownNode { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } }, @@ -1174,12 +1182,13 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { }, machineKey: "", nodeKey: "", - wantHostname: "unknown-node", - checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + wantHostname: testUnknownNode, + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { //nolint:thelper if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } - if hi.Hostname != "unknown-node" { + + if hi.Hostname != testUnknownNode { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } }, @@ -1196,22 +1205,27 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { machineKey: "mkey12345678", nodeKey: "nkey12345678", wantHostname: "test", - checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { //nolint:thelper if hi == nil { t.Error("hostinfo should not be nil") } - if hi.Hostname != "test" { + + if hi.Hostname != "test" { //nolint:staticcheck // SA5011: nil check is above t.Errorf("hostname = %v, want test", hi.Hostname) } + if hi.OS != "windows" { t.Errorf("OS = %v, want windows", hi.OS) } + if hi.OSVersion != "10.0.19044" { t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion) } + if hi.DeviceModel != "test-device" { t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel) } + if hi.BackendLogID != "log123" { t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID) } @@ -1225,11 +1239,12 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { machineKey: "mkey12345678", nodeKey: "nkey12345678", wantHostname: "123456789012345678901234567890123456789012345678901234567890123", - checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { //nolint:thelper if hi == nil { t.Error("hostinfo should not be nil") } - if len(hi.Hostname) != 63 { + + if len(hi.Hostname) != 63 { //nolint:staticcheck // SA5011: nil check is above t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) } }, @@ -1239,6 +1254,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { @@ -1264,7 +1280,10 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { for i, hostname := range testCases { t.Run(cmp.Diff("", ""), func(t *testing.T) { + t.Parallel() + hostinfo := &tailcfg.Hostinfo{Hostname: hostname} + result := EnsureHostname(hostinfo, "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) diff --git a/integration/acl_test.go b/integration/acl_test.go index c746f900..4f4c3eed 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1877,7 +1877,7 @@ func TestACLAutogroupSelf(t *testing.T) { result, err := client.Curl(url) assert.Empty(t, result, "user1 should not be able to access user2's regular devices (autogroup:self isolation)") - assert.Error(t, err, "connection from user1 to user2 regular device should fail") + require.Error(t, err, "connection from user1 to user2 regular device should fail") } } @@ -1896,6 +1896,7 @@ func TestACLAutogroupSelf(t *testing.T) { } } +//nolint:gocyclo // complex integration test scenario func TestACLPolicyPropagationOverTime(t *testing.T) { IntegrationSkip(t) diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index c0631f86..2b5e1726 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -1,6 +1,7 @@ package integration import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -35,6 +36,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -46,6 +48,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Create an API key using the CLI var validAPIKey string + assert.EventuallyWithT(t, func(ct *assert.CollectT) { apiKeyOutput, err := headscale.Execute( []string{ @@ -63,7 +66,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Get the API endpoint endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" // Create HTTP client client := &http.Client{ @@ -76,11 +79,12 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_NoAuthHeader", func(t *testing.T) { // Test 1: Request without any Authorization header // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -99,6 +103,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should NOT contain user data after "Unauthorized" // This is the security bypass - if users array is present, auth was bypassed var jsonCheck map[string]any + jsonErr := json.Unmarshal(body, &jsonCheck) // If we can unmarshal JSON and it contains "users", that's the bypass @@ -126,12 +131,13 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) { // Test 2: Request with invalid Authorization header (missing "Bearer " prefix) // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "InvalidToken") resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -159,12 +165,13 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Test 3: Request with Bearer prefix but invalid token // Expected: Should return 401 with ONLY "Unauthorized" text, no user data // Note: Both malformed and properly formatted invalid tokens should return 401 - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer invalid-token-12345") resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -191,12 +198,13 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_ValidAPIKey", func(t *testing.T) { // Test 4: Request with valid API key // Expected: Should return 200 with user data (this is the authorized case) - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey)) + req.Header.Set("Authorization", "Bearer "+validAPIKey) resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -208,16 +216,19 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should be able to parse as protobuf JSON var response v1.ListUsersResponse + err = protojson.Unmarshal(body, &response) - assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key") + require.NoError(t, err, "Response should be valid protobuf JSON with valid API key") // Should contain our test users users := response.GetUsers() assert.Len(t, users, 3, "Should have 3 users") + userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "user1") assert.Contains(t, userNames, "user2") assert.Contains(t, userNames, "user3") @@ -234,6 +245,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -254,10 +266,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" t.Run("Curl_NoAuth", func(t *testing.T) { // Execute curl from inside the headscale container without auth @@ -274,17 +287,24 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Parse the output lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + + var ( + httpCode string + responseBody string + ) + + var responseBodySb280 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb280.WriteString(line) } } + responseBody += responseBodySb280.String() + // Should return 401 assert.Equal(t, "401", httpCode, "Curl without auth should return 401") @@ -320,17 +340,24 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + + var ( + httpCode string + responseBody string + ) + + var responseBodySb326 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb326.WriteString(line) } } + responseBody += responseBodySb326.String() + assert.Equal(t, "401", httpCode) assert.Contains(t, responseBody, "Unauthorized") assert.NotContains(t, responseBody, "testuser1", @@ -346,7 +373,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { "curl", "-s", "-H", - fmt.Sprintf("Authorization: Bearer %s", validAPIKey), + "Authorization: Bearer " + validAPIKey, "-w", "\nHTTP_CODE:%{http_code}", apiURL, @@ -355,25 +382,34 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + + var ( + httpCode string + responseBody string + ) + + var responseBodySb361 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb361.WriteString(line) } } + responseBody += responseBodySb361.String() + // Should succeed assert.Equal(t, "200", httpCode, "Curl with valid API key should return 200") // Should contain user data var response v1.ListUsersResponse + err = protojson.Unmarshal([]byte(responseBody), &response) - assert.NoError(t, err, "Response should be valid protobuf JSON") + require.NoError(t, err, "Response should be valid protobuf JSON") + users := response.GetUsers() assert.Len(t, users, 2, "Should have 2 users") }) @@ -391,6 +427,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -420,11 +457,12 @@ func TestGRPCAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) // Get the gRPC endpoint // For gRPC, we need to use the hostname and port 50443 - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" t.Run("gRPC_NoAPIKey", func(t *testing.T) { // Test 1: Try to use CLI without API key (should fail) @@ -452,7 +490,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { ) // Should fail with authentication error - assert.Error(t, err, + require.Error(t, err, "gRPC connection with invalid API key should fail") // Should contain authentication error message @@ -481,20 +519,22 @@ func TestGRPCAuthenticationBypass(t *testing.T) { ) // Should succeed - assert.NoError(t, err, + require.NoError(t, err, "gRPC connection with valid API key should succeed, output: %s", output) // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) - assert.NoError(t, err, "Response should be valid JSON array") + require.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "grpcuser1") assert.Contains(t, userNames, "grpcuser2") }) @@ -513,6 +553,7 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -540,9 +581,10 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" // Create a config file for testing configWithoutKey := fmt.Sprintf(` @@ -602,7 +644,7 @@ cli: ) // Should fail - assert.Error(t, err, + require.Error(t, err, "CLI with invalid API key should fail") // Should indicate authentication failure @@ -637,20 +679,22 @@ cli: ) // Should succeed - assert.NoError(t, err, + require.NoError(t, err, "CLI with valid API key should succeed") // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) - assert.NoError(t, err, "Response should be valid JSON array") + require.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "cliuser1") assert.Contains(t, userNames, "cliuser2") }) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index ba6a195b..862b7d32 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -31,6 +31,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -69,18 +70,24 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -111,6 +118,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -148,6 +156,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after relogin") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -201,6 +210,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, nodeCountBeforeLogout) @@ -255,10 +265,14 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -301,9 +315,11 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -323,21 +339,24 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created // for the new user. The original nodes remain with the original user. var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) - }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname())) + }, 30*time.Second, 2*time.Second, "validating %s is logged in as user1 after auth key user switch", client.Hostname()) } } @@ -352,6 +371,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -377,11 +397,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } @@ -395,10 +417,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 359dd456..dadaa9a4 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -148,6 +148,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -175,6 +176,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { syncCompleteTime := time.Now() err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) + loginDuration := time.Since(syncCompleteTime) t.Logf("Login and sync completed in %v", loginDuration) @@ -206,6 +208,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check each client's status individually to provide better diagnostics expiredCount := 0 + for _, client := range allClients { status, err := client.Status() if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) { @@ -355,6 +358,7 @@ func TestOIDC024UserCreation(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -412,6 +416,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -469,6 +474,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { oidcMockUser("user1", true), }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -507,6 +513,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -527,9 +534,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) @@ -537,14 +547,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -582,6 +597,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users after user2 login") assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -637,10 +653,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user2's node should be active after user switch var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { if node.GetUser().GetId() == 2 { // user2 activeUser2NodeID = types.NodeID(node.GetId()) t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -654,6 +672,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user2 node is online if node, exists := nodeStore[activeUser2NodeID]; exists { assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User2 node should be online after login") } @@ -746,6 +765,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -815,10 +835,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user1's node should be active after relogin var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { if node.GetUser().GetId() == 1 { // user1 activeUser1NodeID = types.NodeID(node.GetId()) t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -832,6 +854,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after relogin") } @@ -906,6 +929,7 @@ func TestOIDCFollowUpUrl(t *testing.T) { time.Sleep(2 * time.Minute) var newUrl *url.URL + assert.EventuallyWithT(t, func(c *assert.CollectT) { st, err := ts.Status() assert.NoError(c, err) @@ -1029,7 +1053,7 @@ func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { require.NotEqual(t, redirect1.String(), redirect2.String()) // complete auth with the first opened "browser tab" - _, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) + _, _, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) require.NoError(t, err) listUsers, err = headscale.ListUsers() @@ -1106,6 +1130,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { oidcMockUser("user1", true), // Relogin with same user }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -1145,6 +1170,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1165,9 +1191,12 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var initialNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + initialNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes)) @@ -1175,14 +1204,19 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -1239,6 +1273,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1259,6 +1294,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") var finalNodes []*v1.Node + t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { finalNodes, err = headscale.ListNodes() @@ -1282,6 +1318,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Security validation: user1's node should be active after relogin activeUser1NodeID := types.NodeID(finalNodes[0].GetId()) + t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat)) require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() @@ -1290,6 +1327,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin") } @@ -1359,6 +1397,7 @@ func TestOIDCExpiryAfterRestart(t *testing.T) { // Verify initial expiry is set var initialExpiry time.Time + assert.EventuallyWithT(t, func(ct *assert.CollectT) { nodes, err := headscale.ListNodes() assert.NoError(ct, err) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 5dd546f3..eba2ebbf 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,7 +1,6 @@ package integration import ( - "fmt" "net/netip" "slices" "testing" @@ -67,6 +66,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -106,22 +106,27 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } @@ -152,6 +157,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web flow logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -226,6 +232,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -240,9 +247,13 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { allClients, err := scenario.ListTailscaleClients() requireNoErrListClients(t, err) - allIps, err := scenario.ListTailscaleClientsIPs() + var allIps []netip.Addr + + allIps, err = scenario.ListTailscaleClientsIPs() requireNoErrListClientIPs(t, err) + _ = allIps // used below after user switch + err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) @@ -256,13 +267,16 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after initial web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -299,7 +313,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Register all clients as user1 (this is where cross-user registration happens) // This simulates: headscale nodes register --user user1 --key - scenario.runHeadscaleRegister("user1", body) + _ = scenario.runHeadscaleRegister("user1", body) } // Wait for all clients to reach running state @@ -313,9 +327,11 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { t.Logf("all clients logged back in as user1") var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -333,21 +349,24 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Validate that user2's old nodes still exist in database (but are expired/offline) // When CLI registration creates new nodes for user1, user2's old nodes remain var user2Nodes []*v1.Node + t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1") t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) - }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname())) + }, 30*time.Second, 2*time.Second, "validating %s is logged in as user1 after web flow user switch", client.Hostname()) } // Test connectivity after user switch diff --git a/integration/cli_test.go b/integration/cli_test.go index 65d82444..a1174277 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -203,7 +203,7 @@ func TestUserCommand(t *testing.T) { "--identifier=1", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterIDDelete []*v1.User @@ -245,7 +245,7 @@ func TestUserCommand(t *testing.T) { "--name=newname", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterNameDelete []v1.User @@ -571,7 +571,9 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { IntegrationSkip(t) + //nolint:goconst // test data, not worth extracting user1 := "user1" + //nolint:goconst // test data, not worth extracting user2 := "user2" spec := ScenarioSpec{ @@ -829,7 +831,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -907,7 +909,7 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - assert.NoError(t, err) + require.NoError(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } @@ -952,7 +954,7 @@ func TestApiKeyCommand(t *testing.T) { "--prefix", listedAPIKeys[0].GetPrefix(), }) - assert.NoError(t, err) + require.NoError(t, err) var listedAPIKeysAfterDelete []v1.ApiKey @@ -1071,7 +1073,7 @@ func TestNodeCommand(t *testing.T) { } nodes := make([]*v1.Node, len(regIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1089,7 +1091,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1156,7 +1158,7 @@ func TestNodeCommand(t *testing.T) { } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range otherUserRegIDs { _, err := headscale.Execute( @@ -1174,7 +1176,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1281,7 +1283,7 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - assert.NoError(t, err) + require.NoError(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node @@ -1348,7 +1350,7 @@ func TestNodeExpireCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1411,7 +1413,7 @@ func TestNodeExpireCommand(t *testing.T) { strconv.FormatUint(listAll[idx].GetId(), 10), }, ) - assert.NoError(t, err) + require.NoError(t, err) } var listAllAfterExpiry []v1.Node @@ -1467,7 +1469,7 @@ func TestNodeRenameCommand(t *testing.T) { } nodes := make([]*v1.Node, len(regIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1549,7 +1551,7 @@ func TestNodeRenameCommand(t *testing.T) { fmt.Sprintf("newnode-%d", idx+1), }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, res, "Node renamed") } @@ -1590,7 +1592,7 @@ func TestNodeRenameCommand(t *testing.T) { strings.Repeat("t", 64), }, ) - assert.ErrorContains(t, err, "must not exceed 63 characters") + require.ErrorContains(t, err, "must not exceed 63 characters") var listAllAfterRenameAttempt []v1.Node @@ -1658,7 +1660,7 @@ func TestPolicyCommand(t *testing.T) { }, } - pBytes, _ := json.Marshal(p) + pBytes, _ := json.Marshal(p) //nolint:errchkjson policyFilePath := "/etc/headscale/policy.json" @@ -1745,7 +1747,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { }, } - pBytes, _ := json.Marshal(p) + pBytes, _ := json.Marshal(p) //nolint:errchkjson policyFilePath := "/etc/headscale/policy.json" @@ -1763,7 +1765,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, `invalid action "unknown-action"`) + require.ErrorContains(t, err, `invalid ACL action: "unknown-action"`) // The new policy was invalid, the old one should still be in place, which // is none. diff --git a/integration/control.go b/integration/control.go index 58a061e3..f390d080 100644 --- a/integration/control.go +++ b/integration/control.go @@ -15,8 +15,8 @@ import ( type ControlServer interface { Shutdown() (string, string, error) - SaveLog(string) (string, string, error) - SaveProfile(string) error + SaveLog(path string) (string, string, error) + SaveProfile(path string) error Execute(command []string) (string, error) WriteFile(path string, content []byte) error ConnectToNetwork(network *dockertest.Network) error @@ -35,12 +35,12 @@ type ControlServer interface { ListUsers() ([]*v1.User, error) MapUsers() (map[string]*v1.User, error) DeleteUser(userID uint64) error - ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) + ApproveRoutes(nodeID uint64, routes []netip.Prefix) (*v1.Node, error) SetNodeTags(nodeID uint64, tags []string) error GetCert() []byte GetHostname() string GetIPInNetwork(network *dockertest.Network) string - SetPolicy(*policyv2.Policy) error + SetPolicy(pol *policyv2.Policy) error GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) PrimaryRoutes() (*routes.DebugRoutes, error) DebugBatcher() (*hscontrol.DebugBatcherInfo, error) diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 60260bb1..1cc87de3 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -25,6 +25,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Generate random hostname for the headscale instance hash, err := util.GenerateRandomStringDNSSafe(6) require.NoError(t, err) + testName := "derpverify" hostname := fmt.Sprintf("hs-%s-%s", testName, hash) @@ -40,6 +41,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -104,13 +106,16 @@ func DERPVerify( defer c.Close() var result error - if err := c.Connect(t.Context()); err != nil { + + err := c.Connect(t.Context()) + if err != nil { result = fmt.Errorf("client Connect: %w", err) } - if m, err := c.Recv(); err != nil { + + if m, err := c.Recv(); err != nil { //nolint:noinlineerr result = fmt.Errorf("client first Recv: %w", err) } else if v, ok := m.(derp.ServerInfoMessage); !ok { - result = fmt.Errorf("client first Recv was unexpected type %T", v) + result = fmt.Errorf("client first Recv was unexpected type %T", v) //nolint:err113 } if expectSuccess && result != nil { diff --git a/integration/dns_test.go b/integration/dns_test.go index e937a421..648f2049 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -86,14 +86,13 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { const erPath = "/tmp/extra_records.json" - extraRecords := []tailcfg.DNSRecord{ - { - Name: "test.myvpn.example.com", - Type: "A", - Value: "6.6.6.6", - }, - } - b, _ := json.Marshal(extraRecords) + extraRecords := make([]tailcfg.DNSRecord, 0, 2) + extraRecords = append(extraRecords, tailcfg.DNSRecord{ + Name: "test.myvpn.example.com", + Type: "A", + Value: "6.6.6.6", + }) + b, _ := json.Marshal(extraRecords) //nolint:errchkjson err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithPackages("python3", "curl", "bind-tools"), @@ -133,7 +132,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { require.NoError(t, err) // Write the file directly into place from the docker API. - b0, _ := json.Marshal([]tailcfg.DNSRecord{ + b0, _ := json.Marshal([]tailcfg.DNSRecord{ //nolint:errchkjson { Name: "docker.myvpn.example.com", Type: "A", @@ -155,7 +154,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { Type: "A", Value: "7.7.7.7", }) - b2, _ := json.Marshal(extraRecords) + b2, _ := json.Marshal(extraRecords) //nolint:errchkjson err = hs.WriteFile(erPath+"2", b2) require.NoError(t, err) @@ -169,7 +168,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Write a new file and copy it to the path to ensure the reload // works when a file is copied into place. - b3, _ := json.Marshal([]tailcfg.DNSRecord{ + b3, _ := json.Marshal([]tailcfg.DNSRecord{ //nolint:errchkjson { Name: "copy.myvpn.example.com", Type: "A", @@ -187,7 +186,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { } // Write in place to ensure pipe like behaviour works - b4, _ := json.Marshal([]tailcfg.DNSRecord{ + b4, _ := json.Marshal([]tailcfg.DNSRecord{ //nolint:errchkjson { Name: "docker.myvpn.example.com", Type: "A", diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index c0c57a3e..88b2712c 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -34,6 +34,7 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { if opts.Labels == nil { opts.Labels = make(map[string]string) } + opts.Labels["hi.run-id"] = runID opts.Labels["hi.test-type"] = testType } diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index b09e0d40..7f1d0efb 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -38,9 +38,10 @@ type buffer struct { // Write appends the contents of p to the buffer, growing the buffer as needed. It returns // the number of bytes written. -func (b *buffer) Write(p []byte) (n int, err error) { +func (b *buffer) Write(p []byte) (int, error) { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.Write(p) } @@ -49,6 +50,7 @@ func (b *buffer) Write(p []byte) (n int, err error) { func (b *buffer) String() string { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.String() } @@ -66,7 +68,8 @@ func ExecuteCommand( } for _, opt := range options { - if err := opt(&execConfig); err != nil { + err := opt(&execConfig) + if err != nil { return "", "", fmt.Errorf("execute-command/options: %w", err) } } @@ -105,7 +108,6 @@ func ExecuteCommand( // log.Println("Command: ", cmd) // log.Println("stdout: ", stdout.String()) // log.Println("stderr: ", stderr.String()) - return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) } diff --git a/integration/dockertestutil/logs.go b/integration/dockertestutil/logs.go index 7d104e43..3cd3b7a1 100644 --- a/integration/dockertestutil/logs.go +++ b/integration/dockertestutil/logs.go @@ -47,6 +47,7 @@ func SaveLog( } var stdout, stderr bytes.Buffer + err = WriteLog(pool, resource, &stdout, &stderr) if err != nil { return "", "", err @@ -55,6 +56,7 @@ func SaveLog( log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) stdoutPath := path.Join(basePath, resource.Container.Name+".stdout.log") + err = os.WriteFile( stdoutPath, stdout.Bytes(), @@ -65,6 +67,7 @@ func SaveLog( } stderrPath := path.Join(basePath, resource.Container.Name+".stderr.log") + err = os.WriteFile( stderrPath, stderr.Bytes(), diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 42483247..ab049abf 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -18,8 +18,9 @@ func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Ne if err != nil { return nil, fmt.Errorf("looking up network names: %w", err) } + if len(networks) == 0 { - if _, err := pool.CreateNetwork(name); err == nil { + if _, err := pool.CreateNetwork(name); err == nil { //nolint:noinlineerr // intentional inline check // Create does not give us an updated version of the resource, so we need to // get it again. networks, err := pool.NetworksByName(name) @@ -90,6 +91,7 @@ func RandomFreeHostPort() (int, error) { // CleanUnreferencedNetworks removes networks that are not referenced by any containers. func CleanUnreferencedNetworks(pool *dockertest.Pool) error { filter := "name=hs-" + networks, err := pool.NetworksByName(filter) if err != nil { return fmt.Errorf("getting networks by filter %q: %w", filter, err) @@ -122,6 +124,7 @@ func CleanImagesInCI(pool *dockertest.Pool) error { } removedCount := 0 + for _, image := range images { // Only remove dangling (untagged) images to avoid forcing rebuilds // Dangling images have no RepoTags or only have ":" diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index 36a120bc..e25e2bc4 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -159,10 +159,12 @@ func New( } else { hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) } + tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname) if err != nil { return nil, fmt.Errorf("creating certificates for headscale test: %w", err) } + dsic := &DERPServerInContainer{ version: version, hostname: hostname, @@ -185,6 +187,7 @@ func New( fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort) fmt.Fprintf(&cmdArgs, " --stun=true") fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort) + if dsic.withVerifyClientURL != "" { fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL) } @@ -214,11 +217,13 @@ func New( } var container *dockertest.Resource + buildOptions := &dockertest.BuildOptions{ Dockerfile: "Dockerfile.derper", ContextDir: dockerContextPath, BuildArgs: []docker.BuildArg{}, } + switch version { case "head": buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ @@ -249,6 +254,7 @@ func New( err, ) } + log.Printf("Created %s container\n", hostname) dsic.container = container @@ -259,12 +265,14 @@ func New( return nil, fmt.Errorf("writing TLS certificate to container: %w", err) } } + if len(dsic.tlsCert) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert) if err != nil { return nil, fmt.Errorf("writing TLS certificate to container: %w", err) } } + if len(dsic.tlsKey) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey) if err != nil { diff --git a/integration/helpers.go b/integration/helpers.go index 46e571ae..fa96bcac 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -3,9 +3,12 @@ package integration import ( "bufio" "bytes" + "errors" "fmt" "io" + "maps" "net/netip" + "slices" "strconv" "strings" "sync" @@ -23,8 +26,6 @@ import ( "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "tailscale.com/tailcfg" "tailscale.com/types/ptr" ) @@ -46,9 +47,16 @@ const ( // TimestampFormatRunID is used for generating unique run identifiers // Format: "20060102-150405" provides compact date-time for file/directory names. TimestampFormatRunID = "20060102-150405" + + // stateOnline is the string representation for online state in logs. + stateOnline = "online" + // stateOffline is the string representation for offline state in logs. + stateOffline = "offline" ) -// NodeSystemStatus represents the status of a node across different systems +var errNoNewClientFound = errors.New("no new client found") + +// NodeSystemStatus represents the status of a node across different systems. type NodeSystemStatus struct { Batcher bool BatcherConnCount int @@ -105,7 +113,7 @@ func requireNoErrLogout(t *testing.T, err error) { require.NoError(t, err, "failed to log out tailscale nodes") } -// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes. func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { t.Helper() @@ -114,8 +122,10 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod status := client.MustStatus() nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) } + return expectedNodes } @@ -149,15 +159,17 @@ func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNode } // requireAllClientsOnline validates that all nodes are online/offline across all headscale systems -// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems. func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() startTime := time.Now() - stateStr := "offline" + + stateStr := stateOffline if expectedOnline { - stateStr = "online" + stateStr = stateOnline } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) if expectedOnline { @@ -165,22 +177,26 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode requireAllClientsOnlineWithSingleTimeout(t, headscale, expectedNodes, expectedOnline, message, timeout) } else { // For offline validation, use staged approach with component-specific timeouts - requireAllClientsOfflineStaged(t, headscale, expectedNodes, message, timeout) + requireAllClientsOfflineStaged(t, headscale, expectedNodes) } endTime := time.Now() t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) } -// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state. +// +//nolint:gocyclo // complex validation with multiple node states func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { // Get batcher state debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } @@ -188,6 +204,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get map responses mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -195,6 +212,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -265,6 +283,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if id == nodeID { continue // Skip self-references } + expectedPeerMaps++ if online, exists := peerMap[nodeID]; exists && online { @@ -279,6 +298,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer } } } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") // Update status with map response data @@ -302,10 +322,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Verify all systems show nodes in expected state and report failures allMatch := true + var failureReport strings.Builder - ids := types.NodeIDs(maps.Keys(nodeStatus)) + ids := types.NodeIDs(slices.AppendSeq(make([]types.NodeID, 0, len(nodeStatus)), maps.Keys(nodeStatus))) slices.Sort(ids) + for _, nodeID := range ids { status := nodeStatus[nodeID] systemsMatch := (status.Batcher == expectedOnline) && @@ -314,10 +336,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if !systemsMatch { allMatch = false - stateStr := "offline" + + stateStr := stateOffline if expectedOnline { - stateStr = "online" + stateStr = stateOnline } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) @@ -332,6 +356,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer t.Logf("Previous report:\n%s", prevReport) t.Logf("Current report:\n%s", failureReport.String()) t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() } @@ -341,16 +366,17 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer assert.Fail(c, failureReport.String()) } - stateStr := "offline" + stateStr := stateOffline if expectedOnline { - stateStr = "online" + stateStr = stateOnline } - assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) + + assert.True(c, allMatch, "Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr) }, timeout, 2*time.Second, message) } -// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components -func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components. +func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() // Stage 1: Verify batcher disconnection (should be immediate) @@ -358,18 +384,22 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } allBatcherOffline := true + for _, nodeID := range expectedNodes { nodeIDStr := fmt.Sprintf("%d", nodeID) if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) } } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") }, 15*time.Second, 1*time.Second, "batcher disconnection validation") @@ -378,20 +408,24 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } allNodeStoreOffline := true + for _, nodeID := range expectedNodes { if node, exists := nodeStore[nodeID]; exists { isOnline := node.IsOnline != nil && *node.IsOnline if isOnline { allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) } } } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") }, 20*time.Second, 1*time.Second, "nodestore offline validation") @@ -400,6 +434,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -412,7 +447,8 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec for nodeID := range onlineMap { if slices.Contains(expectedNodes, nodeID) { allMapResponsesOffline = false - assert.False(c, true, "Node %d should not appear in map responses", nodeID) + + assert.Fail(c, fmt.Sprintf("Node %d should not appear in map responses", nodeID)) } } } else { @@ -422,13 +458,16 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec if id == nodeID { continue // Skip self-references } + if online, exists := peerMap[nodeID]; exists && online { allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) } } } } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") }, 60*time.Second, 2*time.Second, "map response propagation validation") @@ -438,6 +477,8 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec // requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database // and a valid DERP server based on the NetInfo. This function follows the pattern of // requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. +// +//nolint:unparam // timeout is configurable for flexibility even though callers currently use same value func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { t.Helper() @@ -448,6 +489,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -462,12 +504,14 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe for _, nodeID := range expectedNodes { node, exists := nodeStore[nodeID] assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { continue } // Validate that the node has Hostinfo assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -475,6 +519,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has NetInfo assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -482,7 +527,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has a valid DERP server (PreferredDERP should be > 0) preferredDERP := node.Hostinfo.NetInfo.PreferredDERP - assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) + assert.Positive(c, preferredDERP, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) } @@ -496,6 +541,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // assertLastSeenSet validates that a node has a non-nil LastSeen timestamp. // Critical for ensuring node activity tracking is functioning properly. func assertLastSeenSet(t *testing.T, node *v1.Node) { + t.Helper() assert.NotNil(t, node) assert.NotNil(t, node.GetLastSeen()) } @@ -514,7 +560,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { for _, client := range clients { status, err := client.Status() - assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) + assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) //nolint:testifylint // assert.TestingT interface assert.Equal(t, "NeedsLogin", status.BackendState, "client %s should be logged out", client.Hostname()) } @@ -523,8 +569,11 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { // pingAllHelper performs ping tests between all clients and addresses, returning success count. // This is used to validate network connectivity in integration tests. // Returns the total number of successful ping operations. +// +//nolint:unparam // opts is variadic for extensibility even though callers currently don't pass options func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { t.Helper() + success := 0 for _, client := range clients { @@ -546,6 +595,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts // for validating NAT traversal and relay functionality. Returns success count. func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { t.Helper() + success := 0 for _, client := range clients { @@ -596,6 +646,8 @@ func isSelfClient(client TailscaleClient, addr string) bool { // assertClientsState validates the status and netmap of a list of clients for general connectivity. // Runs parallel validation of status, netcheck, and netmap for all clients to ensure // they have proper network configuration for all-to-all connectivity tests. +// +//nolint:unused func assertClientsState(t *testing.T, clients []TailscaleClient) { t.Helper() @@ -603,9 +655,12 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) { for _, client := range clients { wg.Add(1) + c := client // Avoid loop pointer + go func() { defer wg.Done() + assertValidStatus(t, c) assertValidNetcheck(t, c) assertValidNetmap(t, c) @@ -620,6 +675,8 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) { // Checks self node and all peers for essential networking data including hostinfo, addresses, // endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56. // This test is not suitable for ACL/partial connection tests. +// +//nolint:unused func assertValidNetmap(t *testing.T, client TailscaleClient) { t.Helper() @@ -636,6 +693,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { assert.NoError(c, err, "getting netmap for %q", client.Hostname()) assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) } @@ -650,10 +708,11 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) for _, peer := range netmap.Peers { - assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) + assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) //nolint:staticcheck // SA1019: testing legacy field assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) @@ -680,8 +739,11 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { // assertValidStatus validates that a client's status has all required fields for proper operation. // Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints, // and network map presence. This test is not suitable for ACL/partial connection tests. +// +//nolint:unused func assertValidStatus(t *testing.T, client TailscaleClient) { t.Helper() + status, err := client.Status(true) if err != nil { t.Fatalf("getting status for %q: %s", client.Hostname(), err) @@ -737,8 +799,11 @@ func assertValidStatus(t *testing.T, client TailscaleClient) { // assertValidNetcheck validates that a client has a proper DERP relay configured. // Ensures the client has discovered and selected a DERP server for relay functionality, // which is essential for NAT traversal and connectivity in restricted networks. +// +//nolint:unused func assertValidNetcheck(t *testing.T, client TailscaleClient) { t.Helper() + report, err := client.Netcheck() if err != nil { t.Fatalf("getting status for %q: %s", client.Hostname(), err) @@ -764,7 +829,7 @@ func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []stri } if !strings.Contains(stdout, contains) { - return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) //nolint:err113 } return struct{}{}, nil @@ -793,6 +858,7 @@ func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { t.Helper() buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) if err != nil { t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) @@ -816,6 +882,7 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) scanner := bufio.NewScanner(in) { const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) scanner.Buffer(buff, len(buff)) scanner.Split(bufio.ScanLines) @@ -885,6 +952,8 @@ func usernameOwner(name string) policyv2.Owner { // groupOwner returns a Group as an Owner for use in TagOwners policies. // Specifies which groups can assign and manage specific tags in ACL configurations. +// +//nolint:unused func groupOwner(name string) policyv2.Owner { return ptr.To(policyv2.Group(name)) } @@ -933,7 +1002,7 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { } } - return nil, fmt.Errorf("user %s not found", username) + return nil, fmt.Errorf("user %s not found", username) //nolint:err113 } // FindNewClient finds a client that is in the new list but not in the original list. @@ -942,17 +1011,20 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) { for _, client := range updated { isOriginal := false + for _, origClient := range original { if client.Hostname() == origClient.Hostname() { isOriginal = true break } } + if !isOriginal { return client, nil } } - return nil, fmt.Errorf("no new client found") + + return nil, errNoNewClientFound } // AddAndLoginClient adds a new tailscale client to a user and logs it in. @@ -960,7 +1032,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) // 1. Creating a new node // 2. Finding the new node in the client list // 3. Getting the user to create a preauth key -// 4. Logging in the new node +// 4. Logging in the new node. func (s *Scenario) AddAndLoginClient( t *testing.T, username string, @@ -992,7 +1064,7 @@ func (s *Scenario) AddAndLoginClient( } if len(updatedClients) != len(originalClients)+1 { - return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients)) + return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients)) //nolint:err113 } newClient, err = FindNewClient(originalClients, updatedClients) @@ -1038,5 +1110,6 @@ func (s *Scenario) MustAddAndLoginClient( client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...) require.NoError(t, err) + return client } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index f1b9feef..3ef4d5d4 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -4,6 +4,7 @@ import ( "archive/tar" "bytes" "cmp" + "context" "crypto/tls" "encoding/json" "errors" @@ -11,6 +12,7 @@ import ( "io" "log" "maps" + "net" "net/http" "net/netip" "os" @@ -46,6 +48,7 @@ const ( tlsKeyPath = "/etc/headscale/tls.key" headscaleDefaultPort = 8080 IntegrationTestDockerFileName = "Dockerfile.integration" + defaultDirPerm = 0o755 ) var ( @@ -198,7 +201,7 @@ func WithPostgres() Option { } } -// WithPolicy sets the policy mode for headscale. +// WithPolicyMode sets the policy mode for headscale. func WithPolicyMode(mode types.PolicyMode) Option { return func(hsic *HeadscaleInContainer) { hsic.policyMode = mode @@ -217,6 +220,8 @@ func WithIPAllocationStrategy(strategy types.IPAllocationStrategy) Option { // and only use the embedded DERP server. // It requires WithTLS and WithHostnameAsServerURL to be // set. +// +//nolint:goconst // env var values like "true" and "headscale" are clearer inline func WithEmbeddedDERPServerOnly() Option { return func(hsic *HeadscaleInContainer) { hsic.env["HEADSCALE_DERP_URLS"] = "" @@ -321,6 +326,8 @@ func (hsic *HeadscaleInContainer) buildEntrypoint() []string { } // New returns a new HeadscaleInContainer instance. +// +//nolint:gocyclo // complex container setup with many options func New( pool *dockertest.Pool, networks []*dockertest.Network, @@ -548,6 +555,7 @@ func New( return nil, fmt.Errorf("starting headscale container: %w\n\nUnable to get diagnostic build output (command may have failed silently)", err) } } + log.Printf("Created %s container\n", hsic.hostname) hsic.container = container @@ -595,7 +603,8 @@ func New( } for _, f := range hsic.filesInContainer { - if err := hsic.WriteFile(f.path, f.contents); err != nil { + err := hsic.WriteFile(f.path, f.contents) + if err != nil { return nil, fmt.Errorf("writing %q: %w", f.path, err) } } @@ -678,7 +687,7 @@ func (t *HeadscaleInContainer) Shutdown() (string, string, error) { // Cleanup postgres container if enabled. if t.postgres { - t.pool.Purge(t.pgContainer) + _ = t.pool.Purge(t.pgContainer) } return stdoutPath, stderrPath, t.pool.Purge(t.container) @@ -697,16 +706,23 @@ func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) { } func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { - resp, err := http.Get(fmt.Sprintf("http://%s:9090/metrics", t.hostname)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://"+net.JoinHostPort(t.hostname, "9090")+"/metrics", nil) + if err != nil { + return fmt.Errorf("creating metrics request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("getting metrics: %w", err) } defer resp.Body.Close() + out, err := os.Create(savePath) if err != nil { return fmt.Errorf("creating file for metrics: %w", err) } defer out.Close() + _, err = io.Copy(out, resp.Body) if err != nil { return fmt.Errorf("copy response to file: %w", err) @@ -717,20 +733,21 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { // extractTarToDirectory extracts a tar archive to a directory. func extractTarToDirectory(tarData []byte, targetDir string) error { - if err := os.MkdirAll(targetDir, 0o755); err != nil { + err := os.MkdirAll(targetDir, defaultDirPerm) + if err != nil { return fmt.Errorf("creating directory %s: %w", targetDir, err) } - tarReader := tar.NewReader(bytes.NewReader(tarData)) - // Find the top-level directory to strip var topLevelDir string + firstPass := tar.NewReader(bytes.NewReader(tarData)) for { header, err := firstPass.Next() if err == io.EOF { break } + if err != nil { return fmt.Errorf("reading tar header: %w", err) } @@ -741,12 +758,13 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { } } - tarReader = tar.NewReader(bytes.NewReader(tarData)) + tarReader := tar.NewReader(bytes.NewReader(tarData)) for { header, err := tarReader.Next() if err == io.EOF { break } + if err != nil { return fmt.Errorf("reading tar header: %w", err) } @@ -775,12 +793,15 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { switch header.Typeflag { case tar.TypeDir: // Create directory - if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + //nolint:gosec // G115: header.Mode is trusted from tar archive + err := os.MkdirAll(targetPath, os.FileMode(header.Mode)) + if err != nil { return fmt.Errorf("creating directory %s: %w", targetPath, err) } case tar.TypeReg: // Ensure parent directories exist - if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + err := os.MkdirAll(filepath.Dir(targetPath), defaultDirPerm) + if err != nil { return fmt.Errorf("creating parent directories for %s: %w", targetPath, err) } @@ -790,14 +811,15 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { return fmt.Errorf("creating file %s: %w", targetPath, err) } - if _, err := io.Copy(outFile, tarReader); err != nil { + if _, err := io.Copy(outFile, tarReader); err != nil { //nolint:gosec,noinlineerr // trusted tar from test container outFile.Close() return fmt.Errorf("copying file contents: %w", err) } + outFile.Close() // Set file permissions - if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { + if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { //nolint:gosec,noinlineerr // safe mode from tar header return fmt.Errorf("setting file permissions: %w", err) } } @@ -844,10 +866,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if the database file exists and has a schema dbPath := "/tmp/integration_test_db.sqlite3" + fileInfo, err := t.Execute([]string{"ls", "-la", dbPath}) if err != nil { return fmt.Errorf("database file does not exist at %s: %w", dbPath, err) } + log.Printf("Database file info: %s", fileInfo) // Check if the database has any tables (schema) @@ -857,7 +881,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } if strings.TrimSpace(schemaCheck) == "" { - return errors.New("database file exists but has no schema (empty database)") + return errors.New("database file exists but has no schema (empty database)") //nolint:err113 } tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") @@ -872,6 +896,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { if err == io.EOF { break } + if err != nil { return fmt.Errorf("reading tar header: %w", err) } @@ -886,13 +911,15 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Extract the first regular file we find if header.Typeflag == tar.TypeReg { dbPath := path.Join(savePath, t.hostname+".db") + outFile, err := os.Create(dbPath) if err != nil { return fmt.Errorf("creating database file: %w", err) } - written, err := io.Copy(outFile, tarReader) + written, err := io.Copy(outFile, tarReader) //nolint:gosec // trusted tar from test container outFile.Close() + if err != nil { return fmt.Errorf("copying database file: %w", err) } @@ -906,7 +933,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if we actually wrote something if written == 0 { - return fmt.Errorf( + return fmt.Errorf( //nolint:err113 "database file is empty (size: %d, header size: %d)", written, header.Size, @@ -917,7 +944,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } } - return errors.New("no regular file found in database tar archive") + return errors.New("no regular file found in database tar archive") //nolint:err113 } // Execute runs a command inside the Headscale container and returns the @@ -1059,6 +1086,7 @@ func (t *HeadscaleInContainer) CreateUser( } var u v1.User + err = json.Unmarshal([]byte(result), &u) if err != nil { return nil, fmt.Errorf("unmarshalling user: %w", err) @@ -1195,6 +1223,7 @@ func (t *HeadscaleInContainer) ListNodes( users ...string, ) ([]*v1.Node, error) { var ret []*v1.Node + execUnmarshal := func(command []string) error { result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -1206,6 +1235,7 @@ func (t *HeadscaleInContainer) ListNodes( } var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) if err != nil { return fmt.Errorf("unmarshalling nodes: %w", err) @@ -1245,7 +1275,7 @@ func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error { "nodes", "delete", "--identifier", - fmt.Sprintf("%d", nodeID), + strconv.FormatUint(nodeID, 10), "--output", "json", "--force", @@ -1309,6 +1339,7 @@ func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { } var users []*v1.User + err = json.Unmarshal([]byte(result), &users) if err != nil { return nil, fmt.Errorf("unmarshalling nodes: %w", err) @@ -1439,6 +1470,7 @@ func (h *HeadscaleInContainer) PID() (int, error) { if pidInt == 1 { continue } + pids = append(pids, pidInt) } @@ -1494,6 +1526,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( } var node *v1.Node + err = json.Unmarshal([]byte(result), &node) if err != nil { return nil, fmt.Errorf("unmarshalling node response: %q, error: %w", result, err) @@ -1569,7 +1602,7 @@ func (t *HeadscaleInContainer) GetAllMapReponses() (map[types.NodeID][]tailcfg.M } var res map[types.NodeID][]tailcfg.MapResponse - if err := json.Unmarshal([]byte(result), &res); err != nil { + if err := json.Unmarshal([]byte(result), &res); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("decoding routes response: %w", err) } @@ -1589,7 +1622,7 @@ func (t *HeadscaleInContainer) PrimaryRoutes() (*routes.DebugRoutes, error) { } var debugRoutes routes.DebugRoutes - if err := json.Unmarshal([]byte(result), &debugRoutes); err != nil { + if err := json.Unmarshal([]byte(result), &debugRoutes); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("decoding routes response: %w", err) } @@ -1609,7 +1642,7 @@ func (t *HeadscaleInContainer) DebugBatcher() (*hscontrol.DebugBatcherInfo, erro } var debugInfo hscontrol.DebugBatcherInfo - if err := json.Unmarshal([]byte(result), &debugInfo); err != nil { + if err := json.Unmarshal([]byte(result), &debugInfo); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("decoding batcher debug response: %w", err) } @@ -1629,7 +1662,7 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er } var nodeStore map[types.NodeID]types.Node - if err := json.Unmarshal([]byte(result), &nodeStore); err != nil { + if err := json.Unmarshal([]byte(result), &nodeStore); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("decoding nodestore debug response: %w", err) } @@ -1649,7 +1682,7 @@ func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) { } var filterRules []tailcfg.FilterRule - if err := json.Unmarshal([]byte(result), &filterRules); err != nil { + if err := json.Unmarshal([]byte(result), &filterRules); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("decoding filter response: %w", err) } diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index d28b289b..2a155619 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -28,6 +28,7 @@ func PeerSyncTimeout() time.Duration { if util.IsCI() { return 120 * time.Second } + return 60 * time.Second } @@ -205,25 +206,27 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type res := make(map[types.NodeID]map[types.NodeID]bool) for nid, mrs := range all { res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { for _, peer := range mr.Peers { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec // safe conversion for peer ID } } for _, peer := range mr.PeersChanged { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec // safe conversion for peer ID } } for _, peer := range mr.PeersChangedPatch { if peer.Online != nil { - res[nid][types.NodeID(peer.NodeID)] = *peer.Online + res[nid][types.NodeID(peer.NodeID)] = *peer.Online //nolint:gosec // safe conversion for peer ID } } } } + return res } diff --git a/integration/route_test.go b/integration/route_test.go index 828dc003..1b35a224 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -49,6 +49,7 @@ func TestEnablingRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -91,6 +92,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -127,6 +129,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route approvals to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -149,9 +152,11 @@ func TestEnablingRoutes(t *testing.T) { assert.NotNil(c, peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) } } @@ -172,6 +177,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -214,6 +220,7 @@ func TestEnablingRoutes(t *testing.T) { } } +//nolint:gocyclo // complex HA failover test scenario func TestHASubnetRouterFailover(t *testing.T) { IntegrationSkip(t) @@ -271,6 +278,7 @@ func TestHASubnetRouterFailover(t *testing.T) { prefp, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) + pref := *prefp t.Logf("usernet1 prefix: %s", pref.String()) @@ -310,6 +318,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String()) t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String()) t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time") + for _, client := range allClients[:3] { command := []string{ "tailscale", @@ -325,6 +334,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Wait for route configuration changes after advertising routes var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -361,13 +371,15 @@ func TestHASubnetRouterFailover(t *testing.T) { ) // Helper function to check test failure and print route map if needed - checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) { + checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) { //nolint:thelper if t.Failed() { t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat)) + status, err := client.Status() if err == nil { printCurrentRouteMap(t, xmaps.Values(status.Peer)...) } + t.FailNow() } } @@ -386,6 +398,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String()) t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes") t.Logf(" Expected: Client can access webservice through router 1 only") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -456,10 +469,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1") @@ -483,6 +498,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)") t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter2.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -494,6 +510,7 @@ func TestHASubnetRouterFailover(t *testing.T) { nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 6) + if len(nodes) >= 3 { requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) @@ -569,10 +586,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode") @@ -598,6 +617,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY") t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -672,12 +692,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") var expectedIP netip.Addr + for _, ip := range ips { if ip.Is4() { expectedIP = ip break } } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") assertTracerouteViaIPWithCollect(c, tr, expectedIP) @@ -705,6 +727,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) should automatically become new PRIMARY", subRouter2.Hostname()) t.Logf(" Expected: Router 3 remains STANDBY") t.Logf(" Expected: Traffic seamlessly fails over to router 2") + err = subRouter1.Down() require.NoError(t, err) @@ -754,10 +777,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover") @@ -783,6 +808,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 (%s) should become new PRIMARY (last remaining router)", subRouter3.Hostname()) t.Logf(" Expected: With only 1 router left, HA is effectively disabled") t.Logf(" Expected: Traffic continues through router 3") + err = subRouter2.Down() require.NoError(t, err) @@ -825,10 +851,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover") @@ -853,6 +881,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)") t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)") t.Logf(" Expected: HA is restored with 2 routers available") + err = subRouter1.Up() require.NoError(t, err) @@ -902,10 +931,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery") @@ -932,6 +963,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname()) t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname()) t.Logf(" Expected: Full HA restored with all 3 routers online") + err = subRouter2.Up() require.NoError(t, err) @@ -982,10 +1014,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery") @@ -1067,10 +1101,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable") @@ -1153,10 +1189,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable") @@ -1182,6 +1220,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname()) t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname()) t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY") + r1Node := MustFindNode(subRouter1.Hostname(), nodes) _, err = headscale.ApproveRoutes( r1Node.GetId(), @@ -1237,10 +1276,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable") @@ -1266,6 +1307,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname()) t.Logf(" Expected: Routers 1 & 3 are both STANDBY") t.Logf(" Expected: Full HA restored with all 3 routers available") + r3Node := MustFindNode(subRouter3.Hostname(), nodes) _, err = headscale.ApproveRoutes( r3Node.GetId(), @@ -1315,6 +1357,7 @@ func TestSubnetRouteACL(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1362,6 +1405,7 @@ func TestSubnetRouteACL(t *testing.T) { sort.SliceStable(allClients, func(i, j int) bool { statusI := allClients[i].MustStatus() statusJ := allClients[j].MustStatus() + return statusI.Self.ID < statusJ.Self.ID }) @@ -1391,15 +1435,20 @@ func TestSubnetRouteACL(t *testing.T) { // Wait for route advertisements to propagate to the server var nodes []*v1.Node + require.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) // Find the node that should have the route by checking node IDs - var routeNode *v1.Node - var otherNode *v1.Node + var ( + routeNode *v1.Node + otherNode *v1.Node + ) + for _, node := range nodes { nodeIDStr := strconv.FormatUint(node.GetId(), 10) if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { @@ -1462,6 +1511,7 @@ func TestSubnetRouteACL(t *testing.T) { srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + if srs1PeerStatus == nil { return } @@ -1552,7 +1602,7 @@ func TestSubnetRouteACL(t *testing.T) { func TestEnablingExitRoutes(t *testing.T) { IntegrationSkip(t) - user := "user2" + user := "user2" //nolint:goconst // test-specific value, not related to userToDelete constant spec := ScenarioSpec{ NodesPerUser: 2, @@ -1560,6 +1610,7 @@ func TestEnablingExitRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) @@ -1581,8 +1632,10 @@ func TestEnablingExitRoutes(t *testing.T) { requireNoErrSync(t, err) var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1640,6 +1693,7 @@ func TestEnablingExitRoutes(t *testing.T) { peerStatus := status.Peer[peerKey] assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) @@ -1670,6 +1724,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1700,10 +1755,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1720,6 +1777,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1750,6 +1808,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1767,6 +1826,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) } }, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") @@ -1793,10 +1853,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := user2c.Traceroute(webip) assert.NoError(c, err) + ip, err := user1c.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for user1c") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router") } @@ -1817,6 +1879,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1844,10 +1907,12 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1864,6 +1929,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1946,6 +2012,7 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { return node } } + panic("node not found") } @@ -1965,6 +2032,8 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { // - Verify that peers can no longer use node // - Policy is changed back to auto approve route, check that routes already existing is approved. // - Verify that routes can now be seen by peers. +// +//nolint:gocyclo // complex multi-network auto-approve test scenario func TestAutoApproveMultiNetwork(t *testing.T) { IntegrationSkip(t) @@ -2229,10 +2298,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } scenario, err := NewScenario(tt.spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) var nodes []*v1.Node + opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -2259,7 +2330,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { preAuthKeyTags = []string{tt.approver} if tt.withURL { // For webauth, only user1 can request tags (per tagOwners policy) - webauthTagUser = "user1" + webauthTagUser = "user1" //nolint:goconst // test value, not a constant } } @@ -2288,6 +2359,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the Docker network route to the auto-approvers // Keep existing auto-approvers (like bigRoute) in place var approvers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): approvers = append(approvers, tagApprover(tt.approver)) @@ -2356,6 +2428,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } else { pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) } + require.NoError(t, err) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) @@ -2447,11 +2520,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) routerPeerFound := false + for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] if peerStatus.ID == routerUsernet1ID.StableID() { routerPeerFound = true + t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", peerStatus.HostName, peerStatus.ID, @@ -2459,9 +2534,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2498,10 +2575,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router") @@ -2538,9 +2617,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2560,10 +2641,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change") @@ -2597,6 +2680,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the route back to the auto approver in the policy, the route should // now become available again. var newApprovers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): newApprovers = append(newApprovers, tagApprover(tt.approver)) @@ -2630,9 +2714,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2652,10 +2738,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval") @@ -2691,11 +2779,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "2" { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2733,9 +2823,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2773,6 +2865,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "3" { requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) @@ -2791,7 +2884,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { assert.NotNil(c, tr) assert.True(c, tr.Success) - assert.NoError(c, tr.Err) + assert.NoError(c, tr.Err) //nolint:testifylint // using assert.CollectT assert.NotEmpty(c, tr.Route) // Since we're inside EventuallyWithT, we can't use require.Greater with t // but assert.NotEmpty above ensures len(tr.Route) > 0 @@ -2805,12 +2898,15 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int { } func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) { + t.Helper() t.Logf("== Current routing map ==") slices.SortFunc(routers, SortPeerStatus) + for _, router := range routers { got := filterNonRoutes(router) t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID) t.Logf(" AllowedIPs: %v", got) + if router.PrimaryRoutes != nil { t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice()) } @@ -2823,6 +2919,7 @@ func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix { if tsaddr.IsExitRoute(p) { return true } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) }) } @@ -2874,6 +2971,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -3014,6 +3112,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { // List nodes and verify the router has 3 available routes var err error + nodes, err := headscale.NodesByUser() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -3049,10 +3148,12 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := nodeClient.Traceroute(webip) assert.NoError(c, err) + ip, err := routerClient.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerClient") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router") } diff --git a/integration/scenario.go b/integration/scenario.go index e2d96603..dd07f50b 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -96,7 +96,7 @@ type User struct { type Scenario struct { // TODO(kradalby): support multiple headcales for later, currently only // use one. - controlServers *xsync.MapOf[string, ControlServer] + controlServers *xsync.Map[string, ControlServer] derpServers []*dsic.DERPServerInContainer users map[string]*User @@ -169,8 +169,8 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { // Opportunity to clean up unreferenced networks. // This might be a no op, but it is worth a try as we sometime // dont clean up nicely after ourselves. - dockertestutil.CleanUnreferencedNetworks(pool) - dockertestutil.CleanImagesInCI(pool) + _ = dockertestutil.CleanUnreferencedNetworks(pool) + _ = dockertestutil.CleanImagesInCI(pool) if spec.MaxWait == 0 { pool.MaxWait = dockertestMaxWait() @@ -180,7 +180,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { testHashPrefix := "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) s := &Scenario{ - controlServers: xsync.NewMapOf[string, ControlServer](), + controlServers: xsync.NewMap[string, ControlServer](), users: make(map[string]*User), pool: pool, @@ -191,9 +191,11 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { } var userToNetwork map[string]*dockertest.Network + if spec.Networks != nil || len(spec.Networks) != 0 { for name, users := range s.spec.Networks { networkName := testHashPrefix + "-" + name + network, err := s.AddNetwork(networkName) if err != nil { return nil, err @@ -201,8 +203,9 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { for _, user := range users { if n2, ok := userToNetwork[user]; ok { - return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) + return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) //nolint:err113 } + mak.Set(&userToNetwork, user, network) } } @@ -219,6 +222,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if err != nil { return nil, err } + mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc)) } } @@ -230,6 +234,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if spec.OIDCAccessTTL != 0 { ttl = spec.OIDCAccessTTL } + err = s.runMockOIDC(ttl, spec.OIDCUsers) if err != nil { return nil, err @@ -268,13 +273,14 @@ func (s *Scenario) Networks() []*dockertest.Network { if len(s.networks) == 0 { panic("Scenario.Networks called with empty network list") } + return xmaps.Values(s.networks) } func (s *Scenario) Network(name string) (*dockertest.Network, error) { net, ok := s.networks[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("no network named: %s", name) //nolint:err113 } return net, nil @@ -283,11 +289,11 @@ func (s *Scenario) Network(name string) (*dockertest.Network, error) { func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { net, ok := s.networks[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("no network named: %s", name) //nolint:err113 } if len(net.Network.IPAM.Config) == 0 { - return nil, fmt.Errorf("no IPAM config found in network: %s", name) + return nil, fmt.Errorf("no IPAM config found in network: %s", name) //nolint:err113 } pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) @@ -301,15 +307,17 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { res, ok := s.extraServices[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("no network named: %s", name) //nolint:err113 } return res, nil } func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { - defer dockertestutil.CleanUnreferencedNetworks(s.pool) - defer dockertestutil.CleanImagesInCI(s.pool) + t.Helper() + + defer func() { _ = dockertestutil.CleanUnreferencedNetworks(s.pool) }() + defer func() { _ = dockertestutil.CleanImagesInCI(s.pool) }() s.controlServers.Range(func(_ string, control ControlServer) bool { stdoutPath, stderrPath, err := control.Shutdown() @@ -334,9 +342,11 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { }) s.mu.Lock() + for userName, user := range s.users { for _, client := range user.Clients { log.Printf("removing client %s in user %s", client.Hostname(), userName) + stdoutPath, stderrPath, err := client.Shutdown() if err != nil { log.Printf("tearing down client: %s", err) @@ -353,6 +363,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { } } } + s.mu.Unlock() for _, derp := range s.derpServers { @@ -373,13 +384,16 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { if s.mockOIDC.r != nil { s.mockOIDC.r.Close() - if err := s.mockOIDC.r.Close(); err != nil { + + err := s.mockOIDC.r.Close() + if err != nil { log.Printf("tearing down oidc server: %s", err) } } for _, network := range s.networks { - if err := network.Close(); err != nil { + err := network.Close() + if err != nil { log.Printf("tearing down network: %s", err) } } @@ -395,7 +409,7 @@ func (s *Scenario) Shutdown() { // Users returns the name of all users associated with the Scenario. func (s *Scenario) Users() []string { - users := make([]string, 0) + users := make([]string, 0, len(s.users)) for user := range s.users { users = append(users, user) } @@ -466,7 +480,7 @@ func (s *Scenario) CreatePreAuthKey( reusable bool, ephemeral bool, ) (*v1.PreAuthKey, error) { - if headscale, err := s.Headscale(); err == nil { + if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr key, err := headscale.CreateAuthKey(user, reusable, ephemeral) if err != nil { return nil, fmt.Errorf("creating user: %w", err) @@ -518,7 +532,7 @@ func (s *Scenario) CreatePreAuthKeyWithTags( // CreateUser creates a User to be created in the // Headscale instance on behalf of the Scenario. func (s *Scenario) CreateUser(user string) (*v1.User, error) { - if headscale, err := s.Headscale(); err == nil { + if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr u, err := headscale.CreateUser(user) if err != nil { return nil, fmt.Errorf("creating user: %w", err) @@ -552,6 +566,7 @@ func (s *Scenario) CreateTailscaleNode( s.mu.Lock() defer s.mu.Unlock() + opts = append(opts, tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), @@ -591,6 +606,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) error { if user, ok := s.users[userStr]; ok { var versions []string + for i := range count { version := requestedVersion if requestedVersion == "all" { @@ -600,6 +616,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( version = MustTestVersions[i%len(MustTestVersions)] } } + versions = append(versions, version) headscale, err := s.Headscale() @@ -623,6 +640,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( extraHosts := []string{hostname + ":" + headscaleIP} s.mu.Lock() + opts = append(opts, tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), @@ -639,6 +657,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( opts..., ) s.mu.Unlock() + if err != nil { return fmt.Errorf( "creating tailscale node: %w", @@ -656,13 +675,17 @@ func (s *Scenario) CreateTailscaleNodesInUser( } s.mu.Lock() + user.Clients[tsClient.Hostname()] = tsClient + s.mu.Unlock() return nil }) } - if err := user.createWaitGroup.Wait(); err != nil { + + err := user.createWaitGroup.Wait() + if err != nil { return err } @@ -682,12 +705,14 @@ func (s *Scenario) RunTailscaleUp( if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { c := client + user.joinWaitGroup.Go(func() error { return c.Login(loginServer, authKey) }) } - if err := user.joinWaitGroup.Wait(); err != nil { + err := user.joinWaitGroup.Wait() + if err != nil { return err } @@ -749,11 +774,14 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat for _, client := range user.Clients { c := client expectedCount := expectedPeers + user.syncWaitGroup.Go(func() error { return c.WaitForPeers(expectedCount, timeout, retryInterval) }) } - if err := user.syncWaitGroup.Wait(); err != nil { + + err := user.syncWaitGroup.Wait() + if err != nil { allErrors = append(allErrors, err) } } @@ -773,11 +801,14 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, ret for _, user := range s.users { for _, client := range user.Clients { c := client + user.syncWaitGroup.Go(func() error { return c.WaitForPeers(peerCount, timeout, retryInterval) }) } - if err := user.syncWaitGroup.Wait(); err != nil { + + err := user.syncWaitGroup.Wait() + if err != nil { allErrors = append(allErrors, err) } } @@ -871,6 +902,7 @@ func (s *Scenario) createHeadscaleEnvWithTags( } else { key, err = s.CreatePreAuthKey(u.GetId(), true, false) } + if err != nil { return err } @@ -887,9 +919,11 @@ func (s *Scenario) createHeadscaleEnvWithTags( func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { tsc := client + user.joinWaitGroup.Go(func() error { loginURL, err := tsc.LoginWithURL(loginServer) if err != nil { @@ -904,7 +938,7 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { // If the URL is not a OIDC URL, then we need to // run the register command to fully log in the client. if !strings.Contains(loginURL.String(), "/oidc/") { - s.runHeadscaleRegister(userStr, body) + _ = s.runHeadscaleRegister(userStr, body) } return nil @@ -913,7 +947,8 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("client %s is ready", client.Hostname()) } - if err := user.joinWaitGroup.Wait(); err != nil { + err := user.joinWaitGroup.Wait() + if err != nil { return err } @@ -945,6 +980,7 @@ func newDebugJar() (*debugJar, error) { if err != nil { return nil, err } + return &debugJar{ inner: jar, store: make(map[string]map[string]map[string]*http.Cookie), @@ -961,20 +997,25 @@ func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) { if c == nil || c.Name == "" { continue } + domain := c.Domain if domain == "" { domain = u.Hostname() } + path := c.Path if path == "" { path = "/" } + if _, ok := j.store[domain]; !ok { j.store[domain] = make(map[string]map[string]*http.Cookie) } + if _, ok := j.store[domain][path]; !ok { j.store[domain][path] = make(map[string]*http.Cookie) } + j.store[domain][path][c.Name] = copyCookie(c) } } @@ -989,8 +1030,10 @@ func (j *debugJar) Dump(w io.Writer) { for domain, paths := range j.store { fmt.Fprintf(w, "Domain: %s\n", domain) + for path, byName := range paths { fmt.Fprintf(w, " Path: %s\n", path) + for _, c := range byName { fmt.Fprintf( w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n", @@ -1046,15 +1089,17 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f error, ) { if hc == nil { - return "", nil, fmt.Errorf("%s http client is nil", hostname) + return "", nil, fmt.Errorf("%s http client is nil", hostname) //nolint:err113 } if loginURL == nil { - return "", nil, fmt.Errorf("%s login url is nil", hostname) + return "", nil, fmt.Errorf("%s login url is nil", hostname) //nolint:err113 } log.Printf("%s logging in with url: %s", hostname, loginURL.String()) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) if err != nil { return "", nil, fmt.Errorf("%s creating http request: %w", hostname, err) @@ -1066,6 +1111,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f return http.ErrUseLastResponse } } + defer func() { hc.CheckRedirect = originalRedirect }() @@ -1080,6 +1126,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if err != nil { return "", nil, fmt.Errorf("%s reading response body: %w", hostname, err) } + body := string(bodyBytes) var redirectURL *url.URL @@ -1093,13 +1140,13 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if followRedirects && resp.StatusCode != http.StatusOK { log.Printf("body: %s", body) - return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) //nolint:err113 } if resp.StatusCode >= http.StatusBadRequest { log.Printf("body: %s", body) - return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) //nolint:err113 } if hc.Jar != nil { @@ -1117,7 +1164,7 @@ var errParseAuthPage = errors.New("parsing auth page") func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { // see api.go HTML template - codeSep := strings.Split(string(body), "") + codeSep := strings.Split(body, "") if len(codeSep) != 2 { return errParseAuthPage } @@ -1126,11 +1173,12 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { if len(keySep) != 2 { return errParseAuthPage } + key := keySep[1] key = strings.SplitN(key, " ", 2)[0] log.Printf("registering node %s", key) - if headscale, err := s.Headscale(); err == nil { + if headscale, err := s.Headscale(); err == nil { //nolint:noinlineerr _, err = headscale.Execute( []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, ) @@ -1154,6 +1202,7 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error noTls := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint } + resp, err := noTls.RoundTrip(req) if err != nil { return nil, err @@ -1173,12 +1222,14 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error // in a Scenario. func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { var ips []netip.Addr + if ns, ok := s.users[user]; ok { for _, client := range ns.Clients { clientIps, err := client.IPs() if err != nil { return ips, fmt.Errorf("getting IPs: %w", err) } + ips = append(ips, clientIps...) } @@ -1191,6 +1242,7 @@ func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { // GetClients returns all TailscaleClients associated with a User in a Scenario. func (s *Scenario) GetClients(user string) ([]TailscaleClient, error) { var clients []TailscaleClient + if ns, ok := s.users[user]; ok { for _, client := range ns.Clients { clients = append(clients, client) @@ -1290,11 +1342,14 @@ func (s *Scenario) WaitForTailscaleLogout() error { for _, user := range s.users { for _, client := range user.Clients { c := client + user.syncWaitGroup.Go(func() error { return c.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) }) } - if err := user.syncWaitGroup.Wait(); err != nil { + + err := user.syncWaitGroup.Wait() + if err != nil { return err } } @@ -1361,6 +1416,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse if err != nil { log.Fatalf("finding open port: %s", err) } + portNotation := fmt.Sprintf("%d/tcp", port) hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) @@ -1405,7 +1461,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") - if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( + if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( //nolint:noinlineerr headscaleBuildOptions, mockOidcOptions, dockertestutil.DockerRestartPolicy); err == nil { @@ -1421,9 +1477,10 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse ipAddr := s.mockOIDC.r.GetIPInNetwork(network) log.Println("Waiting for headscale mock oidc to be ready for tests") + hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) - if err := s.pool.Retry(func() error { + if err := s.pool.Retry(func() error { //nolint:noinlineerr oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) httpClient := &http.Client{} ctx := context.Background() @@ -1468,14 +1525,13 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { // log.Fatalf("finding open port: %s", err) // } // portNotation := fmt.Sprintf("%d/tcp", port) - hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := "hs-webservice-" + hash network, ok := s.networks[s.prefixedNetworkName(networkName)] if !ok { - return nil, fmt.Errorf("network does not exist: %s", networkName) + return nil, fmt.Errorf("network does not exist: %s", networkName) //nolint:err113 } webOpts := &dockertest.RunOptions{ diff --git a/integration/scenario_test.go b/integration/scenario_test.go index 1e2a151a..71998fca 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -35,6 +35,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -83,6 +84,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 2986bcea..993ac418 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -493,7 +493,7 @@ func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) { t.Helper() - assert.Error(t, err) + require.Error(t, err) if !isSSHNoAccessStdError(stderr) { t.Errorf("expected stderr output suggesting access denied, got: %s", stderr) diff --git a/integration/tags_test.go b/integration/tags_test.go index 16105ea2..f9cd394b 100644 --- a/integration/tags_test.go +++ b/integration/tags_test.go @@ -2502,7 +2502,7 @@ func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClie var actualTagsSlice []string if nm.SelfNode.Valid() { - for _, tag := range nm.SelfNode.Tags().All() { + for _, tag := range nm.SelfNode.Tags().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator actualTagsSlice = append(actualTagsSlice, tag) } } @@ -2647,7 +2647,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { var netmapTagsAfterFirstCall []string if nmErr == nil && nm != nil && nm.SelfNode.Valid() { - for _, tag := range nm.SelfNode.Tags().All() { + for _, tag := range nm.SelfNode.Tags().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator netmapTagsAfterFirstCall = append(netmapTagsAfterFirstCall, tag) } } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index c14163a2..879949d5 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -52,8 +52,6 @@ var ( errTailscaleNotLoggedIn = errors.New("tailscale not logged in") errTailscaleWrongPeerCount = errors.New("wrong peer count") errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey") - errTailscaleNotConnected = errors.New("tailscale not connected") - errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login") errInvalidClientConfig = errors.New("verifiably invalid client config requested") errInvalidTailscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_TAILSCALE_IMAGE format, expected repository:tag") errTailscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE must be set in CI for HEAD version") @@ -297,6 +295,8 @@ func (t *TailscaleInContainer) buildEntrypoint() []string { } // New returns a new TailscaleInContainer instance. +// +//nolint:gocyclo // complex container setup with many options func New( pool *dockertest.Pool, version string, @@ -338,7 +338,7 @@ func New( } if tsic.network == nil { - return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) + return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) //nolint:err113 } tailscaleOptions := &dockertest.RunOptions{ @@ -586,7 +586,7 @@ func (t *TailscaleInContainer) Version() string { return t.version } -// ID returns the Docker container ID of the TailscaleInContainer +// ContainerID returns the Docker container ID of the TailscaleInContainer // instance. func (t *TailscaleInContainer) ContainerID() string { return t.container.Container.ID @@ -621,7 +621,7 @@ func (t *TailscaleInContainer) Execute( return stdout, stderr, nil } -// Retrieve container logs. +// Logs retrieves the container logs. func (t *TailscaleInContainer) Logs(stdout, stderr io.Writer) error { return dockertestutil.WriteLog( t.pool, @@ -673,7 +673,7 @@ func (t *TailscaleInContainer) Login( ) error { command := t.buildLoginCommand(loginServer, authKey) - if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { //nolint:noinlineerr return fmt.Errorf( "%s failed to join tailscale client (%s): %w", t.hostname, @@ -685,11 +685,11 @@ func (t *TailscaleInContainer) Login( return nil } -// Up runs the login routine on the given Tailscale instance. +// LoginWithURL runs the login routine on the given Tailscale instance. // This login mechanism uses web + command line flow for authentication. func (t *TailscaleInContainer) LoginWithURL( loginServer string, -) (loginURL *url.URL, err error) { +) (*url.URL, error) { command := t.buildLoginCommand(loginServer, "") stdout, stderr, err := t.Execute(command) @@ -703,7 +703,7 @@ func (t *TailscaleInContainer) LoginWithURL( } }() - loginURL, err = util.ParseLoginURLFromCLILogin(stdout + stderr) + loginURL, err := util.ParseLoginURLFromCLILogin(stdout + stderr) if err != nil { return nil, err } @@ -713,14 +713,14 @@ func (t *TailscaleInContainer) LoginWithURL( // Logout runs the logout routine on the given Tailscale instance. func (t *TailscaleInContainer) Logout() error { - stdout, stderr, err := t.Execute([]string{"tailscale", "logout"}) + _, _, err := t.Execute([]string{"tailscale", "logout"}) if err != nil { return err } - stdout, stderr, _ = t.Execute([]string{"tailscale", "status"}) + stdout, stderr, _ := t.Execute([]string{"tailscale", "status"}) if !strings.Contains(stdout+stderr, "Logged out.") { - return fmt.Errorf("logging out, stdout: %s, stderr: %s", stdout, stderr) + return fmt.Errorf("logging out, stdout: %s, stderr: %s", stdout, stderr) //nolint:err113 } return t.waitForBackendState("NeedsLogin", integrationutil.PeerSyncTimeout()) @@ -759,14 +759,14 @@ func (t *TailscaleInContainer) Restart() error { return nil } -// Helper that runs `tailscale up` with no arguments. +// Up runs `tailscale up` with no arguments. func (t *TailscaleInContainer) Up() error { command := []string{ "tailscale", "up", } - if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { //nolint:noinlineerr return fmt.Errorf( "%s failed to bring tailscale client up (%s): %w", t.hostname, @@ -778,14 +778,14 @@ func (t *TailscaleInContainer) Up() error { return nil } -// Helper that runs `tailscale down` with no arguments. +// Down runs `tailscale down` with no arguments. func (t *TailscaleInContainer) Down() error { command := []string{ "tailscale", "down", } - if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { //nolint:noinlineerr return fmt.Errorf( "%s failed to bring tailscale client down (%s): %w", t.hostname, @@ -832,7 +832,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { } if len(ips) == 0 { - return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname) + return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname) //nolint:err113 } return ips, nil @@ -866,7 +866,7 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) { } } - return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname) + return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname) //nolint:err113 } func (t *TailscaleInContainer) MustIPv4() netip.Addr { @@ -908,7 +908,7 @@ func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) { return nil, fmt.Errorf("unmarshalling tailscale status: %w", err) } - err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_status.json", t.hostname), []byte(result), 0o755) + err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_status.json", t.hostname), []byte(result), 0o755) //nolint:gosec // test infrastructure log files if err != nil { return nil, fmt.Errorf("status netmap to /tmp/control: %w", err) } @@ -968,7 +968,7 @@ func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { return nil, fmt.Errorf("unmarshalling tailscale netmap: %w", err) } - err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_netmap.json", t.hostname), []byte(result), 0o755) + err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_netmap.json", t.hostname), []byte(result), 0o755) //nolint:gosec // test infrastructure log files if err != nil { return nil, fmt.Errorf("saving netmap to /tmp/control: %w", err) } @@ -1001,6 +1001,8 @@ func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { // watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until // it gets one that has a netmap.NetworkMap. +// +//nolint:unused func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error) { pr, pw := io.Pipe() @@ -1211,7 +1213,7 @@ func (t *TailscaleInContainer) waitForBackendState(state string, timeout time.Du for { select { case <-ctx.Done(): - return fmt.Errorf("timeout waiting for backend state %s on %s after %v", state, t.hostname, timeout) + return fmt.Errorf("timeout waiting for backend state %s on %s after %v", state, t.hostname, timeout) //nolint:err113 case <-ticker.C: status, err := t.Status() if err != nil { @@ -1256,7 +1258,7 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval return fmt.Errorf("timeout waiting for %d peers on %s after %v, errors: %w", expected, t.hostname, timeout, multierr.New(lastErrs...)) } - return fmt.Errorf("timeout waiting for %d peers on %s after %v", expected, t.hostname, timeout) + return fmt.Errorf("timeout waiting for %d peers on %s after %v", expected, t.hostname, timeout) //nolint:err113 case <-ticker.C: status, err := t.Status() if err != nil { @@ -1284,15 +1286,15 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval peer := status.Peer[peerKey] if !peer.Online { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)) //nolint:err113 } if peer.HostName == "" { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)) //nolint:err113 } if peer.Relay == "" { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)) //nolint:err113 } } @@ -1355,14 +1357,14 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err opt(&args) } - command := []string{ + command := make([]string, 0, 6) + command = append(command, "tailscale", "ping", fmt.Sprintf("--timeout=%s", args.timeout), fmt.Sprintf("--c=%d", args.count), - "--until-direct=" + strconv.FormatBool(args.direct), - } - - command = append(command, hostnameOrIP) + "--until-direct="+strconv.FormatBool(args.direct), + hostnameOrIP, + ) result, _, err := t.Execute( command, @@ -1566,19 +1568,19 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) { } if !strings.Contains(path, hdr.Name) { - return nil, fmt.Errorf("file not found in tar archive, looking for: %s, header was: %s", path, hdr.Name) + return nil, fmt.Errorf("file not found in tar archive, looking for: %s, header was: %s", path, hdr.Name) //nolint:err113 } - if _, err := io.Copy(&out, tr); err != nil { + if _, err := io.Copy(&out, tr); err != nil { //nolint:gosec,noinlineerr // trusted tar from test container return nil, fmt.Errorf("copying file to buffer: %w", err) } // Only support reading the first tile - break + break //nolint:staticcheck // SA4004: intentional - only read first file } if out.Len() == 0 { - return nil, errors.New("file is empty") + return nil, errors.New("file is empty") //nolint:err113 } return out.Bytes(), nil @@ -1591,7 +1593,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { } store := &mem.Store{} - if err = store.LoadFromJSON(state); err != nil { + if err = store.LoadFromJSON(state); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("unmarshalling state file: %w", err) } @@ -1606,7 +1608,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { } p := &ipn.Prefs{} - if err = json.Unmarshal(currentProfile, &p); err != nil { + if err = json.Unmarshal(currentProfile, &p); err != nil { //nolint:noinlineerr return nil, fmt.Errorf("unmarshalling current profile state: %w", err) } @@ -1617,7 +1619,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { // This is useful for verifying that policy changes have propagated to the client. func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { - return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version) + return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version) //nolint:err113 } nm, err := t.Netmap() diff --git a/swagger.go b/swagger.go index fa764568..514bbdf7 100644 --- a/swagger.go +++ b/swagger.go @@ -49,7 +49,7 @@ func SwaggerUI( `)) var payload bytes.Buffer - if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil { + if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil { //nolint:noinlineerr log.Error(). Caller(). Err(err). @@ -88,7 +88,7 @@ func SwaggerAPIv1( writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - if _, err := writer.Write(apiV1JSON); err != nil { + if _, err := writer.Write(apiV1JSON); err != nil { //nolint:noinlineerr log.Error(). Caller(). Err(err).