Compare commits

...

54 Commits

Author SHA1 Message Date
Kristoffer Dalby
d15ec28799 ci: pin Docker to v28 to avoid v29 breaking changes
Docker 29 (shipped with runner-images 20260209.23.1) breaks docker
build via Go client libraries (broken pipe writing build context)
and docker load/save with certain tarball formats. Add Docker's
official apt repository and install docker-ce 28.5.x in all CI
jobs that interact with Docker.

See https://github.com/actions/runner-images/issues/13474

Updates #3058
2026-02-19 08:21:23 +01:00
Kristoffer Dalby
eccf64eb58 all: fix staticcheck SA4006 in types_test.go
Use new(users["name"]) instead of extracting to intermediate
variables that staticcheck does not recognise as used with
Go 1.26 new(value) syntax.

Updates #3058
2026-02-19 08:21:23 +01:00
Kristoffer Dalby
43afeedde2 all: apply golangci-lint 2.9.0 fixes
Fix issues found by the upgraded golangci-lint:
- wsl_v5: add required whitespace in CLI files
- staticcheck SA4006: replace new(var.Field) with &localVar
  pattern since staticcheck does not recognize Go 1.26
  new(value) as a use of the variable
- staticcheck SA5011: use t.Fatal instead of t.Error for
  nil guard checks so execution stops
- unused: remove dead ptrTo helper function
2026-02-19 08:21:23 +01:00
Kristoffer Dalby
73613d7f53 db: fix database_versions table creation for PostgreSQL
Use GORM AutoMigrate instead of raw SQL to create the
database_versions table, since PostgreSQL does not support the
datetime type used in the raw SQL (it requires timestamp).
2026-02-19 08:21:23 +01:00
Kristoffer Dalby
30d18575be CHANGELOG: document strict version upgrade path 2026-02-19 08:21:23 +01:00
Kristoffer Dalby
70f8141abd all: upgrade from Go 1.26rc2 to Go 1.26.0 2026-02-19 08:21:23 +01:00
Kristoffer Dalby
82958835ce db: enforce strict version upgrade path
Add a version check that runs before database migrations to ensure
users do not skip minor versions or downgrade. This protects database
migrations and allows future cleanup of old migration code.

Rules enforced:
- Same minor version: always allowed (patch changes either way)
- Single minor upgrade (e.g. 0.27 -> 0.28): allowed
- Multi-minor upgrade (e.g. 0.25 -> 0.28): blocked with guidance
- Any minor downgrade: blocked
- Major version change: blocked
- Dev builds: warn but allow, preserve stored version

The version is stored in a purpose-built database_versions table
after migrations succeed. The table is created with raw SQL before
gormigrate runs to avoid circular dependencies.

Updates #3058
2026-02-19 08:21:23 +01:00
Kristoffer Dalby
9c3a3c5837 flake: upgrade golangci-lint to 2.9.0 and update nixpkgs 2026-02-19 08:21:23 +01:00
Florian Preinstorfer
faf55f5e8f Document how to use the provider identifier in the policy 2026-02-18 10:24:05 +01:00
Florian Preinstorfer
e3323b65e5 Describe how to set username instead of SPN for Kanidm 2026-02-18 10:24:05 +01:00
Florian Preinstorfer
8f60b819ec Refresh update path 2026-02-16 15:22:46 +01:00
Florian Preinstorfer
c29bcd2eaf Release planning happens in milestones 2026-02-16 15:22:46 +01:00
Florian Preinstorfer
890a044ef6 Add more UIs 2026-02-16 15:22:46 +01:00
Florian Preinstorfer
8028fa5483 No longer consider autogroup:self experimental 2026-02-16 15:22:46 +01:00
Kristoffer Dalby
a7f981e30e github: fix needs-more-info label race condition
Replace tiangolo/issue-manager with custom logic that distinguishes
bot comments from human responses. The issue-manager action treated
all comments equally, so the bot's own instruction comment would
trigger label removal on the next scheduled run.

Split into two jobs:
- remove-label-on-response: triggers on issue_comment from non-bot
  users, removes the needs-more-info label immediately
- close-stale: runs on daily schedule, uses nushell to iterate open
  needs-more-info issues, checks for human comments after the label
  was added, and closes after 3 days with no response
2026-02-15 19:42:47 +01:00
Kristoffer Dalby
e0d8c3c877 github: fix needs-more-info label race condition
Remove the `issues: labeled` trigger from the timer workflow.

When both workflows triggered on label addition, the comment workflow
would post the bot comment, and by the time the timer workflow ran,
issue-manager would see "a comment was added after the label" and
immediately remove the label due to `remove_label_on_comment: true`.

The timer workflow now only runs on:
- Daily cron (to close stale issues)
- issue_comment (to remove label when humans respond)
- workflow_dispatch (for manual testing)
2026-02-09 10:03:12 +01:00
Kristoffer Dalby
c1b468f9f4 github: update issue template contact links
Reorder contact links to show Discord first, then documentation.
Update Discord invite link and docs URL to current values.
2026-02-09 09:51:28 +01:00
Kristoffer Dalby
900f4b7b75 github: add support-request automation workflow
Add workflow that automatically closes issues labeled as
support-request with a message directing users to Discord
for configuration and support questions.

The workflow:
- Triggers when support-request label is added
- Posts a comment explaining this tracker is for bugs/features
- Links to documentation and Discord
- Closes the issue as "not planned"
2026-02-09 09:51:28 +01:00
Kristoffer Dalby
64f23136a2 github: add needs-more-info automation workflow
Add GitHub Actions automation that helps manage issues requiring
additional information from reporters:

- Post an instruction comment when 'needs-more-info' label is added,
  requesting environment details, debug logs from multiple nodes,
  configuration files, and proper formatting
- Automatically remove the label when anyone comments
- Close the issue after 3 days if no response is provided
- Exempt needs-more-info labeled issues from the stale bot

The instruction comment includes guidance on:
- Required environment and debug information
- Collecting logs from both connecting and connected-to nodes
- Proper redaction rules (replace consistently, never remove IPs)
- Formatting requirements for attachments and Markdown
- Encouragement to discuss on Discord before filing issues
2026-02-09 09:51:28 +01:00
Kristoffer Dalby
0f6d312ada all: upgrade to Go 1.26rc2 and modernize codebase
This commit upgrades the codebase from Go 1.25.5 to Go 1.26rc2 and
adopts new language features.

Toolchain updates:
- go.mod: go 1.25.5 → go 1.26rc2
- flake.nix: buildGo125Module → buildGo126Module, go_1_25 → go_1_26
- flake.nix: build golangci-lint from source with Go 1.26
- Dockerfile.integration: golang:1.25-trixie → golang:1.26rc2-trixie
- Dockerfile.tailscale-HEAD: golang:1.25-alpine → golang:1.26rc2-alpine
- Dockerfile.derper: golang:alpine → golang:1.26rc2-alpine
- .goreleaser.yml: go mod tidy -compat=1.25 → -compat=1.26
- cmd/hi/run.go: fallback Go version 1.25 → 1.26rc2
- .pre-commit-config.yaml: simplify golangci-lint hook entry

Code modernization using Go 1.26 features:
- Replace tsaddr.SortPrefixes with slices.SortFunc + netip.Prefix.Compare
- Replace ptr.To(x) with new(x) syntax
- Replace errors.As with errors.AsType[T]

Lint rule updates:
- Add forbidigo rules to prevent regression to old patterns
2026-02-08 12:35:23 +01:00
Kristoffer Dalby
20dff82f95 CHANGELOG: add minimum Tailscale version for 0.29.0
Update the 0.29.0 changelog entry to document the minimum
supported Tailscale client version (v1.76.0), which corresponds
to capability version 106 based on the 10-version support window.
2026-02-07 08:23:51 +01:00
Kristoffer Dalby
31c4331a91 capver: regenerate from docker tags
Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
2026-02-07 08:23:51 +01:00
Kristoffer Dalby
ce580f8245 all: fix golangci-lint issues (#3064) 2026-02-06 21:45:32 +01:00
Kristoffer Dalby
bfb6fd80df integration: fixup test
Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
3acce2da87 errors: rewrite errors to follow go best practices
Errors should not start capitalised and they should not contain the word error
or state that they "failed" as we already know it is an error

Signed-off-by: Kristoffer Dalby <kristoffer@dalby.cc>
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
4a9a329339 all: use lowercase log messages
Go style recommends that log messages and error strings should not be
capitalized (unless beginning with proper nouns or acronyms) and should
not end with punctuation.

This change normalizes all zerolog .Msg() and .Msgf() calls to start
with lowercase letters, following Go conventions and making logs more
consistent across the codebase.
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
dd16567c52 hscontrol/state,db: use zf constants for logging
Replace raw string field names with zf constants in state.go and
db/node.go for consistent, type-safe logging.

state.go changes:
- User creation, hostinfo validation, node registration
- Tag processing during reauth (processReauthTags)
- Auth path and PreAuthKey handling
- Route auto-approval and MapRequest processing

db/node.go changes:
- RegisterNodeForTest logging
- Invalid hostname replacement logging
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
e0a436cefc hscontrol/util/zlog/zf: add tag, authkey, and route constants
Add new zerolog field constants for improved logging consistency:

- Tag fields: CurrentTags, RemovedTags, RejectedTags, NewTags, OldTags,
  IsTagged, WasAuthKeyTagged
- Node fields: ExistingNodeID
- AuthKey fields: AuthKeyID, AuthKeyUsed, AuthKeyExpired, AuthKeyReusable,
  NodeKeyRotation
- Route fields: RoutesApprovedOld, RoutesApprovedNew, OldAnnouncedRoutes,
  NewAnnouncedRoutes, ApprovedRoutes, OldApprovedRoutes, NewApprovedRoutes,
  AutoApprovedRoutes, AllApprovedRoutes, RouteChanged
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
53cdeff129 hscontrol/mapper: use sub-loggers and zf constants
Add sub-logger patterns to worker(), AddNode(), RemoveNode() and
multiChannelNodeConn to eliminate repeated field calls. Use zf.*
constants for consistent field naming.

Changes in batcher_lockfree.go:
- Add wlog sub-logger in worker() with worker.id context
- Add log field to multiChannelNodeConn struct
- Initialize mc.log with node.id in newMultiChannelNodeConn()
- Add nlog sub-loggers in AddNode() and RemoveNode()
- Update all connection methods to use mc.log

Changes in batcher.go:
- Use zf.NodeID and zf.Reason in handleNodeChange()
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
7148a690d0 hscontrol/grpcv1: use EmbedObject and zf constants
Replace manual field extraction with EmbedObject for node logging
in gRPC handlers. Use zf.* constants for consistent field naming.

Changes:
- RegisterNode: use EmbedObject(node), zf.RegistrationKey, etc.
- SetTags: use EmbedObject(node)
- ExpireNode: use EmbedObject(node), zf.ExpiresAt
- RenameNode: use EmbedObject(node), zf.NewName
- SetApprovedRoutes: use zf.NodeID
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
4e73133b9f hscontrol/routes: use sub-logger and zf constants
Add sub-logger pattern to SetRoutes() to eliminate repeated node.id
field calls. Replace raw strings with zf.* constants throughout
the primary routes code for consistent field naming.

Changes:
- Add nlog sub-logger in SetRoutes() with node.id context
- Replace "prefix" with zf.Prefix
- Replace "changed" with zf.Changes
- Replace "newState" with zf.NewState
- Replace "finalState" with zf.FinalState
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
4f8724151e hscontrol/poll: use sub-logger pattern for mapSession
Replace the helper functions (logf, infof, tracef, errf) with a
zerolog sub-logger initialized in newMapSession(). The sub-logger
is pre-populated with session context (component, node, omitPeers,
stream) eliminating repeated field calls throughout the code.

Changes:
- Add log field to mapSession struct
- Initialize sub-logger with EmbedObject(node) and request context
- Remove logf/infof/tracef/errf helper functions
- Update all callers to use m.log.Level().Caller()... pattern
- Update noise.go to use sess.log instead of sess.tracef

This reduces code by ~20 lines and eliminates ~15 repeated field
calls per log statement.
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
91730e2a1d hscontrol: use EmbedObject for node logging
Replace manual Uint64("node.id")/Str("node.name") field patterns with
EmbedObject(node) which automatically includes all standard node fields
(id, name, machine key, node key, online status, tags, user).

This reduces code repetition and ensures consistent logging across:
- state.go: Connect/Disconnect, persistNodeToDB, AutoApproveRoutes
- auth.go: handleLogout, handleRegisterWithAuthKey
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
b5090a01ec cmd: use zf constants for zerolog field names
Update CLI logging to use zf.* constants instead of inline strings
for consistency with the rest of the codebase.
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
27f5641341 golangci: add forbidigo rule for zerolog field constants
Add a lint rule to enforce use of zf.* constants for zerolog field
names instead of inline string literals. This catches at lint time
any new code that doesn't follow the convention.

The rule matches common zerolog field methods (Str, Int, Bool, etc.)
and flags any usage with a string literal first argument.
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
cf3d30b6f6 types: add MarshalZerologObject to domain types
Implement zerolog.LogObjectMarshaler interface on domain types
for structured logging:

- Node: logs node.id, node.name, machine.key (short), node.key (short),
  node.is_tagged, node.expired, node.online, node.tags, user.name
- User: logs user.id, user.name, user.display, user.provider
- PreAuthKey: logs pak.id, pak.prefix (masked), pak.reusable,
  pak.ephemeral, pak.used, pak.is_tagged, pak.tags
- APIKey: logs api_key.id, api_key.prefix (masked), api_key.expiration

Security: PreAuthKey and APIKey only log masked prefixes, never full
keys or hashes. Uses zf.* constants for consistent field naming.
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
58020696fe zlog: add utility package for safe and consistent logging
Add hscontrol/util/zlog package with:

- zf subpackage: field name constants for compile-time safety
- SafeHostinfo: wrapper that redacts device fingerprinting data
- SafeMapRequest: wrapper that redacts client endpoints

The zf (zerolog fields) subpackage provides short constant names
(e.g., zf.NodeID instead of inline "node.id" strings) ensuring
consistent field naming across all log statements.

Security considerations:
- SafeHostinfo never logs: OSVersion, DeviceModel, DistroName
- SafeMapRequest only logs endpoint counts, not actual IPs
2026-02-06 07:40:29 +01:00
Kristoffer Dalby
e44b402fe4 integration: update TestSubnetRouteACL for filter merging and IPProto
Update integration test expectations to match current policy behavior:

1. IPProto defaults include all four protocols (TCP, UDP, ICMPv4,
   ICMPv6) for port-range ACL rules, not just TCP and UDP.

2. Filter rules with identical SrcIPs and IPProto are now merged
   into a single rule with combined DstPorts, so the subnet router
   receives one filter rule instead of two.

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
835b7eb960 policy: autogroup:internet does not generate packet filters
According to Tailscale SaaS behavior, autogroup:internet is handled
by exit node routing via AllowedIPs, not by packet filtering. ACL
rules with autogroup:internet as destination should produce no
filter rules for any node.

Previously, Headscale expanded autogroup:internet to public CIDR
ranges and distributed filters to exit nodes (because 0.0.0.0/0
"covers" internet destinations). This was incorrect.

Add detection for AutoGroupInternet in filter compilation to skip
filter generation for this autogroup. Update test expectations
accordingly.
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
95b1fd636e policy: fix wildcard DstPorts format and proto:icmp handling
Fix two compatibility issues discovered in Tailscale SaaS testing:

1. Wildcard DstPorts format: Headscale was expanding wildcard
   destinations to CGNAT ranges (100.64.0.0/10, fd7a:115c:a1e0::/48)
   while Tailscale uses {IP: "*"} directly. Add detection for
   wildcard (Asterix) alias type in filter compilation to use the
   correct format.

2. proto:icmp handling: The "icmp" protocol name was returning both
   ICMPv4 (1) and ICMPv6 (58), but Tailscale only returns ICMPv4.
   Users should use "ipv6-icmp" or protocol number 58 explicitly
   for IPv6 ICMP.

Update all test expectations accordingly. This significantly reduces
test file line count by replacing duplicated CGNAT range patterns
with single wildcard entries.
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
834ac27779 policy/v2: add subnet routes and exit node compatibility tests
Add comprehensive test file for validating Headscale's ACL engine
behavior for subnet routes and exit nodes against documented
Tailscale SaaS behavior.

Tests cover:
- Category A: Subnet route basics (wildcard includes routes, tag-based
  ACL excludes routes)
- Category B: Exit node behavior (exit routes not in SrcIPs)
- Category F: Filter placement rules (filters on destination nodes)
- Category G: Protocol and port restrictions
- Category R: Route coverage rules
- Category O: Overlapping routes
- Category H: Edge cases (wildcard formats, CGNAT handling)
- Category T: Tag resolution (tags resolve to node IPs only)
- Category I: IPv6 specific behavior

The tests document expected Tailscale SaaS behavior with TODOs marking
areas where Headscale currently differs. This provides a baseline for
compatibility improvements.
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
4a4032a4b0 changelog: document filter rule merging
Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
29aa08df0e policy: update test expectations for merged filter rules
Update test expectations across policy tests to expect merged
FilterRule entries instead of separate ones. Tests now expect:
- Single FilterRule with combined DstPorts for same source
- Reduced matcher counts for exit node tests

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
0b1727c337 policy: merge filter rules with identical SrcIPs and IPProto
Tailscale merges multiple ACL rules into fewer FilterRule entries
when they have identical SrcIPs and IPProto, combining their DstPorts
arrays. This change implements the same behavior in Headscale.

Add mergeFilterRules() which uses O(n) hash map lookup to merge rules
with identical keys. DstPorts are NOT deduplicated to match Tailscale
behavior.

Also fix DestsIsTheInternet() to handle merged filter rules where
TheInternet is combined with other destinations - now uses superset
check instead of equality check.

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
08fe2e4d6c policy: use CIDR format for autogroup:self destinations
Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
cb29cade46 docs: add compatibility test documentation
Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
f27298c759 changelog: document wildcard CGNAT range change
Add breaking change entry for the wildcard resolution change to use
CGNAT/ULA ranges instead of all IPs.
Updates #3036

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
8baa14ef4a policy: use CGNAT/ULA ranges for wildcard resolution
Change Asterix.Resolve() to use Tailscale's CGNAT range (100.64.0.0/10)
and ULA range (fd7a:115c:a1e0::/48) instead of all IPs (0.0.0.0/0 and
::/0).
This better matches Tailscale's security model where wildcard (*) means
"any node in the tailnet" rather than literally "any IP address on the
internet".
Updates #3036

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
ebdbe03639 policy: validate autogroup:self sources in ACL rules
Tailscale validates that autogroup:self destinations in ACL rules can
only be used when ALL sources are users, groups, autogroup:member, or
wildcard (*). Previously, Headscale only performed this validation for
SSH rules.
Add validateACLSrcDstCombination() to enforce that tags, autogroup:tagged,
hosts, and raw IPs cannot be used as sources with autogroup:self
destinations. Invalid policies like `tag:client → autogroup:self:*` are
now rejected at validation time, matching Tailscale behavior.
Wildcard (*) is allowed because autogroup:self evaluation narrows it
per-node to only the node's own IPs.

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
f735502eae policy: add ICMP protocols to default and export constants
When ACL rules don't specify a protocol, Headscale now defaults to
[TCP, UDP, ICMP, ICMPv6] instead of just [TCP, UDP], matching
Tailscale's behavior.
Also export protocol number constants (ProtocolTCP, ProtocolUDP, etc.)
for use in external test packages, renaming the string protocol
constants to ProtoNameTCP, ProtoNameUDP, etc. to avoid conflicts.
This resolves 78 ICMP-related TODOs in the Tailscale compatibility
tests, reducing the total from 165 to 87.

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
53d17aa321 policy: add comprehensive Tailscale ACL compatibility tests
Add extensive test coverage verifying Headscale's ACL policy behavior
matches Tailscale's coordination server. Tests cover:
- Source/destination resolution for users, groups, tags, hosts, IPs
- autogroup:member, autogroup:tagged, autogroup:self behavior
- Filter rule deduplication and merging semantics
- Multi-rule interaction patterns
- Error case validation
Key behavioral differences documented:
- Headscale creates separate filter entries per ACL rule; Tailscale
  merges rules with identical sources
- Headscale deduplicates Dsts within a rule; Tailscale does not
- Headscale does not validate autogroup:self source restrictions for
  ACL rules (only SSH rules); Tailscale rejects invalid sources
Tests are based on real Tailscale coordination server responses
captured from a test environment with 5 nodes (1 user-owned, 4 tagged).

Updates #3036
2026-02-05 19:29:16 +01:00
Kristoffer Dalby
14f833bdb9 policy: fix autogroup:self handling for tagged nodes
Skip autogroup:self destination processing for tagged nodes since they
can never match autogroup:self (which only applies to user-owned nodes).
Also reorder the IsTagged() check to short-circuit before accessing
User() to avoid potential nil pointer access on tagged nodes.

Updates #3036
2026-02-05 19:29:16 +01:00
Florian Preinstorfer
9e50071df9 Link Fosdem 2026 talk 2026-02-05 08:01:02 +01:00
Florian Preinstorfer
c907b0d323 Fix version in mkdocs 2026-02-05 08:01:02 +01:00
181 changed files with 24747 additions and 2901 deletions

View File

@@ -6,8 +6,7 @@ body:
- type: checkboxes - type: checkboxes
attributes: attributes:
label: Is this a support request? label: Is this a support request?
description: description: This issue tracker is for bugs and feature requests only. If you need
This issue tracker is for bugs and feature requests only. If you need
help, please use ask in our Discord community help, please use ask in our Discord community
options: options:
- label: This is not a support request - label: This is not a support request
@@ -15,8 +14,7 @@ body:
- type: checkboxes - type: checkboxes
attributes: attributes:
label: Is there an existing issue for this? label: Is there an existing issue for this?
description: description: Please search to see if an issue already exists for the bug you
Please search to see if an issue already exists for the bug you
encountered. encountered.
options: options:
- label: I have searched the existing issues - label: I have searched the existing issues

View File

@@ -3,9 +3,9 @@ blank_issues_enabled: false
# Contact links # Contact links
contact_links: contact_links:
- name: "headscale usage documentation"
url: "https://github.com/juanfont/headscale/blob/main/docs"
about: "Find documentation about how to configure and run headscale."
- name: "headscale Discord community" - name: "headscale Discord community"
url: "https://discord.gg/xGj2TuqyxY" url: "https://discord.gg/c84AZQhmpx"
about: "Please ask and answer questions about usage of headscale here." about: "Please ask and answer questions about usage of headscale here."
- name: "headscale usage documentation"
url: "https://headscale.net/"
about: "Find documentation about how to configure and run headscale."

View File

@@ -0,0 +1,80 @@
Thank you for taking the time to report this issue.
To help us investigate and resolve this, we need more information. Please provide the following:
> [!TIP]
> Most issues turn out to be configuration errors rather than bugs. We encourage you to discuss your problem in our [Discord community](https://discord.gg/c84AZQhmpx) **before** opening an issue. The community can often help identify misconfigurations quickly, saving everyone time.
## Required Information
### Environment Details
- **Headscale version**: (run `headscale version`)
- **Tailscale client version**: (run `tailscale version`)
- **Operating System**: (e.g., Ubuntu 24.04, macOS 14, Windows 11)
- **Deployment method**: (binary, Docker, Kubernetes, etc.)
- **Reverse proxy**: (if applicable: nginx, Traefik, Caddy, etc. - include configuration)
### Debug Information
Please follow our [Debugging and Troubleshooting Guide](https://headscale.net/stable/ref/debug/) and provide:
1. **Client netmap dump** (from affected Tailscale client):
```bash
tailscale debug netmap > netmap.json
```
2. **Client status dump** (from affected Tailscale client):
```bash
tailscale status --json > status.json
```
3. **Tailscale client logs** (if experiencing client issues):
```bash
tailscale debug daemon-logs
```
> [!IMPORTANT]
> We need logs from **multiple nodes** to understand the full picture:
>
> - The node(s) initiating connections
> - The node(s) being connected to
>
> Without logs from both sides, we cannot diagnose connectivity issues.
4. **Headscale server logs** with `log.level: trace` enabled
5. **Headscale configuration** (with sensitive values redacted - see rules below)
6. **ACL/Policy configuration** (if using ACLs)
7. **Proxy/Docker configuration** (if applicable - nginx.conf, docker-compose.yml, Traefik config, etc.)
## Formatting Requirements
- **Attach long files** - Do not paste large logs or configurations inline. Use GitHub file attachments or GitHub Gists.
- **Use proper Markdown** - Format code blocks, logs, and configurations with appropriate syntax highlighting.
- **Structure your response** - Use the headings above to organize your information clearly.
## Redaction Rules
> [!CAUTION]
> **Replace, do not remove.** Removing information makes debugging impossible.
When redacting sensitive information:
- ✅ **Replace consistently** - If you change `alice@company.com` to `user1@example.com`, use `user1@example.com` everywhere (logs, config, policy, etc.)
- ✅ **Use meaningful placeholders** - `user1@example.com`, `bob@example.com`, `my-secret-key` are acceptable
- ❌ **Never remove information** - Gaps in data prevent us from correlating events across logs
- ❌ **Never redact IP addresses** - We need the actual IPs to trace network paths and identify issues
**If redaction rules are not followed, we will be unable to debug the issue and will have to close it.**
---
**Note:** This issue will be automatically closed in 3 days if no additional information is provided. Once you reply with the requested information, the `needs-more-info` label will be removed automatically.
If you need help gathering this information, please visit our [Discord community](https://discord.gg/c84AZQhmpx).

View File

@@ -0,0 +1,15 @@
Thank you for reaching out.
This issue tracker is used for **bug reports and feature requests** only. Your question appears to be a support or configuration question rather than a bug report.
For help with setup, configuration, or general questions, please visit our [Discord community](https://discord.gg/c84AZQhmpx) where the community and maintainers can assist you in real-time.
**Before posting in Discord, please check:**
- [Documentation](https://headscale.net/)
- [FAQ](https://headscale.net/stable/faq/)
- [Debugging and Troubleshooting Guide](https://headscale.net/stable/ref/debug/)
If after troubleshooting you determine this is actually a bug, please open a new issue with the required debug information from the troubleshooting guide.
This issue has been automatically closed.

View File

@@ -67,6 +67,24 @@ jobs:
with: with:
name: postgres-image name: postgres-image
path: /tmp/artifacts path: /tmp/artifacts
- name: Pin Docker to v28 (avoid v29 breaking changes)
run: |
# Docker 29 breaks docker build via Go client libraries and
# docker load/save with certain tarball formats.
# Pin to Docker 28.x until our tooling is updated.
# https://github.com/actions/runner-images/issues/13474
sudo install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update -qq
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
sudo apt-get install -y --allow-downgrades \
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
sudo systemctl restart docker
docker version
- name: Load Docker images, Go cache, and prepare binary - name: Load Docker images, Go cache, and prepare binary
run: | run: |
gunzip -c /tmp/artifacts/headscale-image.tar.gz | docker load gunzip -c /tmp/artifacts/headscale-image.tar.gz | docker load

View File

@@ -0,0 +1,28 @@
name: Needs More Info - Post Comment
on:
issues:
types: [labeled]
jobs:
post-comment:
if: >-
github.event.label.name == 'needs-more-info' &&
github.repository == 'juanfont/headscale'
runs-on: ubuntu-latest
permissions:
issues: write
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
sparse-checkout: .github/label-response/needs-more-info.md
sparse-checkout-cone-mode: false
- name: Post instruction comment
run: gh issue comment "$NUMBER" --body-file .github/label-response/needs-more-info.md
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
NUMBER: ${{ github.event.issue.number }}

View File

@@ -0,0 +1,98 @@
name: Needs More Info - Timer
on:
schedule:
- cron: "0 0 * * *" # Daily at midnight UTC
issue_comment:
types: [created]
workflow_dispatch:
jobs:
# When a non-bot user comments on a needs-more-info issue, remove the label.
remove-label-on-response:
if: >-
github.repository == 'juanfont/headscale' &&
github.event_name == 'issue_comment' &&
github.event.comment.user.type != 'Bot' &&
contains(github.event.issue.labels.*.name, 'needs-more-info')
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- name: Remove needs-more-info label
run: gh issue edit "$NUMBER" --remove-label needs-more-info
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
NUMBER: ${{ github.event.issue.number }}
# On schedule, close issues that have had no human response for 3 days.
close-stale:
if: >-
github.repository == 'juanfont/headscale' &&
github.event_name != 'issue_comment'
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- uses: hustcer/setup-nu@920172d92eb04671776f3ba69d605d3b09351c30 # v3.22
with:
version: "*"
- name: Close stale needs-more-info issues
shell: nu {0}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
run: |
let issues = (gh issue list
--repo $env.GH_REPO
--label "needs-more-info"
--state open
--json number
| from json)
for issue in $issues {
let number = $issue.number
print $"Checking issue #($number)"
# Find when needs-more-info was last added
let events = (gh api $"repos/($env.GH_REPO)/issues/($number)/events"
--paginate | from json | flatten)
let label_event = ($events
| where event == "labeled" and label.name == "needs-more-info"
| last)
let label_added_at = ($label_event.created_at | into datetime)
# Check for non-bot comments after the label was added
let comments = (gh api $"repos/($env.GH_REPO)/issues/($number)/comments"
--paginate | from json | flatten)
let human_responses = ($comments
| where user.type != "Bot"
| where { ($in.created_at | into datetime) > $label_added_at })
if ($human_responses | length) > 0 {
print $" Human responded, removing label"
gh issue edit $number --repo $env.GH_REPO --remove-label needs-more-info
continue
}
# Check if 3 days have passed
let elapsed = (date now) - $label_added_at
if $elapsed < 3day {
print $" Only ($elapsed | format duration day) elapsed, skipping"
continue
}
print $" No response for ($elapsed | format duration day), closing"
let message = [
"This issue has been automatically closed because no additional information was provided within 3 days."
""
"If you have the requested information, please open a new issue and include the debug information requested above."
""
"Thank you for your understanding."
] | str join "\n"
gh issue comment $number --repo $env.GH_REPO --body $message
gh issue close $number --repo $env.GH_REPO --reason "not planned"
gh issue edit $number --repo $env.GH_REPO --remove-label needs-more-info
}

View File

@@ -17,6 +17,25 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Pin Docker to v28 (avoid v29 breaking changes)
run: |
# Docker 29 breaks docker build via Go client libraries and
# docker load/save with certain tarball formats.
# Pin to Docker 28.x until our tooling is updated.
# https://github.com/actions/runner-images/issues/13474
sudo install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update -qq
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
sudo apt-get install -y --allow-downgrades \
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
sudo systemctl restart docker
docker version
- name: Login to DockerHub - name: Login to DockerHub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with: with:

View File

@@ -23,5 +23,5 @@ jobs:
since being marked as stale." since being marked as stale."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
exempt-issue-labels: "no-stale-bot" exempt-issue-labels: "no-stale-bot,needs-more-info"
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}

30
.github/workflows/support-request.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Support Request - Close Issue
on:
issues:
types: [labeled]
jobs:
close-support-request:
if: >-
github.event.label.name == 'support-request' &&
github.repository == 'juanfont/headscale'
runs-on: ubuntu-latest
permissions:
issues: write
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
sparse-checkout: .github/label-response/support-request.md
sparse-checkout-cone-mode: false
- name: Post comment and close issue
run: |
gh issue comment "$NUMBER" --body-file .github/label-response/support-request.md
gh issue close "$NUMBER" --reason "not planned"
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GH_REPO: ${{ github.repository }}
NUMBER: ${{ github.event.issue.number }}

View File

@@ -69,6 +69,25 @@ jobs:
name: go-cache name: go-cache
path: go-cache.tar.gz path: go-cache.tar.gz
retention-days: 10 retention-days: 10
- name: Pin Docker to v28 (avoid v29 breaking changes)
if: steps.changed-files.outputs.files == 'true'
run: |
# Docker 29 breaks docker build via Go client libraries and
# docker load/save with certain tarball formats.
# Pin to Docker 28.x until our tooling is updated.
# https://github.com/actions/runner-images/issues/13474
sudo install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update -qq
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
sudo apt-get install -y --allow-downgrades \
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
sudo systemctl restart docker
docker version
- name: Build headscale image - name: Build headscale image
if: steps.changed-files.outputs.files == 'true' if: steps.changed-files.outputs.files == 'true'
run: | run: |
@@ -104,6 +123,24 @@ jobs:
needs: build needs: build
if: needs.build.outputs.files-changed == 'true' if: needs.build.outputs.files-changed == 'true'
steps: steps:
- name: Pin Docker to v28 (avoid v29 breaking changes)
run: |
# Docker 29 breaks docker build via Go client libraries and
# docker load/save with certain tarball formats.
# Pin to Docker 28.x until our tooling is updated.
# https://github.com/actions/runner-images/issues/13474
sudo install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg \
| sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] \
https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" \
| sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update -qq
VERSION=$(apt-cache madison docker-ce | grep '28\.5' | head -1 | awk '{print $3}')
sudo apt-get install -y --allow-downgrades \
"docker-ce=${VERSION}" "docker-ce-cli=${VERSION}"
sudo systemctl restart docker
docker version
- name: Pull and save postgres image - name: Pull and save postgres image
run: | run: |
docker pull postgres:latest docker pull postgres:latest

View File

@@ -18,6 +18,7 @@ linters:
- lll - lll
- maintidx - maintidx
- makezero - makezero
- mnd
- musttag - musttag
- nestif - nestif
- nolintlint - nolintlint
@@ -37,6 +38,23 @@ linters:
time.Sleep is forbidden. time.Sleep is forbidden.
In tests: use assert.EventuallyWithT for polling/waiting patterns. In tests: use assert.EventuallyWithT for polling/waiting patterns.
In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives. In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives.
# Forbid inline string literals in zerolog field methods - use zf.* constants
- pattern: '\.(Str|Int|Int8|Int16|Int32|Int64|Uint|Uint8|Uint16|Uint32|Uint64|Float32|Float64|Bool|Dur|Time|TimeDiff|Strs|Ints|Uints|Floats|Bools|Any|Interface)\("[^"]+"'
msg: >-
Use zf.* constants for zerolog field names instead of string literals.
Import "github.com/juanfont/headscale/hscontrol/util/zlog/zf" and use
constants like zf.NodeID, zf.UserName, etc. Add new constants to
hscontrol/util/zlog/zf/fields.go if needed.
# Forbid ptr.To - use Go 1.26 new(expr) instead
- pattern: 'ptr\.To\('
msg: >-
ptr.To is forbidden. Use Go 1.26's new(expr) syntax instead.
Example: ptr.To(value) → new(value)
# Forbid tsaddr.SortPrefixes - use slices.SortFunc with netip.Prefix.Compare
- pattern: 'tsaddr\.SortPrefixes'
msg: >-
tsaddr.SortPrefixes is forbidden. Use Go 1.26's netip.Prefix.Compare instead.
Example: slices.SortFunc(prefixes, netip.Prefix.Compare)
analyze-types: true analyze-types: true
gocritic: gocritic:
disabled-checks: disabled-checks:

View File

@@ -2,7 +2,7 @@
version: 2 version: 2
before: before:
hooks: hooks:
- go mod tidy -compat=1.25 - go mod tidy -compat=1.26
- go mod vendor - go mod vendor
release: release:

View File

@@ -43,26 +43,12 @@ repos:
entry: prettier --write --list-different entry: prettier --write --list-different
language: system language: system
exclude: ^docs/ exclude: ^docs/
types_or: types_or: [javascript, jsx, ts, tsx, yaml, json, toml, html, css, scss, sass, markdown]
[
javascript,
jsx,
ts,
tsx,
yaml,
json,
toml,
html,
css,
scss,
sass,
markdown,
]
# golangci-lint for Go code quality # golangci-lint for Go code quality
- id: golangci-lint - id: golangci-lint
name: golangci-lint name: golangci-lint
entry: nix develop --command golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix entry: golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix
language: system language: system
types: [go] types: [go]
pass_filenames: false pass_filenames: false

View File

@@ -2,6 +2,38 @@
## 0.29.0 (202x-xx-xx) ## 0.29.0 (202x-xx-xx)
**Minimum supported Tailscale client version: v1.76.0**
### Tailscale ACL compatibility improvements
Extensive test cases were systematically generated using Tailscale clients and the official SaaS
to understand how the packet filter should be generated. We discovered a few differences, but
overall our implementation was very close.
[#3036](https://github.com/juanfont/headscale/pull/3036)
### BREAKING
- **ACL Policy**: Wildcard (`*`) in ACL sources and destinations now resolves to Tailscale's CGNAT range (`100.64.0.0/10`) and ULA range (`fd7a:115c:a1e0::/48`) instead of all IPs (`0.0.0.0/0` and `::/0`) [#3036](https://github.com/juanfont/headscale/pull/3036)
- This better matches Tailscale's security model where `*` means "any node in the tailnet" rather than "any IP address"
- Policies relying on wildcard to match non-Tailscale IPs will need to use explicit CIDR ranges instead
- **Note**: Users with non-standard IP ranges configured in `prefixes.ipv4` or `prefixes.ipv6` (which is unsupported and produces a warning) will need to explicitly specify their CIDR ranges in ACL rules instead of using `*`
- **ACL Policy**: Validate autogroup:self source restrictions matching Tailscale behavior - tags, hosts, and IPs are rejected as sources for autogroup:self destinations [#3036](https://github.com/juanfont/headscale/pull/3036)
- Policies using tags, hosts, or IP addresses as sources for autogroup:self destinations will now fail validation
- **Upgrade path**: Headscale now enforces a strict version upgrade path [#3083](https://github.com/juanfont/headscale/pull/3083)
- Skipping minor versions (e.g. 0.27 → 0.29) is blocked; upgrade one minor version at a time
- Downgrading to a previous minor version is blocked
- Patch version changes within the same minor are always allowed
- **ACL Policy**: The `proto:icmp` protocol name now only includes ICMPv4 (protocol 1), matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
- Previously, `proto:icmp` included both ICMPv4 and ICMPv6
- Use `proto:ipv6-icmp` or protocol number `58` explicitly for ICMPv6
### Changes
- **ACL Policy**: Add ICMP and IPv6-ICMP protocols to default filter rules when no protocol is specified [#3036](https://github.com/juanfont/headscale/pull/3036)
- **ACL Policy**: Fix autogroup:self handling for tagged nodes - tagged nodes no longer incorrectly receive autogroup:self filter rules [#3036](https://github.com/juanfont/headscale/pull/3036)
- **ACL Policy**: Use CIDR format for autogroup:self destination IPs matching Tailscale behavior [#3036](https://github.com/juanfont/headscale/pull/3036)
- **ACL Policy**: Merge filter rules with identical SrcIPs and IPProto matching Tailscale behavior - multiple ACL rules with the same source now produce a single FilterRule with combined DstPorts [#3036](https://github.com/juanfont/headscale/pull/3036)
## 0.28.0 (2026-02-04) ## 0.28.0 (2026-02-04)
**Minimum supported Tailscale client version: v1.74.0** **Minimum supported Tailscale client version: v1.74.0**

View File

@@ -1,6 +1,6 @@
# For testing purposes only # For testing purposes only
FROM golang:alpine AS build-env FROM golang:1.26.0-alpine AS build-env
WORKDIR /go/src WORKDIR /go/src

View File

@@ -2,7 +2,7 @@
# and are in no way endorsed by Headscale's maintainers as an # and are in no way endorsed by Headscale's maintainers as an
# official nor supported release or distribution. # official nor supported release or distribution.
FROM docker.io/golang:1.25-trixie AS builder FROM docker.io/golang:1.26.0-trixie AS builder
ARG VERSION=dev ARG VERSION=dev
ENV GOPATH /go ENV GOPATH /go
WORKDIR /go/src/headscale WORKDIR /go/src/headscale

View File

@@ -4,7 +4,7 @@
# This Dockerfile is more or less lifted from tailscale/tailscale # This Dockerfile is more or less lifted from tailscale/tailscale
# to ensure a similar build process when testing the HEAD of tailscale. # to ensure a similar build process when testing the HEAD of tailscale.
FROM golang:1.25-alpine AS build-env FROM golang:1.26.0-alpine AS build-env
WORKDIR /go/src WORKDIR /go/src

View File

@@ -67,6 +67,8 @@ For NixOS users, a module is available in [`nix/`](./nix/).
## Talks ## Talks
- Fosdem 2026 (video): [Headscale & Tailscale: The complementary open source clone](https://fosdem.org/2026/schedule/event/KYQ3LL-headscale-the-complementary-open-source-clone/)
- presented by Kristoffer Dalby
- Fosdem 2023 (video): [Headscale: How we are using integration testing to reimplement Tailscale](https://fosdem.org/2023/schedule/event/goheadscale/) - Fosdem 2023 (video): [Headscale: How we are using integration testing to reimplement Tailscale](https://fosdem.org/2023/schedule/event/goheadscale/)
- presented by Juan Font Alonso and Kristoffer Dalby - presented by Juan Font Alonso and Kristoffer Dalby

View File

@@ -14,7 +14,7 @@ import (
) )
const ( const (
// 90 days. // DefaultAPIKeyExpiry is 90 days.
DefaultAPIKeyExpiry = "90d" DefaultAPIKeyExpiry = "90d"
) )
@@ -71,6 +71,7 @@ var listAPIKeys = &cobra.Command{
tableData := pterm.TableData{ tableData := pterm.TableData{
{"ID", "Prefix", "Expiration", "Created"}, {"ID", "Prefix", "Expiration", "Created"},
} }
for _, key := range response.GetApiKeys() { for _, key := range response.GetApiKeys() {
expiration := "-" expiration := "-"
@@ -84,8 +85,8 @@ var listAPIKeys = &cobra.Command{
expiration, expiration,
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat), key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
}) })
} }
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(

View File

@@ -16,7 +16,7 @@ var configTestCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
_, err := newHeadscaleServerWithConfig() _, err := newHeadscaleServerWithConfig()
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msg("Error initializing") log.Fatal().Caller().Err(err).Msg("error initializing")
} }
}, },
} }

View File

@@ -19,10 +19,12 @@ func init() {
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
createNodeCmd.Flags().StringP("name", "", "", "Name") createNodeCmd.Flags().StringP("name", "", "", "Name")
err := createNodeCmd.MarkFlagRequired("name") err := createNodeCmd.MarkFlagRequired("name")
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
createNodeCmd.Flags().StringP("user", "u", "", "User") createNodeCmd.Flags().StringP("user", "u", "", "User")
createNodeCmd.Flags().StringP("namespace", "n", "", "User") createNodeCmd.Flags().StringP("namespace", "n", "", "User")
@@ -34,11 +36,14 @@ func init() {
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
createNodeCmd.Flags().StringP("key", "k", "", "Key") createNodeCmd.Flags().StringP("key", "k", "", "Key")
err = createNodeCmd.MarkFlagRequired("key") err = createNodeCmd.MarkFlagRequired("key")
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
createNodeCmd.Flags(). createNodeCmd.Flags().
StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise") StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")

View File

@@ -15,6 +15,7 @@ var healthCmd = &cobra.Command{
Long: "Check the health of the Headscale server. This command will return an exit code of 0 if the server is healthy, or 1 if it is not.", Long: "Check the health of the Headscale server. This command will return an exit code of 0 if the server is healthy, or 1 if it is not.",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()

View File

@@ -1,8 +1,8 @@
package cli package cli
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/oauth2-proxy/mockoidc" "github.com/oauth2-proxy/mockoidc"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -19,6 +20,7 @@ const (
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined") errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined") errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined") errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined")
refreshTTL = 60 * time.Minute refreshTTL = 60 * time.Minute
) )
@@ -35,7 +37,7 @@ var mockOidcCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
err := mockOIDC() err := mockOIDC()
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Error running mock OIDC server") log.Error().Err(err).Msgf("error running mock OIDC server")
os.Exit(1) os.Exit(1)
} }
}, },
@@ -46,41 +48,47 @@ func mockOIDC() error {
if clientID == "" { if clientID == "" {
return errMockOidcClientIDNotDefined return errMockOidcClientIDNotDefined
} }
clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET") clientSecret := os.Getenv("MOCKOIDC_CLIENT_SECRET")
if clientSecret == "" { if clientSecret == "" {
return errMockOidcClientSecretNotDefined return errMockOidcClientSecretNotDefined
} }
addrStr := os.Getenv("MOCKOIDC_ADDR") addrStr := os.Getenv("MOCKOIDC_ADDR")
if addrStr == "" { if addrStr == "" {
return errMockOidcPortNotDefined return errMockOidcPortNotDefined
} }
portStr := os.Getenv("MOCKOIDC_PORT") portStr := os.Getenv("MOCKOIDC_PORT")
if portStr == "" { if portStr == "" {
return errMockOidcPortNotDefined return errMockOidcPortNotDefined
} }
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL") accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
if accessTTLOverride != "" { if accessTTLOverride != "" {
newTTL, err := time.ParseDuration(accessTTLOverride) newTTL, err := time.ParseDuration(accessTTLOverride)
if err != nil { if err != nil {
return err return err
} }
accessTTL = newTTL accessTTL = newTTL
} }
userStr := os.Getenv("MOCKOIDC_USERS") userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" { if userStr == "" {
return errors.New("MOCKOIDC_USERS not defined") return errMockOidcUsersNotDefined
} }
var users []mockoidc.MockUser var users []mockoidc.MockUser
err := json.Unmarshal([]byte(userStr), &users) err := json.Unmarshal([]byte(userStr), &users)
if err != nil { if err != nil {
return fmt.Errorf("unmarshalling users: %w", err) return fmt.Errorf("unmarshalling users: %w", err)
} }
log.Info().Interface("users", users).Msg("loading users from JSON") log.Info().Interface(zf.Users, users).Msg("loading users from JSON")
log.Info().Msgf("Access token TTL: %s", accessTTL) log.Info().Msgf("access token TTL: %s", accessTTL)
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
if err != nil { if err != nil {
@@ -92,7 +100,7 @@ func mockOIDC() error {
return err return err
} }
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", addrStr, port)) listener, err := new(net.ListenConfig).Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", addrStr, port))
if err != nil { if err != nil {
return err return err
} }
@@ -101,8 +109,10 @@ func mockOIDC() error {
if err != nil { if err != nil {
return err return err
} }
log.Info().Msgf("Mock OIDC server listening on %s", listener.Addr().String())
log.Info().Msgf("Issuer: %s", mock.Issuer()) log.Info().Msgf("mock OIDC server listening on %s", listener.Addr().String())
log.Info().Msgf("issuer: %s", mock.Issuer())
c := make(chan struct{}) c := make(chan struct{})
<-c <-c
@@ -133,12 +143,13 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
ErrorQueue: &mockoidc.ErrorQueue{}, ErrorQueue: &mockoidc.ErrorQueue{},
} }
mock.AddMiddleware(func(h http.Handler) http.Handler { _ = mock.AddMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Info().Msgf("Request: %+v", r) log.Info().Msgf("request: %+v", r)
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
if r.Response != nil { if r.Response != nil {
log.Info().Msgf("Response: %+v", r.Response) log.Info().Msgf("response: %+v", r.Response)
} }
}) })
}) })

View File

@@ -26,6 +26,7 @@ func init() {
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace") listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
listNodesNamespaceFlag.Hidden = true listNodesNamespaceFlag.Hidden = true
nodeCmd.AddCommand(listNodesCmd) nodeCmd.AddCommand(listNodesCmd)
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
@@ -42,42 +43,51 @@ func init() {
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
registerNodeCmd.Flags().StringP("key", "k", "", "Key") registerNodeCmd.Flags().StringP("key", "k", "", "Key")
err = registerNodeCmd.MarkFlagRequired("key") err = registerNodeCmd.MarkFlagRequired("key")
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
nodeCmd.AddCommand(registerNodeCmd) nodeCmd.AddCommand(registerNodeCmd)
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.")
err = expireNodeCmd.MarkFlagRequired("identifier") err = expireNodeCmd.MarkFlagRequired("identifier")
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
nodeCmd.AddCommand(expireNodeCmd) nodeCmd.AddCommand(expireNodeCmd)
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = renameNodeCmd.MarkFlagRequired("identifier") err = renameNodeCmd.MarkFlagRequired("identifier")
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
nodeCmd.AddCommand(renameNodeCmd) nodeCmd.AddCommand(renameNodeCmd)
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = deleteNodeCmd.MarkFlagRequired("identifier") err = deleteNodeCmd.MarkFlagRequired("identifier")
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
nodeCmd.AddCommand(deleteNodeCmd) nodeCmd.AddCommand(deleteNodeCmd)
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
tagCmd.MarkFlagRequired("identifier") _ = tagCmd.MarkFlagRequired("identifier")
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
nodeCmd.AddCommand(tagCmd) nodeCmd.AddCommand(tagCmd)
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
approveRoutesCmd.MarkFlagRequired("identifier") _ = approveRoutesCmd.MarkFlagRequired("identifier")
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
nodeCmd.AddCommand(approveRoutesCmd) nodeCmd.AddCommand(approveRoutesCmd)
@@ -95,6 +105,7 @@ var registerNodeCmd = &cobra.Command{
Short: "Registers a node to your network", Short: "Registers a node to your network",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
@@ -142,6 +153,7 @@ var listNodesCmd = &cobra.Command{
Aliases: []string{"ls", "show"}, Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetString("user") user, err := cmd.Flags().GetString("user")
if err != nil { if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
@@ -190,6 +202,7 @@ var listNodeRoutesCmd = &cobra.Command{
Aliases: []string{"lsr", "routes"}, Aliases: []string{"lsr", "routes"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
identifier, err := cmd.Flags().GetUint64("identifier") identifier, err := cmd.Flags().GetUint64("identifier")
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
@@ -233,10 +246,7 @@ var listNodeRoutesCmd = &cobra.Command{
return return
} }
tableData, err := nodeRoutesToPtables(nodes) tableData := nodeRoutesToPtables(nodes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
@@ -276,7 +286,9 @@ var expireNodeCmd = &cobra.Command{
return return
} }
now := time.Now() now := time.Now()
expiryTime := now expiryTime := now
if expiry != "" { if expiry != "" {
expiryTime, err = time.Parse(time.RFC3339, expiry) expiryTime, err = time.Parse(time.RFC3339, expiry)
@@ -343,6 +355,7 @@ var renameNodeCmd = &cobra.Command{
if len(args) > 0 { if len(args) > 0 {
newName = args[0] newName = args[0]
} }
request := &v1.RenameNodeRequest{ request := &v1.RenameNodeRequest{
NodeId: identifier, NodeId: identifier,
NewName: newName, NewName: newName,
@@ -402,6 +415,7 @@ var deleteNodeCmd = &cobra.Command{
} }
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
if !force { if !force {
confirm = util.YesNo(fmt.Sprintf( confirm = util.YesNo(fmt.Sprintf(
@@ -417,6 +431,7 @@ var deleteNodeCmd = &cobra.Command{
return return
} }
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
@@ -424,6 +439,7 @@ var deleteNodeCmd = &cobra.Command{
output, output,
) )
} }
SuccessOutput( SuccessOutput(
map[string]string{"Result": "Node deleted"}, map[string]string{"Result": "Node deleted"},
"Node deleted", "Node deleted",
@@ -506,15 +522,21 @@ func nodesToPtables(
ephemeral = true ephemeral = true
} }
var lastSeen time.Time var (
var lastSeenTime string lastSeen time.Time
lastSeenTime string
)
if node.GetLastSeen() != nil { if node.GetLastSeen() != nil {
lastSeen = node.GetLastSeen().AsTime() lastSeen = node.GetLastSeen().AsTime()
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05") lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
} }
var expiry time.Time var (
var expiryTime string expiry time.Time
expiryTime string
)
if node.GetExpiry() != nil { if node.GetExpiry() != nil {
expiry = node.GetExpiry().AsTime() expiry = node.GetExpiry().AsTime()
expiryTime = expiry.Format("2006-01-02 15:04:05") expiryTime = expiry.Format("2006-01-02 15:04:05")
@@ -523,6 +545,7 @@ func nodesToPtables(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText( err := machineKey.UnmarshalText(
[]byte(node.GetMachineKey()), []byte(node.GetMachineKey()),
) )
@@ -531,6 +554,7 @@ func nodesToPtables(
} }
var nodeKey key.NodePublic var nodeKey key.NodePublic
err = nodeKey.UnmarshalText( err = nodeKey.UnmarshalText(
[]byte(node.GetNodeKey()), []byte(node.GetNodeKey()),
) )
@@ -572,8 +596,11 @@ func nodesToPtables(
user = pterm.LightYellow(node.GetUser().GetName()) user = pterm.LightYellow(node.GetUser().GetName())
} }
var IPV4Address string var (
var IPV6Address string IPV4Address string
IPV6Address string
)
for _, addr := range node.GetIpAddresses() { for _, addr := range node.GetIpAddresses() {
if netip.MustParseAddr(addr).Is4() { if netip.MustParseAddr(addr).Is4() {
IPV4Address = addr IPV4Address = addr
@@ -608,7 +635,7 @@ func nodesToPtables(
func nodeRoutesToPtables( func nodeRoutesToPtables(
nodes []*v1.Node, nodes []*v1.Node,
) (pterm.TableData, error) { ) pterm.TableData {
tableHeader := []string{ tableHeader := []string{
"ID", "ID",
"Hostname", "Hostname",
@@ -632,7 +659,7 @@ func nodeRoutesToPtables(
) )
} }
return tableData, nil return tableData
} }
var tagCmd = &cobra.Command{ var tagCmd = &cobra.Command{
@@ -641,6 +668,7 @@ var tagCmd = &cobra.Command{
Aliases: []string{"tags", "t"}, Aliases: []string{"tags", "t"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@@ -654,6 +682,7 @@ var tagCmd = &cobra.Command{
output, output,
) )
} }
tagsToSet, err := cmd.Flags().GetStringSlice("tags") tagsToSet, err := cmd.Flags().GetStringSlice("tags")
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
@@ -668,6 +697,7 @@ var tagCmd = &cobra.Command{
NodeId: identifier, NodeId: identifier,
Tags: tagsToSet, Tags: tagsToSet,
} }
resp, err := client.SetTags(ctx, request) resp, err := client.SetTags(ctx, request)
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
@@ -692,6 +722,7 @@ var approveRoutesCmd = &cobra.Command{
Short: "Manage the approved routes of a node", Short: "Manage the approved routes of a node",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
@@ -705,6 +736,7 @@ var approveRoutesCmd = &cobra.Command{
output, output,
) )
} }
routes, err := cmd.Flags().GetStringSlice("routes") routes, err := cmd.Flags().GetStringSlice("routes")
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
@@ -719,6 +751,7 @@ var approveRoutesCmd = &cobra.Command{
NodeId: identifier, NodeId: identifier,
Routes: routes, Routes: routes,
} }
resp, err := client.SetApprovedRoutes(ctx, request) resp, err := client.SetApprovedRoutes(ctx, request)
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(

View File

@@ -16,7 +16,7 @@ import (
) )
const ( const (
bypassFlag = "bypass-grpc-and-access-database-directly" bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // not a credential
) )
func init() { func init() {
@@ -26,16 +26,22 @@ func init() {
policyCmd.AddCommand(getPolicy) policyCmd.AddCommand(getPolicy)
setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
if err := setPolicy.MarkFlagRequired("file"); err != nil {
err := setPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
policyCmd.AddCommand(setPolicy) policyCmd.AddCommand(setPolicy)
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
err = checkPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("") log.Fatal().Err(err).Msg("")
} }
policyCmd.AddCommand(checkPolicy) policyCmd.AddCommand(checkPolicy)
} }
@@ -50,9 +56,12 @@ var getPolicy = &cobra.Command{
Aliases: []string{"show", "view", "fetch"}, Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") output, _ := cmd.Flags().GetString("output")
var policy string var policy string
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
if !force { if !force {
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
@@ -128,6 +137,7 @@ var setPolicy = &cobra.Command{
if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass {
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
if !force { if !force {
confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?")
@@ -173,7 +183,7 @@ var setPolicy = &cobra.Command{
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
if _, err := client.SetPolicy(ctx, request); err != nil { if _, err := client.SetPolicy(ctx, request); err != nil { //nolint:noinlineerr
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
} }
} }

View File

@@ -80,6 +80,7 @@ var listPreAuthKeys = &cobra.Command{
"Owner", "Owner",
}, },
} }
for _, key := range response.GetPreAuthKeys() { for _, key := range response.GetPreAuthKeys() {
expiration := "-" expiration := "-"
if key.GetExpiration() != nil { if key.GetExpiration() != nil {
@@ -105,8 +106,8 @@ var listPreAuthKeys = &cobra.Command{
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
owner, owner,
}) })
} }
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(

View File

@@ -45,15 +45,16 @@ func initConfig() {
if cfgFile == "" { if cfgFile == "" {
cfgFile = os.Getenv("HEADSCALE_CONFIG") cfgFile = os.Getenv("HEADSCALE_CONFIG")
} }
if cfgFile != "" { if cfgFile != "" {
err := types.LoadConfig(cfgFile, true) err := types.LoadConfig(cfgFile, true)
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile) log.Fatal().Caller().Err(err).Msgf("error loading config file %s", cfgFile)
} }
} else { } else {
err := types.LoadConfig("", false) err := types.LoadConfig("", false)
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msgf("Error loading config") log.Fatal().Caller().Err(err).Msgf("error loading config")
} }
} }
@@ -80,6 +81,7 @@ func initConfig() {
Repository: "headscale", Repository: "headscale",
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }), TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
} }
res, err := latest.Check(githubTag, versionInfo.Version) res, err := latest.Check(githubTag, versionInfo.Version)
if err == nil && res.Outdated { if err == nil && res.Outdated {
//nolint //nolint
@@ -101,6 +103,7 @@ func isPreReleaseVersion(version string) bool {
return true return true
} }
} }
return false return false
} }
@@ -140,7 +143,8 @@ https://github.com/juanfont/headscale`,
} }
func Execute() { func Execute() {
if err := rootCmd.Execute(); err != nil { err := rootCmd.Execute()
if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) os.Exit(1)
} }

View File

@@ -23,18 +23,17 @@ var serveCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
app, err := newHeadscaleServerWithConfig() app, err := newHeadscaleServerWithConfig()
if err != nil { if err != nil {
var squibbleErr squibble.ValidationError if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok {
if errors.As(err, &squibbleErr) {
fmt.Printf("SQLite schema failed to validate:\n") fmt.Printf("SQLite schema failed to validate:\n")
fmt.Println(squibbleErr.Diff) fmt.Println(squibbleErr.Diff)
} }
log.Fatal().Caller().Err(err).Msg("Error initializing") log.Fatal().Caller().Err(err).Msg("error initializing")
} }
err = app.Serve() err = app.Serve()
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.") log.Fatal().Caller().Err(err).Msg("headscale ran into an error and had to shut down")
} }
}, },
} }

View File

@@ -8,12 +8,19 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/pterm/pterm" "github.com/pterm/pterm"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
// CLI user errors.
var (
errFlagRequired = errors.New("--name or --identifier flag is required")
errMultipleUsersMatch = errors.New("multiple users match query, specify an ID")
)
func usernameAndIDFlag(cmd *cobra.Command) { func usernameAndIDFlag(cmd *cobra.Command) {
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)") cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
cmd.Flags().StringP("name", "n", "", "Username") cmd.Flags().StringP("name", "n", "", "Username")
@@ -23,12 +30,12 @@ func usernameAndIDFlag(cmd *cobra.Command) {
// If both are empty, it will exit the program with an error. // If both are empty, it will exit the program with an error.
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
username, _ := cmd.Flags().GetString("name") username, _ := cmd.Flags().GetString("name")
identifier, _ := cmd.Flags().GetInt64("identifier") identifier, _ := cmd.Flags().GetInt64("identifier")
if username == "" && identifier < 0 { if username == "" && identifier < 0 {
err := errors.New("--name or --identifier flag is required")
ErrorOutput( ErrorOutput(
err, errFlagRequired,
"Cannot rename user: "+status.Convert(err).Message(), "Cannot rename user: "+status.Convert(errFlagRequired).Message(),
"", "",
) )
} }
@@ -50,7 +57,8 @@ func init() {
userCmd.AddCommand(renameUserCmd) userCmd.AddCommand(renameUserCmd)
usernameAndIDFlag(renameUserCmd) usernameAndIDFlag(renameUserCmd)
renameUserCmd.Flags().StringP("new-name", "r", "", "New username") renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
renameNodeCmd.MarkFlagRequired("new-name")
_ = renameNodeCmd.MarkFlagRequired("new-name")
} }
var errMissingParameter = errors.New("missing parameters") var errMissingParameter = errors.New("missing parameters")
@@ -81,7 +89,7 @@ var createUserCmd = &cobra.Command{
defer cancel() defer cancel()
defer conn.Close() defer conn.Close()
log.Trace().Interface("client", client).Msg("Obtained gRPC client") log.Trace().Interface(zf.Client, client).Msg("obtained gRPC client")
request := &v1.CreateUserRequest{Name: userName} request := &v1.CreateUserRequest{Name: userName}
@@ -94,7 +102,7 @@ var createUserCmd = &cobra.Command{
} }
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" { if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
if _, err := url.Parse(pictureURL); err != nil { if _, err := url.Parse(pictureURL); err != nil { //nolint:noinlineerr
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf( fmt.Sprintf(
@@ -104,10 +112,12 @@ var createUserCmd = &cobra.Command{
output, output,
) )
} }
request.PictureUrl = pictureURL request.PictureUrl = pictureURL
} }
log.Trace().Interface("request", request).Msg("Sending CreateUser request") log.Trace().Interface(zf.Request, request).Msg("sending CreateUser request")
response, err := client.CreateUser(ctx, request) response, err := client.CreateUser(ctx, request)
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
@@ -148,7 +158,7 @@ var destroyUserCmd = &cobra.Command{
} }
if len(users.GetUsers()) != 1 { if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") err := errMultipleUsersMatch
ErrorOutput( ErrorOutput(
err, err,
"Error: "+status.Convert(err).Message(), "Error: "+status.Convert(err).Message(),
@@ -159,6 +169,7 @@ var destroyUserCmd = &cobra.Command{
user := users.GetUsers()[0] user := users.GetUsers()[0]
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force") force, _ := cmd.Flags().GetBool("force")
if !force { if !force {
confirm = util.YesNo(fmt.Sprintf( confirm = util.YesNo(fmt.Sprintf(
@@ -178,6 +189,7 @@ var destroyUserCmd = &cobra.Command{
output, output,
) )
} }
SuccessOutput(response, "User destroyed", output) SuccessOutput(response, "User destroyed", output)
} else { } else {
SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output) SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output)
@@ -238,6 +250,7 @@ var listUsersCmd = &cobra.Command{
}, },
) )
} }
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
@@ -276,7 +289,7 @@ var renameUserCmd = &cobra.Command{
} }
if len(users.GetUsers()) != 1 { if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") err := errMultipleUsersMatch
ErrorOutput( ErrorOutput(
err, err,
"Error: "+status.Convert(err).Message(), "Error: "+status.Convert(err).Message(),

View File

@@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol" "github.com/juanfont/headscale/hscontrol"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
@@ -57,7 +58,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout) ctx, cancel := context.WithTimeout(context.Background(), cfg.CLI.Timeout)
grpcOptions := []grpc.DialOption{ grpcOptions := []grpc.DialOption{
grpc.WithBlock(), grpc.WithBlock(), //nolint:staticcheck // SA1019: deprecated but supported in 1.x
} }
address := cfg.CLI.Address address := cfg.CLI.Address
@@ -81,6 +82,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
Msgf("Unable to read/write to headscale socket, do you have the correct permissions?") Msgf("Unable to read/write to headscale socket, do you have the correct permissions?")
} }
} }
socket.Close() socket.Close()
grpcOptions = append( grpcOptions = append(
@@ -92,8 +94,9 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
// If we are not connecting to a local server, require an API key for authentication // If we are not connecting to a local server, require an API key for authentication
apiKey := cfg.CLI.APIKey apiKey := cfg.CLI.APIKey
if apiKey == "" { if apiKey == "" {
log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set.") log.Fatal().Caller().Msgf("HEADSCALE_CLI_API_KEY environment variable needs to be set")
} }
grpcOptions = append(grpcOptions, grpcOptions = append(grpcOptions,
grpc.WithPerRPCCredentials(tokenAuth{ grpc.WithPerRPCCredentials(tokenAuth{
token: apiKey, token: apiKey,
@@ -118,10 +121,11 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
} }
} }
log.Trace().Caller().Str("address", address).Msg("Connecting via gRPC") log.Trace().Caller().Str(zf.Address, address).Msg("connecting via gRPC")
conn, err := grpc.DialContext(ctx, address, grpcOptions...)
conn, err := grpc.DialContext(ctx, address, grpcOptions...) //nolint:staticcheck // SA1019: deprecated but supported in 1.x
if err != nil { if err != nil {
log.Fatal().Caller().Err(err).Msgf("Could not connect: %v", err) log.Fatal().Caller().Err(err).Msgf("could not connect: %v", err)
os.Exit(-1) // we get here if logging is suppressed (i.e., json output) os.Exit(-1) // we get here if logging is suppressed (i.e., json output)
} }
@@ -131,23 +135,26 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
} }
func output(result any, override string, outputFormat string) string { func output(result any, override string, outputFormat string) string {
var jsonBytes []byte var (
var err error jsonBytes []byte
err error
)
switch outputFormat { switch outputFormat {
case "json": case "json":
jsonBytes, err = json.MarshalIndent(result, "", "\t") jsonBytes, err = json.MarshalIndent(result, "", "\t")
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed to unmarshal output") log.Fatal().Err(err).Msg("unmarshalling output")
} }
case "json-line": case "json-line":
jsonBytes, err = json.Marshal(result) jsonBytes, err = json.Marshal(result)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed to unmarshal output") log.Fatal().Err(err).Msg("unmarshalling output")
} }
case "yaml": case "yaml":
jsonBytes, err = yaml.Marshal(result) jsonBytes, err = yaml.Marshal(result)
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed to unmarshal output") log.Fatal().Err(err).Msg("unmarshalling output")
} }
default: default:
// nolint // nolint

View File

@@ -12,6 +12,7 @@ import (
func main() { func main() {
var colors bool var colors bool
switch l := termcolor.SupportLevel(os.Stderr); l { switch l := termcolor.SupportLevel(os.Stderr); l {
case termcolor.Level16M: case termcolor.Level16M:
colors = true colors = true

View File

@@ -14,9 +14,7 @@ import (
) )
func TestConfigFileLoading(t *testing.T) { func TestConfigFileLoading(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale") tmpDir := t.TempDir()
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
path, err := os.Getwd() path, err := os.Getwd()
require.NoError(t, err) require.NoError(t, err)
@@ -48,9 +46,7 @@ func TestConfigFileLoading(t *testing.T) {
} }
func TestConfigLoading(t *testing.T) { func TestConfigLoading(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale") tmpDir := t.TempDir()
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
path, err := os.Getwd() path, err := os.Getwd()
require.NoError(t, err) require.NoError(t, err)

View File

@@ -22,11 +22,11 @@ import (
func cleanupBeforeTest(ctx context.Context) error { func cleanupBeforeTest(ctx context.Context) error {
err := cleanupStaleTestContainers(ctx) err := cleanupStaleTestContainers(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to clean stale test containers: %w", err) return fmt.Errorf("cleaning stale test containers: %w", err)
} }
if err := pruneDockerNetworks(ctx); err != nil { if err := pruneDockerNetworks(ctx); err != nil { //nolint:noinlineerr
return fmt.Errorf("failed to prune networks: %w", err) return fmt.Errorf("pruning networks: %w", err)
} }
return nil return nil
@@ -39,14 +39,14 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
Force: true, Force: true,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to remove test container: %w", err) return fmt.Errorf("removing test container: %w", err)
} }
// Clean up integration test containers for this run only // Clean up integration test containers for this run only
if runID != "" { if runID != "" {
err := killTestContainersByRunID(ctx, runID) err := killTestContainersByRunID(ctx, runID)
if err != nil { if err != nil {
return fmt.Errorf("failed to clean up containers for run %s: %w", runID, err) return fmt.Errorf("cleaning up containers for run %s: %w", runID, err)
} }
} }
@@ -55,9 +55,9 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
// killTestContainers terminates and removes all test containers. // killTestContainers terminates and removes all test containers.
func killTestContainers(ctx context.Context) error { func killTestContainers(ctx context.Context) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
@@ -65,12 +65,14 @@ func killTestContainers(ctx context.Context) error {
All: true, All: true,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to list containers: %w", err) return fmt.Errorf("listing containers: %w", err)
} }
removed := 0 removed := 0
for _, cont := range containers { for _, cont := range containers {
shouldRemove := false shouldRemove := false
for _, name := range cont.Names { for _, name := range cont.Names {
if strings.Contains(name, "headscale-test-suite") || if strings.Contains(name, "headscale-test-suite") ||
strings.Contains(name, "hs-") || strings.Contains(name, "hs-") ||
@@ -107,9 +109,9 @@ func killTestContainers(ctx context.Context) error {
// This function filters containers by the hi.run-id label to only affect containers // This function filters containers by the hi.run-id label to only affect containers
// belonging to the specified test run, leaving other concurrent test runs untouched. // belonging to the specified test run, leaving other concurrent test runs untouched.
func killTestContainersByRunID(ctx context.Context, runID string) error { func killTestContainersByRunID(ctx context.Context, runID string) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
@@ -121,7 +123,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
), ),
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to list containers for run %s: %w", runID, err) return fmt.Errorf("listing containers for run %s: %w", runID, err)
} }
removed := 0 removed := 0
@@ -149,9 +151,9 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs // This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
// without interfering with currently running concurrent tests. // without interfering with currently running concurrent tests.
func cleanupStaleTestContainers(ctx context.Context) error { func cleanupStaleTestContainers(ctx context.Context) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
@@ -164,7 +166,7 @@ func cleanupStaleTestContainers(ctx context.Context) error {
), ),
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to list stopped containers: %w", err) return fmt.Errorf("listing stopped containers: %w", err)
} }
removed := 0 removed := 0
@@ -223,15 +225,15 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container
// pruneDockerNetworks removes unused Docker networks. // pruneDockerNetworks removes unused Docker networks.
func pruneDockerNetworks(ctx context.Context) error { func pruneDockerNetworks(ctx context.Context) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
report, err := cli.NetworksPrune(ctx, filters.Args{}) report, err := cli.NetworksPrune(ctx, filters.Args{})
if err != nil { if err != nil {
return fmt.Errorf("failed to prune networks: %w", err) return fmt.Errorf("pruning networks: %w", err)
} }
if len(report.NetworksDeleted) > 0 { if len(report.NetworksDeleted) > 0 {
@@ -245,9 +247,9 @@ func pruneDockerNetworks(ctx context.Context) error {
// cleanOldImages removes test-related and old dangling Docker images. // cleanOldImages removes test-related and old dangling Docker images.
func cleanOldImages(ctx context.Context) error { func cleanOldImages(ctx context.Context) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
@@ -255,12 +257,14 @@ func cleanOldImages(ctx context.Context) error {
All: true, All: true,
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to list images: %w", err) return fmt.Errorf("listing images: %w", err)
} }
removed := 0 removed := 0
for _, img := range images { for _, img := range images {
shouldRemove := false shouldRemove := false
for _, tag := range img.RepoTags { for _, tag := range img.RepoTags {
if strings.Contains(tag, "hs-") || if strings.Contains(tag, "hs-") ||
strings.Contains(tag, "headscale-integration") || strings.Contains(tag, "headscale-integration") ||
@@ -295,18 +299,19 @@ func cleanOldImages(ctx context.Context) error {
// cleanCacheVolume removes the Docker volume used for Go module cache. // cleanCacheVolume removes the Docker volume used for Go module cache.
func cleanCacheVolume(ctx context.Context) error { func cleanCacheVolume(ctx context.Context) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
volumeName := "hs-integration-go-cache" volumeName := "hs-integration-go-cache"
err = cli.VolumeRemove(ctx, volumeName, true) err = cli.VolumeRemove(ctx, volumeName, true)
if err != nil { if err != nil {
if errdefs.IsNotFound(err) { if errdefs.IsNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
fmt.Printf("Go module cache volume not found: %s\n", volumeName) fmt.Printf("Go module cache volume not found: %s\n", volumeName)
} else if errdefs.IsConflict(err) { } else if errdefs.IsConflict(err) { //nolint:staticcheck // SA1019: deprecated but functional
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName) fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
} else { } else {
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err) fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)
@@ -330,7 +335,7 @@ func cleanCacheVolume(ctx context.Context) error {
func cleanupSuccessfulTestArtifacts(logsDir string, verbose bool) error { func cleanupSuccessfulTestArtifacts(logsDir string, verbose bool) error {
entries, err := os.ReadDir(logsDir) entries, err := os.ReadDir(logsDir)
if err != nil { if err != nil {
return fmt.Errorf("failed to read logs directory: %w", err) return fmt.Errorf("reading logs directory: %w", err)
} }
var ( var (

View File

@@ -22,17 +22,22 @@ import (
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
) )
const defaultDirPerm = 0o755
var ( var (
ErrTestFailed = errors.New("test failed") ErrTestFailed = errors.New("test failed")
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait") ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
ErrNoDockerContext = errors.New("no docker context found") ErrNoDockerContext = errors.New("no docker context found")
ErrMemoryLimitViolations = errors.New("container(s) exceeded memory limits")
) )
// runTestContainer executes integration tests in a Docker container. // runTestContainer executes integration tests in a Docker container.
//
//nolint:gocyclo // complex test orchestration function
func runTestContainer(ctx context.Context, config *RunConfig) error { func runTestContainer(ctx context.Context, config *RunConfig) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
@@ -48,19 +53,21 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
absLogsDir, err := filepath.Abs(logsDir) absLogsDir, err := filepath.Abs(logsDir)
if err != nil { if err != nil {
return fmt.Errorf("failed to get absolute path for logs directory: %w", err) return fmt.Errorf("getting absolute path for logs directory: %w", err)
} }
const dirPerm = 0o755 const dirPerm = 0o755
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { //nolint:noinlineerr
return fmt.Errorf("failed to create logs directory: %w", err) return fmt.Errorf("creating logs directory: %w", err)
} }
if config.CleanBefore { if config.CleanBefore {
if config.Verbose { if config.Verbose {
log.Printf("Running pre-test cleanup...") log.Printf("Running pre-test cleanup...")
} }
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
err := cleanupBeforeTest(ctx)
if err != nil && config.Verbose {
log.Printf("Warning: pre-test cleanup failed: %v", err) log.Printf("Warning: pre-test cleanup failed: %v", err)
} }
} }
@@ -71,21 +78,21 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
} }
imageName := "golang:" + config.GoVersion imageName := "golang:" + config.GoVersion
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { //nolint:noinlineerr
return fmt.Errorf("failed to ensure image availability: %w", err) return fmt.Errorf("ensuring image availability: %w", err)
} }
resp, err := createGoTestContainer(ctx, cli, config, containerName, absLogsDir, goTestCmd) resp, err := createGoTestContainer(ctx, cli, config, containerName, absLogsDir, goTestCmd)
if err != nil { if err != nil {
return fmt.Errorf("failed to create container: %w", err) return fmt.Errorf("creating container: %w", err)
} }
if config.Verbose { if config.Verbose {
log.Printf("Created container: %s", resp.ID) log.Printf("Created container: %s", resp.ID)
} }
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { //nolint:noinlineerr
return fmt.Errorf("failed to start container: %w", err) return fmt.Errorf("starting container: %w", err)
} }
log.Printf("Starting test: %s", config.TestPattern) log.Printf("Starting test: %s", config.TestPattern)
@@ -95,13 +102,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Start stats collection for container resource monitoring (if enabled) // Start stats collection for container resource monitoring (if enabled)
var statsCollector *StatsCollector var statsCollector *StatsCollector
if config.Stats { if config.Stats {
var err error var err error
statsCollector, err = NewStatsCollector()
statsCollector, err = NewStatsCollector(ctx)
if err != nil { if err != nil {
if config.Verbose { if config.Verbose {
log.Printf("Warning: failed to create stats collector: %v", err) log.Printf("Warning: failed to create stats collector: %v", err)
} }
statsCollector = nil statsCollector = nil
} }
@@ -110,7 +120,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Start stats collection immediately - no need for complex retry logic // Start stats collection immediately - no need for complex retry logic
// The new implementation monitors Docker events and will catch containers as they start // The new implementation monitors Docker events and will catch containers as they start
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil { err := statsCollector.StartCollection(ctx, runID, config.Verbose)
if err != nil {
if config.Verbose { if config.Verbose {
log.Printf("Warning: failed to start stats collection: %v", err) log.Printf("Warning: failed to start stats collection: %v", err)
} }
@@ -122,12 +133,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
exitCode, err := streamAndWait(ctx, cli, resp.ID) exitCode, err := streamAndWait(ctx, cli, resp.ID)
// Ensure all containers have finished and logs are flushed before extracting artifacts // Ensure all containers have finished and logs are flushed before extracting artifacts
if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose { waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose)
if waitErr != nil && config.Verbose {
log.Printf("Warning: failed to wait for container finalization: %v", waitErr) log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
} }
// Extract artifacts from test containers before cleanup // Extract artifacts from test containers before cleanup
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { //nolint:noinlineerr
log.Printf("Warning: failed to extract artifacts from containers: %v", err) log.Printf("Warning: failed to extract artifacts from containers: %v", err)
} }
@@ -140,12 +152,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
if len(violations) > 0 { if len(violations) > 0 {
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
log.Printf("=================================") log.Printf("=================================")
for _, violation := range violations { for _, violation := range violations {
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
} }
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) return fmt.Errorf("test failed: %d %w", len(violations), ErrMemoryLimitViolations)
} }
} }
@@ -176,7 +189,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
} }
if err != nil { if err != nil {
return fmt.Errorf("test execution failed: %w", err) return fmt.Errorf("executing test: %w", err)
} }
if exitCode != 0 { if exitCode != 0 {
@@ -210,7 +223,7 @@ func buildGoTestCommand(config *RunConfig) []string {
func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunConfig, containerName, logsDir string, goTestCmd []string) (container.CreateResponse, error) { func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunConfig, containerName, logsDir string, goTestCmd []string) (container.CreateResponse, error) {
pwd, err := os.Getwd() pwd, err := os.Getwd()
if err != nil { if err != nil {
return container.CreateResponse{}, fmt.Errorf("failed to get working directory: %w", err) return container.CreateResponse{}, fmt.Errorf("getting working directory: %w", err)
} }
projectRoot := findProjectRoot(pwd) projectRoot := findProjectRoot(pwd)
@@ -312,7 +325,7 @@ func streamAndWait(ctx context.Context, cli *client.Client, containerID string)
Follow: true, Follow: true,
}) })
if err != nil { if err != nil {
return -1, fmt.Errorf("failed to get container logs: %w", err) return -1, fmt.Errorf("getting container logs: %w", err)
} }
defer out.Close() defer out.Close()
@@ -324,7 +337,7 @@ func streamAndWait(ctx context.Context, cli *client.Client, containerID string)
select { select {
case err := <-errCh: case err := <-errCh:
if err != nil { if err != nil {
return -1, fmt.Errorf("error waiting for container: %w", err) return -1, fmt.Errorf("waiting for container: %w", err)
} }
case status := <-statusCh: case status := <-statusCh:
return int(status.StatusCode), nil return int(status.StatusCode), nil
@@ -338,7 +351,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
// First, get all related test containers // First, get all related test containers
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
if err != nil { if err != nil {
return fmt.Errorf("failed to list containers: %w", err) return fmt.Errorf("listing containers: %w", err)
} }
testContainers := getCurrentTestContainers(containers, testContainerID, verbose) testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
@@ -347,6 +360,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
maxWaitTime := 10 * time.Second maxWaitTime := 10 * time.Second
checkInterval := 500 * time.Millisecond checkInterval := 500 * time.Millisecond
timeout := time.After(maxWaitTime) timeout := time.After(maxWaitTime)
ticker := time.NewTicker(checkInterval) ticker := time.NewTicker(checkInterval)
defer ticker.Stop() defer ticker.Stop()
@@ -356,6 +370,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose { if verbose {
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction") log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
} }
return nil return nil
case <-ticker.C: case <-ticker.C:
allFinalized := true allFinalized := true
@@ -366,12 +381,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose { if verbose {
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err) log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
} }
continue continue
} }
// Check if container is in a final state // Check if container is in a final state
if !isContainerFinalized(inspect.State) { if !isContainerFinalized(inspect.State) {
allFinalized = false allFinalized = false
if verbose { if verbose {
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
} }
@@ -384,6 +401,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose { if verbose {
log.Printf("All test containers finalized, ready for artifact extraction") log.Printf("All test containers finalized, ready for artifact extraction")
} }
return nil return nil
} }
} }
@@ -400,13 +418,15 @@ func isContainerFinalized(state *container.State) bool {
func findProjectRoot(startPath string) string { func findProjectRoot(startPath string) string {
current := startPath current := startPath
for { for {
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { //nolint:noinlineerr
return current return current
} }
parent := filepath.Dir(current) parent := filepath.Dir(current)
if parent == current { if parent == current {
return startPath return startPath
} }
current = parent current = parent
} }
} }
@@ -416,6 +436,7 @@ func boolToInt(b bool) int {
if b { if b {
return 1 return 1
} }
return 0 return 0
} }
@@ -428,13 +449,14 @@ type DockerContext struct {
} }
// createDockerClient creates a Docker client with context detection. // createDockerClient creates a Docker client with context detection.
func createDockerClient() (*client.Client, error) { func createDockerClient(ctx context.Context) (*client.Client, error) {
contextInfo, err := getCurrentDockerContext() contextInfo, err := getCurrentDockerContext(ctx)
if err != nil { if err != nil {
return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
} }
var clientOpts []client.Opt var clientOpts []client.Opt
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation()) clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
if contextInfo != nil { if contextInfo != nil {
@@ -444,6 +466,7 @@ func createDockerClient() (*client.Client, error) {
if runConfig.Verbose { if runConfig.Verbose {
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
} }
clientOpts = append(clientOpts, client.WithHost(host)) clientOpts = append(clientOpts, client.WithHost(host))
} }
} }
@@ -458,16 +481,17 @@ func createDockerClient() (*client.Client, error) {
} }
// getCurrentDockerContext retrieves the current Docker context information. // getCurrentDockerContext retrieves the current Docker context information.
func getCurrentDockerContext() (*DockerContext, error) { func getCurrentDockerContext(ctx context.Context) (*DockerContext, error) {
cmd := exec.Command("docker", "context", "inspect") cmd := exec.CommandContext(ctx, "docker", "context", "inspect")
output, err := cmd.Output() output, err := cmd.Output()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get docker context: %w", err) return nil, fmt.Errorf("getting docker context: %w", err)
} }
var contexts []DockerContext var contexts []DockerContext
if err := json.Unmarshal(output, &contexts); err != nil { if err := json.Unmarshal(output, &contexts); err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("failed to parse docker context: %w", err) return nil, fmt.Errorf("parsing docker context: %w", err)
} }
if len(contexts) > 0 { if len(contexts) > 0 {
@@ -486,12 +510,13 @@ func getDockerSocketPath() string {
// checkImageAvailableLocally checks if the specified Docker image is available locally. // checkImageAvailableLocally checks if the specified Docker image is available locally.
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) { func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
_, _, err := cli.ImageInspectWithRaw(ctx, imageName) _, _, err := cli.ImageInspectWithRaw(ctx, imageName) //nolint:staticcheck // SA1019: deprecated but functional
if err != nil { if err != nil {
if client.IsErrNotFound(err) { if client.IsErrNotFound(err) { //nolint:staticcheck // SA1019: deprecated but functional
return false, nil return false, nil
} }
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
return false, fmt.Errorf("inspecting image %s: %w", imageName, err)
} }
return true, nil return true, nil
@@ -502,13 +527,14 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
// First check if image is available locally // First check if image is available locally
available, err := checkImageAvailableLocally(ctx, cli, imageName) available, err := checkImageAvailableLocally(ctx, cli, imageName)
if err != nil { if err != nil {
return fmt.Errorf("failed to check local image availability: %w", err) return fmt.Errorf("checking local image availability: %w", err)
} }
if available { if available {
if verbose { if verbose {
log.Printf("Image %s is available locally", imageName) log.Printf("Image %s is available locally", imageName)
} }
return nil return nil
} }
@@ -519,20 +545,21 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{}) reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
if err != nil { if err != nil {
return fmt.Errorf("failed to pull image %s: %w", imageName, err) return fmt.Errorf("pulling image %s: %w", imageName, err)
} }
defer reader.Close() defer reader.Close()
if verbose { if verbose {
_, err = io.Copy(os.Stdout, reader) _, err = io.Copy(os.Stdout, reader)
if err != nil { if err != nil {
return fmt.Errorf("failed to read pull output: %w", err) return fmt.Errorf("reading pull output: %w", err)
} }
} else { } else {
_, err = io.Copy(io.Discard, reader) _, err = io.Copy(io.Discard, reader)
if err != nil { if err != nil {
return fmt.Errorf("failed to read pull output: %w", err) return fmt.Errorf("reading pull output: %w", err)
} }
log.Printf("Image %s pulled successfully", imageName) log.Printf("Image %s pulled successfully", imageName)
} }
@@ -547,9 +574,11 @@ func listControlFiles(logsDir string) {
return return
} }
var logFiles []string var (
var dataFiles []string logFiles []string
var dataDirs []string dataFiles []string
dataDirs []string
)
for _, entry := range entries { for _, entry := range entries {
name := entry.Name() name := entry.Name()
@@ -578,6 +607,7 @@ func listControlFiles(logsDir string) {
if len(logFiles) > 0 { if len(logFiles) > 0 {
log.Printf("Headscale logs:") log.Printf("Headscale logs:")
for _, file := range logFiles { for _, file := range logFiles {
log.Printf(" %s", file) log.Printf(" %s", file)
} }
@@ -585,9 +615,11 @@ func listControlFiles(logsDir string) {
if len(dataFiles) > 0 || len(dataDirs) > 0 { if len(dataFiles) > 0 || len(dataDirs) > 0 {
log.Printf("Headscale data:") log.Printf("Headscale data:")
for _, file := range dataFiles { for _, file := range dataFiles {
log.Printf(" %s", file) log.Printf(" %s", file)
} }
for _, dir := range dataDirs { for _, dir := range dataDirs {
log.Printf(" %s/", dir) log.Printf(" %s/", dir)
} }
@@ -596,25 +628,27 @@ func listControlFiles(logsDir string) {
// extractArtifactsFromContainers collects container logs and files from the specific test run. // extractArtifactsFromContainers collects container logs and files from the specific test run.
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error { func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err) return fmt.Errorf("creating Docker client: %w", err)
} }
defer cli.Close() defer cli.Close()
// List all containers // List all containers
containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) containers, err := cli.ContainerList(ctx, container.ListOptions{All: true})
if err != nil { if err != nil {
return fmt.Errorf("failed to list containers: %w", err) return fmt.Errorf("listing containers: %w", err)
} }
// Get containers from the specific test run // Get containers from the specific test run
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose) currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
extractedCount := 0 extractedCount := 0
for _, cont := range currentTestContainers { for _, cont := range currentTestContainers {
// Extract container logs and tar files // Extract container logs and tar files
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose)
if err != nil {
if verbose { if verbose {
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err) log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
} }
@@ -622,6 +656,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
if verbose { if verbose {
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12]) log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
} }
extractedCount++ extractedCount++
} }
} }
@@ -645,11 +680,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
// Find the test container to get its run ID label // Find the test container to get its run ID label
var runID string var runID string
for _, cont := range containers { for _, cont := range containers {
if cont.ID == testContainerID { if cont.ID == testContainerID {
if cont.Labels != nil { if cont.Labels != nil {
runID = cont.Labels["hi.run-id"] runID = cont.Labels["hi.run-id"]
} }
break break
} }
} }
@@ -690,18 +727,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
// extractContainerArtifacts saves logs and tar files from a container. // extractContainerArtifacts saves logs and tar files from a container.
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Ensure the logs directory exists // Ensure the logs directory exists
if err := os.MkdirAll(logsDir, 0o755); err != nil { err := os.MkdirAll(logsDir, defaultDirPerm)
return fmt.Errorf("failed to create logs directory: %w", err) if err != nil {
return fmt.Errorf("creating logs directory: %w", err)
} }
// Extract container logs // Extract container logs
if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { err = extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose)
return fmt.Errorf("failed to extract logs: %w", err) if err != nil {
return fmt.Errorf("extracting logs: %w", err)
} }
// Extract tar files for headscale containers only // Extract tar files for headscale containers only
if strings.HasPrefix(containerName, "hs-") { if strings.HasPrefix(containerName, "hs-") {
if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose)
if err != nil {
if verbose { if verbose {
log.Printf("Warning: failed to extract files from %s: %v", containerName, err) log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
} }
@@ -723,7 +763,7 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
Tail: "all", Tail: "all",
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to get container logs: %w", err) return fmt.Errorf("getting container logs: %w", err)
} }
defer logReader.Close() defer logReader.Close()
@@ -737,17 +777,17 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
// Demultiplex the Docker logs stream to separate stdout and stderr // Demultiplex the Docker logs stream to separate stdout and stderr
_, err = stdcopy.StdCopy(&stdoutBuf, &stderrBuf, logReader) _, err = stdcopy.StdCopy(&stdoutBuf, &stderrBuf, logReader)
if err != nil { if err != nil {
return fmt.Errorf("failed to demultiplex container logs: %w", err) return fmt.Errorf("demultiplexing container logs: %w", err)
} }
// Write stdout logs // Write stdout logs
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
return fmt.Errorf("failed to write stdout log: %w", err) return fmt.Errorf("writing stdout log: %w", err)
} }
// Write stderr logs // Write stderr logs
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { //nolint:gosec,noinlineerr // log files should be readable
return fmt.Errorf("failed to write stderr log: %w", err) return fmt.Errorf("writing stderr log: %w", err)
} }
if verbose { if verbose {

View File

@@ -38,13 +38,13 @@ func runDoctorCheck(ctx context.Context) error {
} }
// Check 3: Go installation // Check 3: Go installation
results = append(results, checkGoInstallation()) results = append(results, checkGoInstallation(ctx))
// Check 4: Git repository // Check 4: Git repository
results = append(results, checkGitRepository()) results = append(results, checkGitRepository(ctx))
// Check 5: Required files // Check 5: Required files
results = append(results, checkRequiredFiles()) results = append(results, checkRequiredFiles(ctx))
// Display results // Display results
displayDoctorResults(results) displayDoctorResults(results)
@@ -86,7 +86,7 @@ func checkDockerBinary() DoctorResult {
// checkDockerDaemon verifies Docker daemon is running and accessible. // checkDockerDaemon verifies Docker daemon is running and accessible.
func checkDockerDaemon(ctx context.Context) DoctorResult { func checkDockerDaemon(ctx context.Context) DoctorResult {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
Name: "Docker Daemon", Name: "Docker Daemon",
@@ -124,8 +124,8 @@ func checkDockerDaemon(ctx context.Context) DoctorResult {
} }
// checkDockerContext verifies Docker context configuration. // checkDockerContext verifies Docker context configuration.
func checkDockerContext(_ context.Context) DoctorResult { func checkDockerContext(ctx context.Context) DoctorResult {
contextInfo, err := getCurrentDockerContext() contextInfo, err := getCurrentDockerContext(ctx)
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
Name: "Docker Context", Name: "Docker Context",
@@ -155,7 +155,7 @@ func checkDockerContext(_ context.Context) DoctorResult {
// checkDockerSocket verifies Docker socket accessibility. // checkDockerSocket verifies Docker socket accessibility.
func checkDockerSocket(ctx context.Context) DoctorResult { func checkDockerSocket(ctx context.Context) DoctorResult {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
Name: "Docker Socket", Name: "Docker Socket",
@@ -192,7 +192,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
// checkGolangImage verifies the golang Docker image is available locally or can be pulled. // checkGolangImage verifies the golang Docker image is available locally or can be pulled.
func checkGolangImage(ctx context.Context) DoctorResult { func checkGolangImage(ctx context.Context) DoctorResult {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
Name: "Golang Image", Name: "Golang Image",
@@ -251,7 +251,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
} }
// checkGoInstallation verifies Go is installed and working. // checkGoInstallation verifies Go is installed and working.
func checkGoInstallation() DoctorResult { func checkGoInstallation(ctx context.Context) DoctorResult {
_, err := exec.LookPath("go") _, err := exec.LookPath("go")
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
@@ -265,7 +265,8 @@ func checkGoInstallation() DoctorResult {
} }
} }
cmd := exec.Command("go", "version") cmd := exec.CommandContext(ctx, "go", "version")
output, err := cmd.Output() output, err := cmd.Output()
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
@@ -285,8 +286,9 @@ func checkGoInstallation() DoctorResult {
} }
// checkGitRepository verifies we're in a git repository. // checkGitRepository verifies we're in a git repository.
func checkGitRepository() DoctorResult { func checkGitRepository(ctx context.Context) DoctorResult {
cmd := exec.Command("git", "rev-parse", "--git-dir") cmd := exec.CommandContext(ctx, "git", "rev-parse", "--git-dir")
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
return DoctorResult{ return DoctorResult{
@@ -308,7 +310,7 @@ func checkGitRepository() DoctorResult {
} }
// checkRequiredFiles verifies required files exist. // checkRequiredFiles verifies required files exist.
func checkRequiredFiles() DoctorResult { func checkRequiredFiles(ctx context.Context) DoctorResult {
requiredFiles := []string{ requiredFiles := []string{
"go.mod", "go.mod",
"integration/", "integration/",
@@ -316,9 +318,12 @@ func checkRequiredFiles() DoctorResult {
} }
var missingFiles []string var missingFiles []string
for _, file := range requiredFiles { for _, file := range requiredFiles {
cmd := exec.Command("test", "-e", file) cmd := exec.CommandContext(ctx, "test", "-e", file)
if err := cmd.Run(); err != nil {
err := cmd.Run()
if err != nil {
missingFiles = append(missingFiles, file) missingFiles = append(missingFiles, file)
} }
} }
@@ -350,6 +355,7 @@ func displayDoctorResults(results []DoctorResult) {
for _, result := range results { for _, result := range results {
var icon string var icon string
switch result.Status { switch result.Status {
case "PASS": case "PASS":
icon = "✅" icon = "✅"

View File

@@ -79,13 +79,18 @@ func main() {
} }
func cleanAll(ctx context.Context) error { func cleanAll(ctx context.Context) error {
if err := killTestContainers(ctx); err != nil { err := killTestContainers(ctx)
if err != nil {
return err return err
} }
if err := pruneDockerNetworks(ctx); err != nil {
err = pruneDockerNetworks(ctx)
if err != nil {
return err return err
} }
if err := cleanOldImages(ctx); err != nil {
err = cleanOldImages(ctx)
if err != nil {
return err return err
} }

View File

@@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error {
if runConfig.Verbose { if runConfig.Verbose {
log.Printf("Running pre-flight system checks...") log.Printf("Running pre-flight system checks...")
} }
if err := runDoctorCheck(env.Context()); err != nil {
err := runDoctorCheck(env.Context())
if err != nil {
return fmt.Errorf("pre-flight checks failed: %w", err) return fmt.Errorf("pre-flight checks failed: %w", err)
} }
@@ -66,15 +68,15 @@ func runIntegrationTest(env *command.Env) error {
func detectGoVersion() string { func detectGoVersion() string {
goModPath := filepath.Join("..", "..", "go.mod") goModPath := filepath.Join("..", "..", "go.mod")
if _, err := os.Stat("go.mod"); err == nil { if _, err := os.Stat("go.mod"); err == nil { //nolint:noinlineerr
goModPath = "go.mod" goModPath = "go.mod"
} else if _, err := os.Stat("../../go.mod"); err == nil { } else if _, err := os.Stat("../../go.mod"); err == nil { //nolint:noinlineerr
goModPath = "../../go.mod" goModPath = "../../go.mod"
} }
content, err := os.ReadFile(goModPath) content, err := os.ReadFile(goModPath)
if err != nil { if err != nil {
return "1.25" return "1.26.0"
} }
lines := splitLines(string(content)) lines := splitLines(string(content))
@@ -89,13 +91,15 @@ func detectGoVersion() string {
} }
} }
return "1.25" return "1.26.0"
} }
// splitLines splits a string into lines without using strings.Split. // splitLines splits a string into lines without using strings.Split.
func splitLines(s string) []string { func splitLines(s string) []string {
var lines []string var (
var current string lines []string
current string
)
for _, char := range s { for _, char := range s {
if char == '\n' { if char == '\n' {

View File

@@ -18,6 +18,9 @@ import (
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
// ErrStatsCollectionAlreadyStarted is returned when trying to start stats collection that is already running.
var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started")
// ContainerStats represents statistics for a single container. // ContainerStats represents statistics for a single container.
type ContainerStats struct { type ContainerStats struct {
ContainerID string ContainerID string
@@ -44,10 +47,10 @@ type StatsCollector struct {
} }
// NewStatsCollector creates a new stats collector instance. // NewStatsCollector creates a new stats collector instance.
func NewStatsCollector() (*StatsCollector, error) { func NewStatsCollector(ctx context.Context) (*StatsCollector, error) {
cli, err := createDockerClient() cli, err := createDockerClient(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create Docker client: %w", err) return nil, fmt.Errorf("creating Docker client: %w", err)
} }
return &StatsCollector{ return &StatsCollector{
@@ -63,17 +66,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
defer sc.mutex.Unlock() defer sc.mutex.Unlock()
if sc.collectionStarted { if sc.collectionStarted {
return errors.New("stats collection already started") return ErrStatsCollectionAlreadyStarted
} }
sc.collectionStarted = true sc.collectionStarted = true
// Start monitoring existing containers // Start monitoring existing containers
sc.wg.Add(1) sc.wg.Add(1)
go sc.monitorExistingContainers(ctx, runID, verbose) go sc.monitorExistingContainers(ctx, runID, verbose)
// Start Docker events monitoring for new containers // Start Docker events monitoring for new containers
sc.wg.Add(1) sc.wg.Add(1)
go sc.monitorDockerEvents(ctx, runID, verbose) go sc.monitorDockerEvents(ctx, runID, verbose)
if verbose { if verbose {
@@ -87,10 +92,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
func (sc *StatsCollector) StopCollection() { func (sc *StatsCollector) StopCollection() {
// Check if already stopped without holding lock // Check if already stopped without holding lock
sc.mutex.RLock() sc.mutex.RLock()
if !sc.collectionStarted { if !sc.collectionStarted {
sc.mutex.RUnlock() sc.mutex.RUnlock()
return return
} }
sc.mutex.RUnlock() sc.mutex.RUnlock()
// Signal stop to all goroutines // Signal stop to all goroutines
@@ -114,6 +121,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
if verbose { if verbose {
log.Printf("Failed to list existing containers: %v", err) log.Printf("Failed to list existing containers: %v", err)
} }
return return
} }
@@ -147,13 +155,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
case event := <-events: case event := <-events:
if event.Type == "container" && event.Action == "start" { if event.Type == "container" && event.Action == "start" {
// Get container details // Get container details
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) //nolint:staticcheck // SA1019: use Actor.ID
if err != nil { if err != nil {
continue continue
} }
// Convert to types.Container format for consistency // Convert to types.Container format for consistency
cont := types.Container{ cont := types.Container{ //nolint:staticcheck // SA1019: use container.Summary
ID: containerInfo.ID, ID: containerInfo.ID,
Names: []string{containerInfo.Name}, Names: []string{containerInfo.Name},
Labels: containerInfo.Config.Labels, Labels: containerInfo.Config.Labels,
@@ -167,13 +175,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
if verbose { if verbose {
log.Printf("Error in Docker events stream: %v", err) log.Printf("Error in Docker events stream: %v", err)
} }
return return
} }
} }
} }
// shouldMonitorContainer determines if a container should be monitored. // shouldMonitorContainer determines if a container should be monitored.
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { //nolint:staticcheck // SA1019: use container.Summary
// Check if it has the correct run ID label // Check if it has the correct run ID label
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID { if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
return false return false
@@ -213,6 +222,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
} }
sc.wg.Add(1) sc.wg.Add(1)
go sc.collectStatsForContainer(ctx, containerID, verbose) go sc.collectStatsForContainer(ctx, containerID, verbose)
} }
@@ -226,12 +236,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
if verbose { if verbose {
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
} }
return return
} }
defer statsResponse.Body.Close() defer statsResponse.Body.Close()
decoder := json.NewDecoder(statsResponse.Body) decoder := json.NewDecoder(statsResponse.Body)
var prevStats *container.Stats
var prevStats *container.Stats //nolint:staticcheck // SA1019: use StatsResponse
for { for {
select { select {
@@ -240,12 +252,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
var stats container.Stats var stats container.Stats //nolint:staticcheck // SA1019: use StatsResponse
if err := decoder.Decode(&stats); err != nil {
err := decoder.Decode(&stats)
if err != nil {
// EOF is expected when container stops or stream ends // EOF is expected when container stops or stream ends
if err.Error() != "EOF" && verbose { if err.Error() != "EOF" && verbose {
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
} }
return return
} }
@@ -261,8 +276,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
// Store the sample (skip first sample since CPU calculation needs previous stats) // Store the sample (skip first sample since CPU calculation needs previous stats)
if prevStats != nil { if prevStats != nil {
// Get container stats reference without holding the main mutex // Get container stats reference without holding the main mutex
var containerStats *ContainerStats var (
var exists bool containerStats *ContainerStats
exists bool
)
sc.mutex.RLock() sc.mutex.RLock()
containerStats, exists = sc.containers[containerID] containerStats, exists = sc.containers[containerID]
@@ -286,7 +303,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
} }
// calculateCPUPercent calculates CPU usage percentage from Docker stats. // calculateCPUPercent calculates CPU usage percentage from Docker stats.
func calculateCPUPercent(prevStats, stats *container.Stats) float64 { func calculateCPUPercent(prevStats, stats *container.Stats) float64 { //nolint:staticcheck // SA1019: use StatsResponse
// CPU calculation based on Docker's implementation // CPU calculation based on Docker's implementation
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage) cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage) systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
@@ -331,10 +348,12 @@ type StatsSummary struct {
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
// Take snapshot of container references without holding main lock long // Take snapshot of container references without holding main lock long
sc.mutex.RLock() sc.mutex.RLock()
containerRefs := make([]*ContainerStats, 0, len(sc.containers)) containerRefs := make([]*ContainerStats, 0, len(sc.containers))
for _, containerStats := range sc.containers { for _, containerStats := range sc.containers {
containerRefs = append(containerRefs, containerStats) containerRefs = append(containerRefs, containerStats)
} }
sc.mutex.RUnlock() sc.mutex.RUnlock()
summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
@@ -384,23 +403,25 @@ func calculateStatsSummary(values []float64) StatsSummary {
return StatsSummary{} return StatsSummary{}
} }
min := values[0] minVal := values[0]
max := values[0] maxVal := values[0]
sum := 0.0 sum := 0.0
for _, value := range values { for _, value := range values {
if value < min { if value < minVal {
min = value minVal = value
} }
if value > max {
max = value if value > maxVal {
maxVal = value
} }
sum += value sum += value
} }
return StatsSummary{ return StatsSummary{
Min: min, Min: minVal,
Max: max, Max: maxVal,
Average: sum / float64(len(values)), Average: sum / float64(len(values)),
} }
} }
@@ -434,6 +455,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
} }
summaries := sc.GetSummary() summaries := sc.GetSummary()
var violations []MemoryViolation var violations []MemoryViolation
for _, summary := range summaries { for _, summary := range summaries {

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
@@ -15,7 +16,10 @@ type MapConfig struct {
Directory string `flag:"directory,Directory to read map responses from"` Directory string `flag:"directory,Directory to read map responses from"`
} }
var mapConfig MapConfig var (
mapConfig MapConfig
errDirectoryRequired = errors.New("directory is required")
)
func main() { func main() {
root := command.C{ root := command.C{
@@ -40,7 +44,7 @@ func main() {
// runIntegrationTest executes the integration test workflow. // runIntegrationTest executes the integration test workflow.
func runOnline(env *command.Env) error { func runOnline(env *command.Env) error {
if mapConfig.Directory == "" { if mapConfig.Directory == "" {
return fmt.Errorf("directory is required") return errDirectoryRequired
} }
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
@@ -57,5 +61,6 @@ func runOnline(env *command.Env) error {
os.Stderr.Write(out) os.Stderr.Write(out)
os.Stderr.Write([]byte("\n")) os.Stderr.Write([]byte("\n"))
return nil return nil
} }

View File

@@ -24,9 +24,12 @@ We are more than happy to exchange emails, or to have dedicated calls before a P
## When/Why is Feature X going to be implemented? ## When/Why is Feature X going to be implemented?
We don't know. We might be working on it. If you're interested in contributing, please post a feature request about it. We use [GitHub Milestones to plan for upcoming Headscale releases](https://github.com/juanfont/headscale/milestones).
Have a look at [our current plan](https://github.com/juanfont/headscale/milestones) to get an idea when a specific
feature is about to be implemented. The release plan is subject to change at any time.
Please be aware that there are a number of reasons why we might not accept specific contributions: If you're interested in contributing, please post a feature request about it. Please be aware that there are a number of
reasons why we might not accept specific contributions:
- It is not possible to implement the feature in a way that makes sense in a self-hosted environment. - It is not possible to implement the feature in a way that makes sense in a self-hosted environment.
- Given that we are reverse-engineering Tailscale to satisfy our own curiosity, we might be interested in implementing the feature ourselves. - Given that we are reverse-engineering Tailscale to satisfy our own curiosity, we might be interested in implementing the feature ourselves.
@@ -47,7 +50,7 @@ we have a "docker-issues" channel where you can ask for Docker-specific help to
## What is the recommended update path? Can I skip multiple versions while updating? ## What is the recommended update path? Can I skip multiple versions while updating?
Please follow the steps outlined in the [upgrade guide](../setup/upgrade.md) to update your existing Headscale Please follow the steps outlined in the [upgrade guide](../setup/upgrade.md) to update your existing Headscale
installation. Its best to update from one stable version to the next (e.g. 0.24.0 &rarr; 0.25.1 &rarr; 0.26.1) in case installation. Its best to update from one stable version to the next (e.g. 0.26.0 &rarr; 0.27.1 &rarr; 0.28.0) in case
you are multiple releases behind. You should always pick the latest available patch release. you are multiple releases behind. You should always pick the latest available patch release.
Be sure to check the [changelog](https://github.com/juanfont/headscale/blob/main/CHANGELOG.md) for version specific Be sure to check the [changelog](https://github.com/juanfont/headscale/blob/main/CHANGELOG.md) for version specific

View File

@@ -245,7 +245,6 @@ Includes all devices that [have at least one tag](registration.md/#identity-mode
``` ```
### `autogroup:self` ### `autogroup:self`
**(EXPERIMENTAL)**
!!! warning "The current implementation of `autogroup:self` is inefficient" !!! warning "The current implementation of `autogroup:self` is inefficient"

View File

@@ -20,5 +20,7 @@ Headscale doesn't provide a built-in web interface but users may pick one from t
- [headscale-console](https://github.com/rickli-cloud/headscale-console) - WebAssembly-based client supporting SSH, VNC - [headscale-console](https://github.com/rickli-cloud/headscale-console) - WebAssembly-based client supporting SSH, VNC
and RDP with optional self-service capabilities and RDP with optional self-service capabilities
- [headscale-piying](https://github.com/wszgrcy/headscale-piying) - headscale web ui,support visual ACL configuration - [headscale-piying](https://github.com/wszgrcy/headscale-piying) - headscale web ui,support visual ACL configuration
- [HeadControl](https://github.com/ahmadzip/HeadControl) - Minimal Headscale admin dashboard, built with Go and HTMX
- [Headscale Manager](https://github.com/hkdone/headscalemanager) - Headscale UI for Android
You can ask for support on our [Discord server](https://discord.gg/c84AZQhmpx) in the "web-interfaces" channel. You can ask for support on our [Discord server](https://discord.gg/c84AZQhmpx) in the "web-interfaces" channel.

View File

@@ -185,7 +185,8 @@ You may refer to users in the Headscale policy via:
- Email address - Email address
- Username - Username
- Provider identifier (only available in the database or from your identity provider) - Provider identifier (this value is currently only available from the [API](api.md), database or directly from your
identity provider)
!!! note "A user identifier in the policy must contain a single `@`" !!! note "A user identifier in the policy must contain a single `@`"
@@ -200,6 +201,34 @@ You may refer to users in the Headscale policy via:
consequences for Headscale where a policy might no longer work or a user might obtain more access by hijacking an consequences for Headscale where a policy might no longer work or a user might obtain more access by hijacking an
existing username or email address. existing username or email address.
!!! tip "Howto use the provider identifier in the policy"
The provider identifier uniquely identifies an OIDC user and a well-behaving identity provider guarantees that this
value never changes for a particular user. It is usually an opaque and long string and its value is currently only
available from the [API](api.md), database or directly from your identity provider).
Use the [API](api.md) with the `/api/v1/user` endpoint to fetch the provider identifier (`providerId`). The value
(be sure to append an `@` in case the provider identifier doesn't already contain an `@` somewhere) can be used
directly to reference a user in the policy. To improve readability of the policy, one may use the `groups` section
as an alias:
```json
{
"groups": {
"group:alice": [
"https://soo.example.com/oauth2/openid/59ac9125-c31b-46c5-814e-06242908cf57@"
]
},
"acls": [
{
"action": "accept",
"src": ["group:alice"],
"dst": ["*:*"]
}
]
}
```
## Supported OIDC claims ## Supported OIDC claims
Headscale uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to Headscale uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to
@@ -289,6 +318,14 @@ Console.
- Kanidm is fully supported by Headscale. - Kanidm is fully supported by Headscale.
- Groups for the [allowed groups filter](#authorize-users-with-filters) need to be specified with their full SPN, for - Groups for the [allowed groups filter](#authorize-users-with-filters) need to be specified with their full SPN, for
example: `headscale_users@sso.example.com`. example: `headscale_users@sso.example.com`.
- Kanidm sends the full SPN (`alice@sso.example.com`) as `preferred_username` by default. Headscale stores this value as
username which might be confusing as the username and email fields now contain values that look like an email address.
[Kanidm can be configured to send the short username as `preferred_username` attribute
instead](https://kanidm.github.io/kanidm/stable/integrations/oauth2.html#short-names):
```console
kanidm system oauth2 prefer-short-username <client name>
```
Once configured, the short username in Headscale will be `alice` and can be referred to as `alice@` in the policy.
### Keycloak ### Keycloak

6
flake.lock generated
View File

@@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1770141374, "lastModified": 1771177547,
"narHash": "sha256-yD4K/vRHPwXbJf5CK3JkptBA6nFWUKNX/jlFp2eKEQc=", "narHash": "sha256-trTtk3WTOHz7hSw89xIIvahkgoFJYQ0G43IlqprFoMA=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "41965737c1797c1d83cfb0b644ed0840a6220bd1", "rev": "ac055f38c798b0d87695240c7b761b82fc7e5bc2",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@@ -26,8 +26,8 @@
overlays.default = _: prev: overlays.default = _: prev:
let let
pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system};
buildGo = pkgs.buildGo125Module; buildGo = pkgs.buildGo126Module;
vendorHash = "sha256-jkeB9XUTEGt58fPOMpE4/e3+JQoMQTgf0RlthVBmfG0="; vendorHash = "sha256-9BvphYDAxzwooyVokI3l+q1wRuRsWn/qM+NpWUgqJH0=";
in in
{ {
headscale = buildGo { headscale = buildGo {
@@ -94,14 +94,46 @@
subPackages = [ "." ]; subPackages = [ "." ];
}; };
# Upstream does not override buildGoModule properly, # Build golangci-lint with Go 1.26 (upstream uses hardcoded Go version)
# importing a specific module, so comment out for now. golangci-lint = buildGo rec {
# golangci-lint = prev.golangci-lint.override { pname = "golangci-lint";
# buildGoModule = buildGo; version = "2.9.0";
# };
# golangci-lint-langserver = prev.golangci-lint.override { src = pkgs.fetchFromGitHub {
# buildGoModule = buildGo; owner = "golangci";
# }; repo = "golangci-lint";
rev = "v${version}";
hash = "sha256-8LEtm1v0slKwdLBtS41OilKJLXytSxcI9fUlZbj5Gfw=";
};
vendorHash = "sha256-w8JfF6n1ylrU652HEv/cYdsOdDZz9J2uRQDqxObyhkY=";
subPackages = [ "cmd/golangci-lint" ];
nativeBuildInputs = [ pkgs.installShellFiles ];
ldflags = [
"-s"
"-w"
"-X main.version=${version}"
"-X main.commit=v${version}"
"-X main.date=1970-01-01T00:00:00Z"
];
postInstall = ''
for shell in bash zsh fish; do
HOME=$TMPDIR $out/bin/golangci-lint completion $shell > golangci-lint.$shell
installShellCompletion golangci-lint.$shell
done
'';
meta = {
description = "Fast linters runner for Go";
homepage = "https://golangci-lint.run/";
changelog = "https://github.com/golangci/golangci-lint/blob/v${version}/CHANGELOG.md";
mainProgram = "golangci-lint";
};
};
# The package uses buildGo125Module, not the convention. # The package uses buildGo125Module, not the convention.
# goreleaser = prev.goreleaser.override { # goreleaser = prev.goreleaser.override {
@@ -132,7 +164,7 @@
overlays = [ self.overlays.default ]; overlays = [ self.overlays.default ];
inherit system; inherit system;
}; };
buildDeps = with pkgs; [ git go_1_25 gnumake ]; buildDeps = with pkgs; [ git go_1_26 gnumake ];
devDeps = with pkgs; devDeps = with pkgs;
buildDeps buildDeps
++ [ ++ [

2
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/juanfont/headscale module github.com/juanfont/headscale
go 1.25.5 go 1.26.0
require ( require (
github.com/arl/statsviz v0.8.0 github.com/arl/statsviz v0.8.0

View File

@@ -115,13 +115,14 @@ var (
func NewHeadscale(cfg *types.Config) (*Headscale, error) { func NewHeadscale(cfg *types.Config) (*Headscale, error) {
var err error var err error
if profilingEnabled { if profilingEnabled {
runtime.SetBlockProfileRate(1) runtime.SetBlockProfileRate(1)
} }
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath) noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) return nil, fmt.Errorf("reading or creating Noise protocol private key: %w", err)
} }
s, err := state.NewState(cfg) s, err := state.NewState(cfg)
@@ -140,27 +141,30 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) {
node, ok := app.state.GetNodeByID(ni) node, ok := app.state.GetNodeByID(ni)
if !ok { if !ok {
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") log.Error().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed")
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("ephemeral node deletion failed because node not found in NodeStore")
return return
} }
policyChanged, err := app.state.DeleteNode(node) policyChanged, err := app.state.DeleteNode(node)
if err != nil { if err != nil {
log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed") log.Error().Err(err).EmbedObject(node).Msg("ephemeral node deletion failed")
return return
} }
app.Change(policyChanged) app.Change(policyChanged)
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached") log.Debug().Caller().EmbedObject(node).Msg("ephemeral node deleted because garbage collection timeout reached")
}) })
app.ephemeralGC = ephemeralGC app.ephemeralGC = ephemeralGC
var authProvider AuthProvider var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL) authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" { if cfg.OIDC.Issuer != "" {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
oidcProvider, err := NewAuthProviderOIDC( oidcProvider, err := NewAuthProviderOIDC(
ctx, ctx,
&app, &app,
@@ -177,17 +181,18 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
authProvider = oidcProvider authProvider = oidcProvider
} }
} }
app.authProvider = authProvider app.authProvider = authProvider
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
// TODO(kradalby): revisit why this takes a list. // TODO(kradalby): revisit why this takes a list.
var magicDNSDomains []dnsname.FQDN var magicDNSDomains []dnsname.FQDN
if cfg.PrefixV4 != nil { if cfg.PrefixV4 != nil {
magicDNSDomains = append( magicDNSDomains = append(
magicDNSDomains, magicDNSDomains,
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...) util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
} }
if cfg.PrefixV6 != nil { if cfg.PrefixV6 != nil {
magicDNSDomains = append( magicDNSDomains = append(
magicDNSDomains, magicDNSDomains,
@@ -198,6 +203,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if app.cfg.TailcfgDNSConfig.Routes == nil { if app.cfg.TailcfgDNSConfig.Routes == nil {
app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver) app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver)
} }
for _, d := range magicDNSDomains { for _, d := range magicDNSDomains {
app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil
} }
@@ -206,7 +212,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if cfg.DERP.ServerEnabled { if cfg.DERP.ServerEnabled {
derpServerKey, err := readOrCreatePrivateKey(cfg.DERP.ServerPrivateKeyPath) derpServerKey, err := readOrCreatePrivateKey(cfg.DERP.ServerPrivateKeyPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read or create DERP server private key: %w", err) return nil, fmt.Errorf("reading or creating DERP server private key: %w", err)
} }
if derpServerKey.Equal(*noisePrivateKey) { if derpServerKey.Equal(*noisePrivateKey) {
@@ -232,6 +238,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
app.DERPServer = embeddedDERPServer app.DERPServer = embeddedDERPServer
} }
@@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
lastExpiryCheck := time.Unix(0, 0) lastExpiryCheck := time.Unix(0, 0)
derpTickerChan := make(<-chan time.Time) derpTickerChan := make(<-chan time.Time)
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 { if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
defer derpTicker.Stop() defer derpTicker.Stop()
derpTickerChan = derpTicker.C derpTickerChan = derpTicker.C
} }
@@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
return return
case <-expireTicker.C: case <-expireTicker.C:
var expiredNodeChanges []change.Change var (
var changed bool expiredNodeChanges []change.Change
changed bool
)
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
@@ -286,12 +297,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
} }
case <-derpTickerChan: case <-derpTickerChan:
log.Info().Msg("Fetching DERPMap updates") log.Info().Msg("fetching DERPMap updates")
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { //nolint:contextcheck
derpMap, err := derp.GetDERPMap(h.cfg.DERP) derpMap, err := derp.GetDERPMap(h.cfg.DERP)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion() region, _ := h.DERPServer.GenerateRegion()
derpMap.Regions[region.RegionID] = &region derpMap.Regions[region.RegionID] = &region
@@ -303,6 +316,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later") log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
continue continue
} }
h.state.SetDERPMap(derpMap) h.state.SetDERPMap(derpMap)
h.Change(change.DERPMap()) h.Change(change.DERPMap())
@@ -311,6 +325,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
if !ok { if !ok {
continue continue
} }
h.cfg.TailcfgDNSConfig.ExtraRecords = records h.cfg.TailcfgDNSConfig.ExtraRecords = records
h.Change(change.ExtraRecords()) h.Change(change.ExtraRecords())
@@ -339,7 +354,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
if !ok { if !ok {
return ctx, status.Errorf( return ctx, status.Errorf(
codes.InvalidArgument, codes.InvalidArgument,
"Retrieving metadata is failed", "retrieving metadata",
) )
} }
@@ -347,7 +362,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
if !ok { if !ok {
return ctx, status.Errorf( return ctx, status.Errorf(
codes.Unauthenticated, codes.Unauthenticated,
"Authorization token is not supplied", "authorization token not supplied",
) )
} }
@@ -362,7 +377,7 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
if err != nil { if err != nil {
return ctx, status.Error(codes.Internal, "failed to validate token") return ctx, status.Error(codes.Internal, "validating token")
} }
if !valid { if !valid {
@@ -390,7 +405,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writeUnauthorized := func(statusCode int) { writeUnauthorized := func(statusCode int) {
writer.WriteHeader(statusCode) writer.WriteHeader(statusCode)
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
if _, err := writer.Write([]byte("Unauthorized")); err != nil { //nolint:noinlineerr
log.Error().Err(err).Msg("writing HTTP response failed") log.Error().Err(err).Msg("writing HTTP response failed")
} }
} }
@@ -401,6 +417,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
Str("client_address", req.RemoteAddr). Str("client_address", req.RemoteAddr).
Msg(`missing "Bearer " prefix in "Authorization" header`) Msg(`missing "Bearer " prefix in "Authorization" header`)
writeUnauthorized(http.StatusUnauthorized) writeUnauthorized(http.StatusUnauthorized)
return return
} }
@@ -412,6 +429,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
Str("client_address", req.RemoteAddr). Str("client_address", req.RemoteAddr).
Msg("failed to validate token") Msg("failed to validate token")
writeUnauthorized(http.StatusUnauthorized) writeUnauthorized(http.StatusUnauthorized)
return return
} }
@@ -420,6 +438,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
Str("client_address", req.RemoteAddr). Str("client_address", req.RemoteAddr).
Msg("invalid token") Msg("invalid token")
writeUnauthorized(http.StatusUnauthorized) writeUnauthorized(http.StatusUnauthorized)
return return
} }
@@ -431,7 +450,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
// and will remove it if it is not. // and will remove it if it is not.
func (h *Headscale) ensureUnixSocketIsAbsent() error { func (h *Headscale) ensureUnixSocketIsAbsent() error {
// File does not exist, all fine // File does not exist, all fine
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { //nolint:noinlineerr
return nil return nil
} }
@@ -455,6 +474,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
} }
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet) Methods(http.MethodGet)
@@ -484,8 +504,11 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
} }
// Serve launches the HTTP and gRPC server service Headscale and the API. // Serve launches the HTTP and gRPC server service Headscale and the API.
//
//nolint:gocyclo // complex server startup function
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
var err error var err error
capver.CanOldCodeBeCleanedUp() capver.CanOldCodeBeCleanedUp()
if profilingEnabled { if profilingEnabled {
@@ -506,12 +529,13 @@ func (h *Headscale) Serve() error {
} }
versionInfo := types.GetVersionInfo() versionInfo := types.GetVersionInfo()
log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("Starting Headscale") log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("starting headscale")
log.Info(). log.Info().
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)). Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
Msg("Clients with a lower minimum version will be rejected") Msg("Clients with a lower minimum version will be rejected")
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
h.mapBatcher.Start() h.mapBatcher.Start()
defer h.mapBatcher.Close() defer h.mapBatcher.Close()
@@ -526,7 +550,7 @@ func (h *Headscale) Serve() error {
derpMap, err := derp.GetDERPMap(h.cfg.DERP) derpMap, err := derp.GetDERPMap(h.cfg.DERP)
if err != nil { if err != nil {
return fmt.Errorf("failed to get DERPMap: %w", err) return fmt.Errorf("getting DERPMap: %w", err)
} }
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
@@ -545,6 +569,7 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will // around between restarts, they will reconnect and the GC will
// be cancelled. // be cancelled.
go h.ephemeralGC.Start() go h.ephemeralGC.Start()
ephmNodes := h.state.ListEphemeralNodes() ephmNodes := h.state.ListEphemeralNodes()
for _, node := range ephmNodes.All() { for _, node := range ephmNodes.All() {
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
@@ -555,7 +580,9 @@ func (h *Headscale) Serve() error {
if err != nil { if err != nil {
return fmt.Errorf("setting up extrarecord manager: %w", err) return fmt.Errorf("setting up extrarecord manager: %w", err)
} }
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
go h.extraRecordMan.Run() go h.extraRecordMan.Run()
defer h.extraRecordMan.Close() defer h.extraRecordMan.Close()
} }
@@ -564,6 +591,7 @@ func (h *Headscale) Serve() error {
// records updates // records updates
scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
defer scheduleCancel() defer scheduleCancel()
go h.scheduledTasks(scheduleCtx) go h.scheduledTasks(scheduleCtx)
if zl.GlobalLevel() == zl.TraceLevel { if zl.GlobalLevel() == zl.TraceLevel {
@@ -576,6 +604,7 @@ func (h *Headscale) Serve() error {
errorGroup := new(errgroup.Group) errorGroup := new(errgroup.Group)
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -586,29 +615,30 @@ func (h *Headscale) Serve() error {
err = h.ensureUnixSocketIsAbsent() err = h.ensureUnixSocketIsAbsent()
if err != nil { if err != nil {
return fmt.Errorf("unable to remove old socket file: %w", err) return fmt.Errorf("removing old socket file: %w", err)
} }
socketDir := filepath.Dir(h.cfg.UnixSocket) socketDir := filepath.Dir(h.cfg.UnixSocket)
err = util.EnsureDir(socketDir) err = util.EnsureDir(socketDir)
if err != nil { if err != nil {
return fmt.Errorf("setting up unix socket: %w", err) return fmt.Errorf("setting up unix socket: %w", err)
} }
socketListener, err := net.Listen("unix", h.cfg.UnixSocket) socketListener, err := new(net.ListenConfig).Listen(context.Background(), "unix", h.cfg.UnixSocket)
if err != nil { if err != nil {
return fmt.Errorf("failed to set up gRPC socket: %w", err) return fmt.Errorf("setting up gRPC socket: %w", err)
} }
// Change socket permissions // Change socket permissions
if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { if err := os.Chmod(h.cfg.UnixSocket, h.cfg.UnixSocketPermission); err != nil { //nolint:noinlineerr
return fmt.Errorf("failed change permission of gRPC socket: %w", err) return fmt.Errorf("changing gRPC socket permission: %w", err)
} }
grpcGatewayMux := grpcRuntime.NewServeMux() grpcGatewayMux := grpcRuntime.NewServeMux()
// Make the grpc-gateway connect to grpc over socket // Make the grpc-gateway connect to grpc over socket
grpcGatewayConn, err := grpc.Dial( grpcGatewayConn, err := grpc.Dial( //nolint:staticcheck // SA1019: deprecated but supported in 1.x
h.cfg.UnixSocket, h.cfg.UnixSocket,
[]grpc.DialOption{ []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
@@ -659,10 +689,13 @@ func (h *Headscale) Serve() error {
// https://github.com/soheilhy/cmux/issues/68 // https://github.com/soheilhy/cmux/issues/68
// https://github.com/soheilhy/cmux/issues/91 // https://github.com/soheilhy/cmux/issues/91
var grpcServer *grpc.Server var (
var grpcListener net.Listener grpcServer *grpc.Server
grpcListener net.Listener
)
if tlsConfig != nil || h.cfg.GRPCAllowInsecure { if tlsConfig != nil || h.cfg.GRPCAllowInsecure {
log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) log.Info().Msgf("enabling remote gRPC at %s", h.cfg.GRPCAddr)
grpcOptions := []grpc.ServerOption{ grpcOptions := []grpc.ServerOption{
grpc.ChainUnaryInterceptor( grpc.ChainUnaryInterceptor(
@@ -685,9 +718,9 @@ func (h *Headscale) Serve() error {
v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
reflection.Register(grpcServer) reflection.Register(grpcServer)
grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) grpcListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.GRPCAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("binding to TCP address: %w", err)
} }
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) }) errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
@@ -715,14 +748,16 @@ func (h *Headscale) Serve() error {
} }
var httpListener net.Listener var httpListener net.Listener
if tlsConfig != nil { if tlsConfig != nil {
httpServer.TLSConfig = tlsConfig httpServer.TLSConfig = tlsConfig
httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig) httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
} else { } else {
httpListener, err = net.Listen("tcp", h.cfg.Addr) httpListener, err = new(net.ListenConfig).Listen(context.Background(), "tcp", h.cfg.Addr)
} }
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("binding to TCP address: %w", err)
} }
errorGroup.Go(func() error { return httpServer.Serve(httpListener) }) errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
@@ -738,7 +773,7 @@ func (h *Headscale) Serve() error {
if h.cfg.MetricsAddr != "" { if h.cfg.MetricsAddr != "" {
debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr) debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err) return fmt.Errorf("binding to TCP address: %w", err)
} }
debugHTTPServer = h.debugHTTPServer() debugHTTPServer = h.debugHTTPServer()
@@ -751,19 +786,24 @@ func (h *Headscale) Serve() error {
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)") log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
} }
var tailsqlContext context.Context var tailsqlContext context.Context
if tailsqlEnabled { if tailsqlEnabled {
if h.cfg.Database.Type != types.DatabaseSqlite { if h.cfg.Database.Type != types.DatabaseSqlite {
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
log.Fatal(). log.Fatal().
Str("type", h.cfg.Database.Type). Str("type", h.cfg.Database.Type).
Msgf("tailsql only support %q", types.DatabaseSqlite) Msgf("tailsql only support %q", types.DatabaseSqlite)
} }
if tailsqlTSKey == "" { if tailsqlTSKey == "" {
//nolint:gocritic // exitAfterDefer: Fatal exits during initialization before servers start
log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set")
} }
tailsqlContext = context.Background() tailsqlContext = context.Background()
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path)
go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) //nolint:errcheck
} }
// Handle common process-killing signals so we can gracefully shut down: // Handle common process-killing signals so we can gracefully shut down:
@@ -774,6 +814,7 @@ func (h *Headscale) Serve() error {
syscall.SIGTERM, syscall.SIGTERM,
syscall.SIGQUIT, syscall.SIGQUIT,
syscall.SIGHUP) syscall.SIGHUP)
sigFunc := func(c chan os.Signal) { sigFunc := func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL: // Wait for a SIGINT or SIGKILL:
for { for {
@@ -798,6 +839,7 @@ func (h *Headscale) Serve() error {
default: default:
info := func(msg string) { log.Info().Msg(msg) } info := func(msg string) { log.Info().Msg(msg) }
log.Info(). log.Info().
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
@@ -854,6 +896,7 @@ func (h *Headscale) Serve() error {
if debugHTTPListener != nil { if debugHTTPListener != nil {
debugHTTPListener.Close() debugHTTPListener.Close()
} }
httpListener.Close() httpListener.Close()
grpcGatewayConn.Close() grpcGatewayConn.Close()
@@ -863,6 +906,7 @@ func (h *Headscale) Serve() error {
// Close state connections // Close state connections
info("closing state and database") info("closing state and database")
err = h.state.Close() err = h.state.Close()
if err != nil { if err != nil {
log.Error().Err(err).Msg("failed to close state") log.Error().Err(err).Msg("failed to close state")
@@ -875,6 +919,7 @@ func (h *Headscale) Serve() error {
} }
} }
} }
errorGroup.Go(func() error { errorGroup.Go(func() error {
sigFunc(sigc) sigFunc(sigc)
@@ -886,6 +931,7 @@ func (h *Headscale) Serve() error {
func (h *Headscale) getTLSSettings() (*tls.Config, error) { func (h *Headscale) getTLSSettings() (*tls.Config, error) {
var err error var err error
if h.cfg.TLS.LetsEncrypt.Hostname != "" { if h.cfg.TLS.LetsEncrypt.Hostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn(). log.Warn().
@@ -918,7 +964,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
server := &http.Server{ server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen, Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)), Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
@@ -940,13 +985,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
} }
} else if h.cfg.TLS.CertPath == "" { } else if h.cfg.TLS.CertPath == "" {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") { if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") log.Warn().Msg("listening without TLS but ServerURL does not start with http://")
} }
return nil, err return nil, err
} else { } else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") log.Warn().Msg("listening with TLS but ServerURL does not start with https://")
} }
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@@ -963,6 +1008,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
dir := filepath.Dir(path) dir := filepath.Dir(path)
err := util.EnsureDir(dir) err := util.EnsureDir(dir)
if err != nil { if err != nil {
return nil, fmt.Errorf("ensuring private key directory: %w", err) return nil, fmt.Errorf("ensuring private key directory: %w", err)
@@ -970,21 +1016,22 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
privateKey, err := os.ReadFile(path) privateKey, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
log.Info().Str("path", path).Msg("No private key file at path, creating...") log.Info().Str("path", path).Msg("no private key file at path, creating...")
machineKey := key.NewMachine() machineKey := key.NewMachine()
machineKeyStr, err := machineKey.MarshalText() machineKeyStr, err := machineKey.MarshalText()
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to convert private key to string for saving: %w", "converting private key to string for saving: %w",
err, err,
) )
} }
err = os.WriteFile(path, machineKeyStr, privateKeyFileMode) err = os.WriteFile(path, machineKeyStr, privateKeyFileMode)
if err != nil { if err != nil {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"failed to save private key to disk at path %q: %w", "saving private key to disk at path %q: %w",
path, path,
err, err,
) )
@@ -992,14 +1039,14 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil return &machineKey, nil
} else if err != nil { } else if err != nil {
return nil, fmt.Errorf("failed to read private key file: %w", err) return nil, fmt.Errorf("reading private key file: %w", err)
} }
trimmedPrivateKey := strings.TrimSpace(string(privateKey)) trimmedPrivateKey := strings.TrimSpace(string(privateKey))
var machineKey key.MachinePrivate var machineKey key.MachinePrivate
if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { if err = machineKey.UnmarshalText([]byte(trimmedPrivateKey)); err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("failed to parse private key: %w", err) return nil, fmt.Errorf("parsing private key: %w", err)
} }
return &machineKey, nil return &machineKey, nil
@@ -1023,7 +1070,7 @@ type acmeLogger struct {
func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) { func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := l.rt.RoundTrip(req) resp, err := l.rt.RoundTrip(req)
if err != nil { if err != nil {
log.Error().Err(err).Str("url", req.URL.String()).Msg("ACME request failed") log.Error().Err(err).Str("url", req.URL.String()).Msg("acme request failed")
return nil, err return nil, err
} }
@@ -1031,7 +1078,7 @@ func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
defer resp.Body.Close() defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("ACME request returned error") log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("acme request returned error")
} }
return resp, nil return resp, nil

View File

@@ -16,12 +16,11 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
type AuthProvider interface { type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request) RegisterHandler(w http.ResponseWriter, r *http.Request)
AuthURL(types.RegistrationID) string AuthURL(regID types.RegistrationID) string
} }
func (h *Headscale) handleRegister( func (h *Headscale) handleRegister(
@@ -42,8 +41,7 @@ func (h *Headscale) handleRegister(
// This is a logout attempt (expiry in the past) // This is a logout attempt (expiry in the past)
if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok { if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok {
log.Debug(). log.Debug().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Str("node.name", node.Hostname()).
Bool("is_ephemeral", node.IsEphemeral()). Bool("is_ephemeral", node.IsEphemeral()).
Bool("has_authkey", node.AuthKey().Valid()). Bool("has_authkey", node.AuthKey().Valid()).
Msg("Found existing node for logout, calling handleLogout") Msg("Found existing node for logout, calling handleLogout")
@@ -52,6 +50,7 @@ func (h *Headscale) handleRegister(
if err != nil { if err != nil {
return nil, fmt.Errorf("handling logout: %w", err) return nil, fmt.Errorf("handling logout: %w", err)
} }
if resp != nil { if resp != nil {
return resp, nil return resp, nil
} }
@@ -113,8 +112,7 @@ func (h *Headscale) handleRegister(
resp, err := h.handleRegisterWithAuthKey(req, machineKey) resp, err := h.handleRegisterWithAuthKey(req, machineKey)
if err != nil { if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer // Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError if httpErr, ok := errors.AsType[HTTPError](err); ok {
if errors.As(err, &httpErr) {
return nil, httpErr return nil, httpErr
} }
@@ -133,7 +131,7 @@ func (h *Headscale) handleRegister(
} }
// handleLogout checks if the [tailcfg.RegisterRequest] is a // handleLogout checks if the [tailcfg.RegisterRequest] is a
// logout attempt from a node. If the node is not attempting to // logout attempt from a node. If the node is not attempting to.
func (h *Headscale) handleLogout( func (h *Headscale) handleLogout(
node types.NodeView, node types.NodeView,
req tailcfg.RegisterRequest, req tailcfg.RegisterRequest,
@@ -155,11 +153,12 @@ func (h *Headscale) handleLogout(
// force the client to re-authenticate. // force the client to re-authenticate.
// TODO(kradalby): I wonder if this is a path we ever hit? // TODO(kradalby): I wonder if this is a path we ever hit?
if node.IsExpired() { if node.IsExpired() {
log.Trace().Str("node.name", node.Hostname()). log.Trace().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Interface("reg.req", req). Interface("reg.req", req).
Bool("unexpected", true). Bool("unexpected", true).
Msg("Node key expired, forcing re-authentication") Msg("Node key expired, forcing re-authentication")
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
NodeKeyExpired: true, NodeKeyExpired: true,
MachineAuthorized: false, MachineAuthorized: false,
@@ -182,8 +181,7 @@ func (h *Headscale) handleLogout(
// Zero expiry is handled in handleRegister() before calling this function. // Zero expiry is handled in handleRegister() before calling this function.
if req.Expiry.Before(time.Now()) { if req.Expiry.Before(time.Now()) {
log.Debug(). log.Debug().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Str("node.name", node.Hostname()).
Bool("is_ephemeral", node.IsEphemeral()). Bool("is_ephemeral", node.IsEphemeral()).
Bool("has_authkey", node.AuthKey().Valid()). Bool("has_authkey", node.AuthKey().Valid()).
Time("req.expiry", req.Expiry). Time("req.expiry", req.Expiry).
@@ -191,8 +189,7 @@ func (h *Headscale) handleLogout(
if node.IsEphemeral() { if node.IsEphemeral() {
log.Info(). log.Info().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Str("node.name", node.Hostname()).
Msg("Deleting ephemeral node during logout") Msg("Deleting ephemeral node during logout")
c, err := h.state.DeleteNode(node) c, err := h.state.DeleteNode(node)
@@ -209,8 +206,7 @@ func (h *Headscale) handleLogout(
} }
log.Debug(). log.Debug().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Str("node.name", node.Hostname()).
Msg("Node is not ephemeral, setting expiry instead of deleting") Msg("Node is not ephemeral, setting expiry instead of deleting")
} }
@@ -279,6 +275,7 @@ func (h *Headscale) waitForFollowup(
// registration is expired in the cache, instruct the client to try a new registration // registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey) return h.reqToNewRegisterResponse(req, machineKey)
} }
return nodeToRegisterResponse(node.View()), nil return nodeToRegisterResponse(node.View()), nil
} }
} }
@@ -316,7 +313,7 @@ func (h *Headscale) reqToNewRegisterResponse(
MachineKey: machineKey, MachineKey: machineKey,
NodeKey: req.NodeKey, NodeKey: req.NodeKey,
Hostinfo: hostinfo, Hostinfo: hostinfo,
LastSeen: ptr.To(time.Now()), LastSeen: new(time.Now()),
}, },
) )
@@ -324,7 +321,7 @@ func (h *Headscale) reqToNewRegisterResponse(
nodeToRegister.Node.Expiry = &req.Expiry nodeToRegister.Node.Expiry = &req.Expiry
} }
log.Info().Msgf("New followup node registration using key: %s", newRegID) log.Info().Msgf("new followup node registration using key: %s", newRegID)
h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister)
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
@@ -344,8 +341,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
} }
var perr types.PAKError
if errors.As(err, &perr) { if perr, ok := errors.AsType[types.PAKError](err); ok {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
} }
@@ -355,7 +352,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
// If node is not valid, it means an ephemeral node was deleted during logout // If node is not valid, it means an ephemeral node was deleted during logout
if !node.Valid() { if !node.Valid() {
h.Change(changed) h.Change(changed)
return nil, nil return nil, nil //nolint:nilnil // intentional: no node to return when ephemeral deleted
} }
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
@@ -397,8 +394,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
Caller(). Caller().
Interface("reg.resp", resp). Interface("reg.resp", resp).
Interface("reg.req", req). Interface("reg.req", req).
Str("node.name", node.Hostname()). EmbedObject(node).
Uint64("node.id", node.ID().Uint64()).
Msg("RegisterResponse") Msg("RegisterResponse")
return resp, nil return resp, nil
@@ -435,6 +431,7 @@ func (h *Headscale) handleRegisterInteractive(
Str("generated.hostname", hostname). Str("generated.hostname", hostname).
Msg("Received registration request with empty hostname, generated default") Msg("Received registration request with empty hostname, generated default")
} }
hostinfo.Hostname = hostname hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode( nodeToRegister := types.NewRegisterNode(
@@ -443,7 +440,7 @@ func (h *Headscale) handleRegisterInteractive(
MachineKey: machineKey, MachineKey: machineKey,
NodeKey: req.NodeKey, NodeKey: req.NodeKey,
Hostinfo: hostinfo, Hostinfo: hostinfo,
LastSeen: ptr.To(time.Now()), LastSeen: new(time.Now()),
}, },
) )
@@ -456,7 +453,7 @@ func (h *Headscale) handleRegisterInteractive(
nodeToRegister, nodeToRegister,
) )
log.Info().Msgf("Starting node registration using key: %s", registrationId) log.Info().Msgf("starting node registration using key: %s", registrationId)
return &tailcfg.RegisterResponse{ return &tailcfg.RegisterResponse{
AuthURL: h.authProvider.AuthURL(registrationId), AuthURL: h.authProvider.AuthURL(registrationId),

File diff suppressed because it is too large Load Diff

View File

@@ -40,6 +40,7 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.88": 125, "v1.88": 125,
"v1.90": 130, "v1.90": 130,
"v1.92": 131, "v1.92": 131,
"v1.94": 131,
} }
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{

View File

@@ -9,10 +9,9 @@ var tailscaleLatestMajorMinorTests = []struct {
stripV bool stripV bool
expected []string expected []string
}{ }{
{3, false, []string{"v1.88", "v1.90", "v1.92"}}, {3, false, []string{"v1.90", "v1.92", "v1.94"}},
{2, true, []string{"1.90", "1.92"}}, {2, true, []string{"1.92", "1.94"}},
{10, true, []string{ {10, true, []string{
"1.74",
"1.76", "1.76",
"1.78", "1.78",
"1.80", "1.80",
@@ -22,6 +21,7 @@ var tailscaleLatestMajorMinorTests = []struct {
"1.88", "1.88",
"1.90", "1.90",
"1.92", "1.92",
"1.94",
}}, }},
{0, false, nil}, {0, false, nil},
} }

View File

@@ -77,8 +77,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
Expiration: expiration, Expiration: expiration,
} }
if err := hsdb.DB.Save(&key).Error; err != nil { if err := hsdb.DB.Save(&key).Error; err != nil { //nolint:noinlineerr
return "", nil, fmt.Errorf("failed to save API key to database: %w", err) return "", nil, fmt.Errorf("saving API key to database: %w", err)
} }
return keyStr, &key, nil return keyStr, &key, nil
@@ -87,7 +87,9 @@ func (hsdb *HSDatabase) CreateAPIKey(
// ListAPIKeys returns the list of ApiKeys for a user. // ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
keys := []types.APIKey{} keys := []types.APIKey{}
if err := hsdb.DB.Find(&keys).Error; err != nil {
err := hsdb.DB.Find(&keys).Error
if err != nil {
return nil, err return nil, err
} }
@@ -126,7 +128,8 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
// ExpireAPIKey marks a ApiKey as expired. // ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil { err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error
if err != nil {
return err return err
} }

View File

@@ -24,7 +24,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"tailscale.com/net/tsaddr"
"zgo.at/zcache/v2" "zgo.at/zcache/v2"
) )
@@ -53,6 +52,8 @@ type HSDatabase struct {
// NewHeadscaleDatabase creates a new database connection and runs migrations. // NewHeadscaleDatabase creates a new database connection and runs migrations.
// It accepts the full configuration to allow migrations access to policy settings. // It accepts the full configuration to allow migrations access to policy settings.
//
//nolint:gocyclo // complex database initialization with many migrations
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
cfg *types.Config, cfg *types.Config,
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
@@ -62,6 +63,11 @@ func NewHeadscaleDatabase(
return nil, err return nil, err
} }
err = checkVersionUpgradePath(dbConn)
if err != nil {
return nil, fmt.Errorf("version check: %w", err)
}
migrations := gormigrate.New( migrations := gormigrate.New(
dbConn, dbConn,
gormigrate.DefaultOptions, gormigrate.DefaultOptions,
@@ -76,7 +82,7 @@ func NewHeadscaleDatabase(
ID: "202501221827", ID: "202501221827",
Migrate: func(tx *gorm.DB) error { Migrate: func(tx *gorm.DB) error {
// Remove any invalid routes associated with a node that does not exist. // Remove any invalid routes associated with a node that does not exist.
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { //nolint:staticcheck // SA1019: Route kept for migrations
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
if err != nil { if err != nil {
return err return err
@@ -84,14 +90,14 @@ func NewHeadscaleDatabase(
} }
// Remove any invalid routes without a node_id. // Remove any invalid routes without a node_id.
if tx.Migrator().HasTable(&types.Route{}) { if tx.Migrator().HasTable(&types.Route{}) { //nolint:staticcheck // SA1019: Route kept for migrations
err := tx.Exec("delete from routes where node_id is null").Error err := tx.Exec("delete from routes where node_id is null").Error
if err != nil { if err != nil {
return err return err
} }
} }
err := tx.AutoMigrate(&types.Route{}) err := tx.AutoMigrate(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
if err != nil { if err != nil {
return fmt.Errorf("automigrating types.Route: %w", err) return fmt.Errorf("automigrating types.Route: %w", err)
} }
@@ -109,6 +115,7 @@ func NewHeadscaleDatabase(
if err != nil { if err != nil {
return fmt.Errorf("automigrating types.PreAuthKey: %w", err) return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
} }
err = tx.AutoMigrate(&types.Node{}) err = tx.AutoMigrate(&types.Node{})
if err != nil { if err != nil {
return fmt.Errorf("automigrating types.Node: %w", err) return fmt.Errorf("automigrating types.Node: %w", err)
@@ -155,7 +162,8 @@ AND auth_key_id NOT IN (
nodeRoutes := map[uint64][]netip.Prefix{} nodeRoutes := map[uint64][]netip.Prefix{}
var routes []types.Route var routes []types.Route //nolint:staticcheck // SA1019: Route kept for migrations
err = tx.Find(&routes).Error err = tx.Find(&routes).Error
if err != nil { if err != nil {
return fmt.Errorf("fetching routes: %w", err) return fmt.Errorf("fetching routes: %w", err)
@@ -168,10 +176,10 @@ AND auth_key_id NOT IN (
} }
for nodeID, routes := range nodeRoutes { for nodeID, routes := range nodeRoutes {
tsaddr.SortPrefixes(routes) slices.SortFunc(routes, netip.Prefix.Compare)
routes = slices.Compact(routes) routes = slices.Compact(routes)
data, err := json.Marshal(routes) data, _ := json.Marshal(routes)
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
if err != nil { if err != nil {
@@ -180,7 +188,7 @@ AND auth_key_id NOT IN (
} }
// Drop the old table. // Drop the old table.
_ = tx.Migrator().DropTable(&types.Route{}) _ = tx.Migrator().DropTable(&types.Route{}) //nolint:staticcheck // SA1019: Route kept for migrations
return nil return nil
}, },
@@ -245,21 +253,24 @@ AND auth_key_id NOT IN (
Migrate: func(tx *gorm.DB) error { Migrate: func(tx *gorm.DB) error {
// Only run on SQLite // Only run on SQLite
if cfg.Database.Type != types.DatabaseSqlite { if cfg.Database.Type != types.DatabaseSqlite {
log.Info().Msg("Skipping schema migration on non-SQLite database") log.Info().Msg("skipping schema migration on non-SQLite database")
return nil return nil
} }
log.Info().Msg("Starting schema recreation with table renaming") log.Info().Msg("starting schema recreation with table renaming")
// Rename existing tables to _old versions // Rename existing tables to _old versions
tablesToRename := []string{"users", "pre_auth_keys", "api_keys", "nodes", "policies"} tablesToRename := []string{"users", "pre_auth_keys", "api_keys", "nodes", "policies"}
// Check if routes table exists and drop it (should have been migrated already) // Check if routes table exists and drop it (should have been migrated already)
var routesExists bool var routesExists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists) err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
if err == nil && routesExists { if err == nil && routesExists {
log.Info().Msg("Dropping leftover routes table") log.Info().Msg("dropping leftover routes table")
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
err := tx.Exec("DROP TABLE routes").Error
if err != nil {
return fmt.Errorf("dropping routes table: %w", err) return fmt.Errorf("dropping routes table: %w", err)
} }
} }
@@ -281,6 +292,7 @@ AND auth_key_id NOT IN (
for _, table := range tablesToRename { for _, table := range tablesToRename {
// Check if table exists before renaming // Check if table exists before renaming
var exists bool var exists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists) err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
if err != nil { if err != nil {
return fmt.Errorf("checking if table %s exists: %w", table, err) return fmt.Errorf("checking if table %s exists: %w", table, err)
@@ -291,7 +303,8 @@ AND auth_key_id NOT IN (
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error _ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
// Rename current table to _old // Rename current table to _old
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil { err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
if err != nil {
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err) return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
} }
} }
@@ -365,7 +378,8 @@ AND auth_key_id NOT IN (
} }
for _, createSQL := range tableCreationSQL { for _, createSQL := range tableCreationSQL {
if err := tx.Exec(createSQL).Error; err != nil { err := tx.Exec(createSQL).Error
if err != nil {
return fmt.Errorf("creating new table: %w", err) return fmt.Errorf("creating new table: %w", err)
} }
} }
@@ -394,7 +408,8 @@ AND auth_key_id NOT IN (
} }
for _, copySQL := range dataCopySQL { for _, copySQL := range dataCopySQL {
if err := tx.Exec(copySQL).Error; err != nil { err := tx.Exec(copySQL).Error
if err != nil {
return fmt.Errorf("copying data: %w", err) return fmt.Errorf("copying data: %w", err)
} }
} }
@@ -417,19 +432,21 @@ AND auth_key_id NOT IN (
} }
for _, indexSQL := range indexes { for _, indexSQL := range indexes {
if err := tx.Exec(indexSQL).Error; err != nil { err := tx.Exec(indexSQL).Error
if err != nil {
return fmt.Errorf("creating index: %w", err) return fmt.Errorf("creating index: %w", err)
} }
} }
// Drop old tables only after everything succeeds // Drop old tables only after everything succeeds
for _, table := range tablesToRename { for _, table := range tablesToRename {
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil { err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded") if err != nil {
log.Warn().Str("table", table+"_old").Err(err).Msg("failed to drop old table, but migration succeeded")
} }
} }
log.Info().Msg("Schema recreation completed successfully") log.Info().Msg("schema recreation completed successfully")
return nil return nil
}, },
@@ -595,12 +612,12 @@ AND auth_key_id NOT IN (
// 1. Load policy from file or database based on configuration // 1. Load policy from file or database based on configuration
policyData, err := PolicyBytes(tx, cfg) policyData, err := PolicyBytes(tx, cfg)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)") log.Warn().Err(err).Msg("failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil return nil
} }
if len(policyData) == 0 { if len(policyData) == 0 {
log.Info().Msg("No policy found, skipping RequestTags migration (tags will be validated on node reconnect)") log.Info().Msg("no policy found, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil return nil
} }
@@ -618,7 +635,7 @@ AND auth_key_id NOT IN (
// 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.) // 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.)
polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice()) polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice())
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)") log.Warn().Err(err).Msg("failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil return nil
} }
@@ -652,8 +669,7 @@ AND auth_key_id NOT IN (
if len(validatedTags) == 0 { if len(validatedTags) == 0 {
if len(rejectedTags) > 0 { if len(rejectedTags) > 0 {
log.Debug(). log.Debug().
Uint64("node.id", uint64(node.ID)). EmbedObject(node).
Str("node.name", node.Hostname).
Strs("rejected_tags", rejectedTags). Strs("rejected_tags", rejectedTags).
Msg("RequestTags rejected during migration (not authorized)") Msg("RequestTags rejected during migration (not authorized)")
} }
@@ -676,8 +692,7 @@ AND auth_key_id NOT IN (
} }
log.Info(). log.Info().
Uint64("node.id", uint64(node.ID)). EmbedObject(node).
Str("node.name", node.Hostname).
Strs("validated_tags", validatedTags). Strs("validated_tags", validatedTags).
Strs("rejected_tags", rejectedTags). Strs("rejected_tags", rejectedTags).
Strs("existing_tags", existingTags). Strs("existing_tags", existingTags).
@@ -750,6 +765,20 @@ AND auth_key_id NOT IN (
return nil, fmt.Errorf("migration failed: %w", err) return nil, fmt.Errorf("migration failed: %w", err)
} }
// Store the current version in the database after migrations succeed.
// Dev builds skip this to preserve the stored version for the next
// real versioned binary.
currentVersion := types.GetVersionInfo().Version
if !isDev(currentVersion) {
err = setDatabaseVersion(dbConn, currentVersion)
if err != nil {
return nil, fmt.Errorf(
"storing database version: %w",
err,
)
}
}
// Validate that the schema ends up in the expected state. // Validate that the schema ends up in the expected state.
// This is currently only done on sqlite as squibble does not // This is currently only done on sqlite as squibble does not
// support Postgres and we use our sqlite schema as our source of // support Postgres and we use our sqlite schema as our source of
@@ -762,6 +791,7 @@ AND auth_key_id NOT IN (
// or else it blocks... // or else it blocks...
sqlConn.SetMaxIdleConns(maxIdleConns) sqlConn.SetMaxIdleConns(maxIdleConns)
sqlConn.SetMaxOpenConns(maxOpenConns) sqlConn.SetMaxOpenConns(maxOpenConns)
defer sqlConn.SetMaxIdleConns(1) defer sqlConn.SetMaxIdleConns(1)
defer sqlConn.SetMaxOpenConns(1) defer sqlConn.SetMaxOpenConns(1)
@@ -779,7 +809,7 @@ AND auth_key_id NOT IN (
}, },
} }
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("validating schema: %w", err) return nil, fmt.Errorf("validating schema: %w", err)
} }
} }
@@ -805,6 +835,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
switch cfg.Type { switch cfg.Type {
case types.DatabaseSqlite: case types.DatabaseSqlite:
dir := filepath.Dir(cfg.Sqlite.Path) dir := filepath.Dir(cfg.Sqlite.Path)
err := util.EnsureDir(dir) err := util.EnsureDir(dir)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating directory for sqlite: %w", err) return nil, fmt.Errorf("creating directory for sqlite: %w", err)
@@ -858,7 +889,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
Str("path", dbString). Str("path", dbString).
Msg("Opening database") Msg("Opening database")
if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { //nolint:noinlineerr
if !sslEnabled { if !sslEnabled {
dbString += " sslmode=disable" dbString += " sslmode=disable"
} }
@@ -913,7 +944,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
// Get the current foreign key status // Get the current foreign key status
var fkOriginallyEnabled int var fkOriginallyEnabled int
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("checking foreign key status: %w", err) return fmt.Errorf("checking foreign key status: %w", err)
} }
@@ -937,33 +968,36 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
} }
for _, migrationID := range migrationIDs { for _, migrationID := range migrationIDs {
log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration") log.Trace().Caller().Str("migration_id", migrationID).Msg("running migration")
needsFKDisabled := migrationsRequiringFKDisabled[migrationID] needsFKDisabled := migrationsRequiringFKDisabled[migrationID]
if needsFKDisabled { if needsFKDisabled {
// Disable foreign keys for this migration // Disable foreign keys for this migration
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
if err != nil {
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err) return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
} }
} else { } else {
// Ensure foreign keys are enabled for this migration // Ensure foreign keys are enabled for this migration
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
if err != nil {
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err) return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
} }
} }
// Run up to this specific migration (will only run the next pending migration) // Run up to this specific migration (will only run the next pending migration)
if err := migrations.MigrateTo(migrationID); err != nil { err := migrations.MigrateTo(migrationID)
if err != nil {
return fmt.Errorf("running migration %s: %w", migrationID, err) return fmt.Errorf("running migration %s: %w", migrationID, err)
} }
} }
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("restoring foreign keys: %w", err) return fmt.Errorf("restoring foreign keys: %w", err)
} }
// Run the rest of the migrations // Run the rest of the migrations
if err := migrations.Migrate(); err != nil { if err := migrations.Migrate(); err != nil { //nolint:noinlineerr
return err return err
} }
@@ -981,16 +1015,22 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
if err != nil { if err != nil {
return err return err
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var violation constraintViolation var violation constraintViolation
if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil {
err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex)
if err != nil {
return err return err
} }
violatedConstraints = append(violatedConstraints, violation) violatedConstraints = append(violatedConstraints, violation)
} }
_ = rows.Close()
if err := rows.Err(); err != nil { //nolint:noinlineerr
return err
}
if len(violatedConstraints) > 0 { if len(violatedConstraints) > 0 {
for _, violation := range violatedConstraints { for _, violation := range violatedConstraints {
@@ -1005,7 +1045,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
} }
} else { } else {
// PostgreSQL can run all migrations in one block - no foreign key issues // PostgreSQL can run all migrations in one block - no foreign key issues
if err := migrations.Migrate(); err != nil { err := migrations.Migrate()
if err != nil {
return err return err
} }
} }
@@ -1016,6 +1057,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
func (hsdb *HSDatabase) PingDB(ctx context.Context) error { func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second) ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel() defer cancel()
sqlDB, err := hsdb.DB.DB() sqlDB, err := hsdb.DB.DB()
if err != nil { if err != nil {
return err return err
@@ -1031,7 +1073,7 @@ func (hsdb *HSDatabase) Close() error {
} }
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
db.Exec("VACUUM") db.Exec("VACUUM") //nolint:errcheck,noctx
} }
return db.Close() return db.Close()
@@ -1040,12 +1082,14 @@ func (hsdb *HSDatabase) Close() error {
func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error { func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error {
rx := hsdb.DB.Begin() rx := hsdb.DB.Begin()
defer rx.Rollback() defer rx.Rollback()
return fn(rx) return fn(rx)
} }
func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
rx := db.Begin() rx := db.Begin()
defer rx.Rollback() defer rx.Rollback()
ret, err := fn(rx) ret, err := fn(rx)
if err != nil { if err != nil {
var no T var no T
@@ -1058,7 +1102,9 @@ func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) {
func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
tx := hsdb.DB.Begin() tx := hsdb.DB.Begin()
defer tx.Rollback() defer tx.Rollback()
if err := fn(tx); err != nil {
err := fn(tx)
if err != nil {
return err return err
} }
@@ -1068,6 +1114,7 @@ func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error {
func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) { func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) {
tx := db.Begin() tx := db.Begin()
defer tx.Rollback() defer tx.Rollback()
ret, err := fn(tx) ret, err := fn(tx)
if err != nil { if err != nil {
var no T var no T

View File

@@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"os" "os"
"os/exec" "os/exec"
@@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
// Verify api_keys data preservation // Verify api_keys data preservation
var apiKeyCount int var apiKeyCount int
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema") assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
@@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
return err return err
} }
_, err = db.Exec(string(schemaContent)) _, err = db.ExecContext(context.Background(), string(schemaContent))
return err return err
} }
@@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
func requireConstraintFailed(t *testing.T, err error) { func requireConstraintFailed(t *testing.T, err error) {
t.Helper() t.Helper()
require.Error(t, err) require.Error(t, err)
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
} }
@@ -198,7 +201,7 @@ func TestConstraints(t *testing.T) {
}{ }{
{ {
name: "no-duplicate-username-if-no-oidc", name: "no-duplicate-username-if-no-oidc",
run: func(t *testing.T, db *gorm.DB) { run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
_, err := CreateUser(db, types.User{Name: "user1"}) _, err := CreateUser(db, types.User{Name: "user1"})
require.NoError(t, err) require.NoError(t, err)
_, err = CreateUser(db, types.User{Name: "user1"}) _, err = CreateUser(db, types.User{Name: "user1"})
@@ -207,7 +210,7 @@ func TestConstraints(t *testing.T) {
}, },
{ {
name: "no-oidc-duplicate-username-and-id", name: "no-oidc-duplicate-username-and-id",
run: func(t *testing.T, db *gorm.DB) { run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
user := types.User{ user := types.User{
Model: gorm.Model{ID: 1}, Model: gorm.Model{ID: 1},
Name: "user1", Name: "user1",
@@ -229,7 +232,7 @@ func TestConstraints(t *testing.T) {
}, },
{ {
name: "no-oidc-duplicate-id", name: "no-oidc-duplicate-id",
run: func(t *testing.T, db *gorm.DB) { run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
user := types.User{ user := types.User{
Model: gorm.Model{ID: 1}, Model: gorm.Model{ID: 1},
Name: "user1", Name: "user1",
@@ -251,7 +254,7 @@ func TestConstraints(t *testing.T) {
}, },
{ {
name: "allow-duplicate-username-cli-then-oidc", name: "allow-duplicate-username-cli-then-oidc",
run: func(t *testing.T, db *gorm.DB) { run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
_, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username _, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username
require.NoError(t, err) require.NoError(t, err)
@@ -266,7 +269,7 @@ func TestConstraints(t *testing.T) {
}, },
{ {
name: "allow-duplicate-username-oidc-then-cli", name: "allow-duplicate-username-oidc-then-cli",
run: func(t *testing.T, db *gorm.DB) { run: func(t *testing.T, db *gorm.DB) { //nolint:thelper
user := types.User{ user := types.User{
Name: "user1", Name: "user1",
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
@@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
} }
// Construct the pg_restore command // Construct the pg_restore command
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
// Set the output streams // Set the output streams
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
@@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
// skip already-applied migrations and only run new ones. // skip already-applied migrations and only run new ones.
func TestSQLiteAllTestdataMigrations(t *testing.T) { func TestSQLiteAllTestdataMigrations(t *testing.T) {
t.Parallel() t.Parallel()
schemas, err := os.ReadDir("testdata/sqlite") schemas, err := os.ReadDir("testdata/sqlite")
require.NoError(t, err) require.NoError(t, err)

View File

@@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines) t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Basic deletion tracking mechanism // Basic deletion tracking mechanism
var deletedIDs []types.NodeID var (
var deleteMutex sync.Mutex deletedIDs []types.NodeID
var deletionWg sync.WaitGroup deleteMutex sync.Mutex
deletionWg sync.WaitGroup
)
deleteFunc := func(nodeID types.NodeID) { deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock() deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID) deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock() deleteMutex.Unlock()
deletionWg.Done() deletionWg.Done()
} }
@@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
go gc.Start() go gc.Start()
// Schedule several nodes for deletion with short expiry // Schedule several nodes for deletion with short expiry
const expiry = fifty const (
const numNodes = 100 expiry = fifty
numNodes = 100
)
// Set up wait group for expected deletions // Set up wait group for expected deletions
deletionWg.Add(numNodes) deletionWg.Add(numNodes)
for i := 1; i <= numNodes; i++ { for i := 1; i <= numNodes; i++ {
gc.Schedule(types.NodeID(i), expiry) gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // safe conversion in test
} }
// Wait for all scheduled deletions to complete // Wait for all scheduled deletions to complete
@@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// Schedule and immediately cancel to test that part of the code // Schedule and immediately cancel to test that part of the code
for i := numNodes + 1; i <= numNodes*2; i++ { for i := numNodes + 1; i <= numNodes*2; i++ {
nodeID := types.NodeID(i) nodeID := types.NodeID(i) //nolint:gosec // safe conversion in test
gc.Schedule(nodeID, time.Hour) gc.Schedule(nodeID, time.Hour)
gc.Cancel(nodeID) gc.Cancel(nodeID)
} }
@@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. // and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorReschedule(t *testing.T) { func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Deletion tracking mechanism // Deletion tracking mechanism
var deletedIDs []types.NodeID var (
var deleteMutex sync.Mutex deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1) deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) { deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock() deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID) deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock() deleteMutex.Unlock()
deletionNotifier <- nodeID deletionNotifier <- nodeID
@@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Start GC // Start GC
gc := NewEphemeralGarbageCollector(deleteFunc) gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start() go gc.Start()
defer gc.Close() defer gc.Close()
const shortExpiry = fifty const (
const longExpiry = 1 * time.Hour shortExpiry = fifty
longExpiry = 1 * time.Hour
)
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
@@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// and verifies that the node is deleted only once. // and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// Deletion tracking mechanism // Deletion tracking mechanism
var deletedIDs []types.NodeID var (
var deleteMutex sync.Mutex deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1) deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) { deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock() deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID) deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock() deleteMutex.Unlock()
deletionNotifier <- nodeID deletionNotifier <- nodeID
} }
// Start the GC // Start the GC
gc := NewEphemeralGarbageCollector(deleteFunc) gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start() go gc.Start()
defer gc.Close() defer gc.Close()
nodeID := types.NodeID(1) nodeID := types.NodeID(1)
const expiry = fifty const expiry = fifty
// Schedule node for deletion // Schedule node for deletion
@@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. // It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
// Deletion tracking // Deletion tracking
var deletedIDs []types.NodeID var (
var deleteMutex sync.Mutex deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1) deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) { deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock() deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID) deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock() deleteMutex.Unlock()
deletionNotifier <- nodeID deletionNotifier <- nodeID
@@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines) t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking // Deletion tracking
var deletedIDs []types.NodeID var (
var deleteMutex sync.Mutex deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
nodeDeleted := make(chan struct{}) nodeDeleted := make(chan struct{})
deleteFunc := func(nodeID types.NodeID) { deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock() deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID) deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock() deleteMutex.Unlock()
close(nodeDeleted) // Signal that deletion happened close(nodeDeleted) // Signal that deletion happened
} }
@@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Use a WaitGroup to ensure the GC has started // Use a WaitGroup to ensure the GC has started
var startWg sync.WaitGroup var startWg sync.WaitGroup
startWg.Add(1) startWg.Add(1)
go func() { go func() {
startWg.Done() // Signal that the goroutine has started startWg.Done() // Signal that the goroutine has started
gc.Start() gc.Start()
}() }()
startWg.Wait() // Wait for the GC to start startWg.Wait() // Wait for the GC to start
// Close GC right away // Close GC right away
@@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Check no node was deleted // Check no node was deleted
deleteMutex.Lock() deleteMutex.Lock()
nodesDeleted := len(deletedIDs) nodesDeleted := len(deletedIDs)
deleteMutex.Unlock() deleteMutex.Unlock()
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
@@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines) t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking mechanism // Deletion tracking mechanism
var deletedIDs []types.NodeID var (
var deleteMutex sync.Mutex deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deleteFunc := func(nodeID types.NodeID) { deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock() deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID) deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock() deleteMutex.Unlock()
} }
@@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
go gc.Start() go gc.Start()
// Number of concurrent scheduling goroutines // Number of concurrent scheduling goroutines
const numSchedulers = 10 const (
const nodesPerScheduler = 50 numSchedulers = 10
nodesPerScheduler = 50
)
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
@@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
case <-stopScheduling: case <-stopScheduling:
return return
default: default:
nodeID := types.NodeID(baseNodeID + j + 1) nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // safe conversion in test
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
atomic.AddInt64(&scheduledCount, 1) atomic.AddInt64(&scheduledCount, 1)
// Yield to other goroutines to introduce variability // Yield to other goroutines to introduce variability

View File

@@ -17,7 +17,11 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
) )
var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip") var (
errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip")
errGeneratedIPNotInPrefix = errors.New("generated ip not in prefix")
errIPAllocatorNil = errors.New("ip allocator was nil")
)
// IPAllocator is a singleton responsible for allocating // IPAllocator is a singleton responsible for allocating
// IP addresses for nodes and making sure the same // IP addresses for nodes and making sure the same
@@ -62,8 +66,10 @@ func NewIPAllocator(
strategy: strategy, strategy: strategy,
} }
var v4s []sql.NullString var (
var v6s []sql.NullString v4s []sql.NullString
v6s []sql.NullString
)
if db != nil { if db != nil {
err := db.Read(func(rx *gorm.DB) error { err := db.Read(func(rx *gorm.DB) error {
@@ -135,15 +141,18 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
i.mu.Lock() i.mu.Lock()
defer i.mu.Unlock() defer i.mu.Unlock()
var err error var (
var ret4 *netip.Addr err error
var ret6 *netip.Addr ret4 *netip.Addr
ret6 *netip.Addr
)
if i.prefix4 != nil { if i.prefix4 != nil {
ret4, err = i.next(i.prev4, i.prefix4) ret4, err = i.next(i.prev4, i.prefix4)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err) return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
} }
i.prev4 = *ret4 i.prev4 = *ret4
} }
@@ -152,6 +161,7 @@ func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err) return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
} }
i.prev6 = *ret6 i.prev6 = *ret6
} }
@@ -168,8 +178,10 @@ func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.
} }
func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) { func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) {
var err error var (
var ip netip.Addr err error
ip netip.Addr
)
switch i.strategy { switch i.strategy {
case types.IPAllocationStrategySequential: case types.IPAllocationStrategySequential:
@@ -243,7 +255,8 @@ func randomNext(pfx netip.Prefix) (netip.Addr, error) {
if !pfx.Contains(ip) { if !pfx.Contains(ip) {
return netip.Addr{}, fmt.Errorf( return netip.Addr{}, fmt.Errorf(
"generated ip(%s) not in prefix(%s)", "%w: ip(%s) not in prefix(%s)",
errGeneratedIPNotInPrefix,
ip.String(), ip.String(),
pfx.String(), pfx.String(),
) )
@@ -268,11 +281,14 @@ func isTailscaleReservedIP(ip netip.Addr) bool {
// If a prefix type has been removed (IPv4 or IPv6), it // If a prefix type has been removed (IPv4 or IPv6), it
// will remove the IPs in that family from the node. // will remove the IPs in that family from the node.
func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
var err error var (
var ret []string err error
ret []string
)
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
if i == nil { if i == nil {
return errors.New("backfilling IPs: ip allocator was nil") return fmt.Errorf("backfilling IPs: %w", errIPAllocatorNil)
} }
log.Trace().Caller().Msgf("starting to backfill IPs") log.Trace().Caller().Msgf("starting to backfill IPs")
@@ -283,18 +299,19 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
} }
for _, node := range nodes { for _, node := range nodes {
log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database") log.Trace().Caller().EmbedObject(node).Msg("ip backfill check started because node found in database")
changed := false changed := false
// IPv4 prefix is set, but node ip is missing, alloc // IPv4 prefix is set, but node ip is missing, alloc
if i.prefix4 != nil && node.IPv4 == nil { if i.prefix4 != nil && node.IPv4 == nil {
ret4, err := i.nextLocked(i.prev4, i.prefix4) ret4, err := i.nextLocked(i.prev4, i.prefix4)
if err != nil { if err != nil {
return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err) return fmt.Errorf("allocating IPv4 for node(%d): %w", node.ID, err)
} }
node.IPv4 = ret4 node.IPv4 = ret4
changed = true changed = true
ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname)) ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname))
} }
@@ -302,11 +319,12 @@ func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) {
if i.prefix6 != nil && node.IPv6 == nil { if i.prefix6 != nil && node.IPv6 == nil {
ret6, err := i.nextLocked(i.prev6, i.prefix6) ret6, err := i.nextLocked(i.prev6, i.prefix6)
if err != nil { if err != nil {
return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err) return fmt.Errorf("allocating IPv6 for node(%d): %w", node.ID, err)
} }
node.IPv6 = ret6 node.IPv6 = ret6
changed = true changed = true
ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname)) ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname))
} }

View File

@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/types/ptr"
) )
var mpp = func(pref string) *netip.Prefix { var mpp = func(pref string) *netip.Prefix {
@@ -21,9 +20,7 @@ var mpp = func(pref string) *netip.Prefix {
return &p return &p
} }
var na = func(pref string) netip.Addr { var na = netip.MustParseAddr
return netip.MustParseAddr(pref)
}
var nap = func(pref string) *netip.Addr { var nap = func(pref string) *netip.Addr {
n := na(pref) n := na(pref)
@@ -158,8 +155,10 @@ func TestIPAllocatorSequential(t *testing.T) {
types.IPAllocationStrategySequential, types.IPAllocationStrategySequential,
) )
var got4s []netip.Addr var (
var got6s []netip.Addr got4s []netip.Addr
got6s []netip.Addr
)
for range tt.getCount { for range tt.getCount {
got4, got6, err := alloc.Next() got4, got6, err := alloc.Next()
@@ -175,6 +174,7 @@ func TestIPAllocatorSequential(t *testing.T) {
got6s = append(got6s, *got6) got6s = append(got6s, *got6)
} }
} }
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" { if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff) t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
} }
@@ -288,6 +288,7 @@ func TestBackfillIPAddresses(t *testing.T) {
fullNodeP := func(i int) *types.Node { fullNodeP := func(i int) *types.Node {
v4 := fmt.Sprintf("100.64.0.%d", i) v4 := fmt.Sprintf("100.64.0.%d", i)
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i) v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
return &types.Node{ return &types.Node{
IPv4: nap(v4), IPv4: nap(v4),
IPv6: nap(v6), IPv6: nap(v6),
@@ -484,12 +485,13 @@ func TestBackfillIPAddresses(t *testing.T) {
func TestIPAllocatorNextNoReservedIPs(t *testing.T) { func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
db, err := newSQLiteTestDB() db, err := newSQLiteTestDB()
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
alloc, err := NewIPAllocator( alloc, err := NewIPAllocator(
db, db,
ptr.To(tsaddr.CGNATRange()), new(tsaddr.CGNATRange()),
ptr.To(tsaddr.TailscaleULARange()), new(tsaddr.TailscaleULARange()),
types.IPAllocationStrategySequential, types.IPAllocationStrategySequential,
) )
if err != nil { if err != nil {
@@ -497,17 +499,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
} }
// Validate that we do not give out 100.100.100.100 // Validate that we do not give out 100.100.100.100
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange())) nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange()))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, na("100.100.100.101"), *nextQuad100) assert.Equal(t, na("100.100.100.101"), *nextQuad100)
// Validate that we do not give out fd7a:115c:a1e0::53 // Validate that we do not give out fd7a:115c:a1e0::53
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange())) nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange()))
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6) assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
// Validate that we do not give out fd7a:115c:a1e0::53 // Validate that we do not give out fd7a:115c:a1e0::53
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange())) nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange()))
t.Logf("chrome: %s", nextChrome.String()) t.Logf("chrome: %s", nextChrome.String())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, na("100.115.94.0"), *nextChrome) assert.Equal(t, na("100.115.94.0"), *nextChrome)

View File

@@ -16,18 +16,24 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
const ( const (
NodeGivenNameHashLength = 8 NodeGivenNameHashLength = 8
NodeGivenNameTrimSize = 2 NodeGivenNameTrimSize = 2
// defaultTestNodePrefix is the default hostname prefix for nodes created in tests.
defaultTestNodePrefix = "testnode"
) )
// ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique")
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ( var (
@@ -51,12 +57,14 @@ func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID)
// If at least one peer ID is given, only these peer nodes will be returned. // If at least one peer ID is given, only these peer nodes will be returned.
func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx.
err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
Where("id <> ?", nodeID). Where("id <> ?", nodeID).
Where(peerIDs).Find(&nodes).Error; err != nil { Where(peerIDs).Find(&nodes).Error
if err != nil {
return types.Nodes{}, err return types.Nodes{}, err
} }
@@ -75,11 +83,13 @@ func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error)
// or for the given nodes if at least one node ID is given as parameter. // or for the given nodes if at least one node ID is given as parameter.
func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx.
err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
Where(nodeIDs).Find(&nodes).Error; err != nil { Where(nodeIDs).Find(&nodes).Error
if err != nil {
return nil, err return nil, err
} }
@@ -89,7 +99,9 @@ func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) {
func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil {
err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error
if err != nil {
return nil, err return nil, err
} }
@@ -207,6 +219,7 @@ func SetTags(
slices.Sort(tags) slices.Sort(tags)
tags = slices.Compact(tags) tags = slices.Compact(tags)
b, err := json.Marshal(tags) b, err := json.Marshal(tags)
if err != nil { if err != nil {
return err return err
@@ -220,7 +233,7 @@ func SetTags(
return nil return nil
} }
// SetTags takes a Node struct pointer and update the forced tags. // SetApprovedRoutes takes a Node struct pointer and updates the approved routes.
func SetApprovedRoutes( func SetApprovedRoutes(
tx *gorm.DB, tx *gorm.DB,
nodeID types.NodeID, nodeID types.NodeID,
@@ -228,7 +241,8 @@ func SetApprovedRoutes(
) error { ) error {
if len(routes) == 0 { if len(routes) == 0 {
// if no routes are provided, we remove all // if no routes are provided, we remove all
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil { err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error
if err != nil {
return fmt.Errorf("removing approved routes: %w", err) return fmt.Errorf("removing approved routes: %w", err)
} }
@@ -251,7 +265,7 @@ func SetApprovedRoutes(
return err return err
} }
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("updating approved routes: %w", err) return fmt.Errorf("updating approved routes: %w", err)
} }
@@ -277,22 +291,25 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
func RenameNode(tx *gorm.DB, func RenameNode(tx *gorm.DB,
nodeID types.NodeID, newName string, nodeID types.NodeID, newName string,
) error { ) error {
if err := util.ValidateHostname(newName); err != nil { err := util.ValidateHostname(newName)
if err != nil {
return fmt.Errorf("renaming node: %w", err) return fmt.Errorf("renaming node: %w", err)
} }
// Check if the new name is unique // Check if the new name is unique
var count int64 var count int64
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
return fmt.Errorf("failed to check name uniqueness: %w", err) err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error
if err != nil {
return fmt.Errorf("checking name uniqueness: %w", err)
} }
if count > 0 { if count > 0 {
return errors.New("name is not unique") return ErrNodeNameNotUnique
} }
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { //nolint:noinlineerr
return fmt.Errorf("failed to rename node in the database: %w", err) return fmt.Errorf("renaming node in database: %w", err)
} }
return nil return nil
@@ -323,7 +340,8 @@ func DeleteNode(tx *gorm.DB,
node *types.Node, node *types.Node,
) error { ) error {
// Unscoped causes the node to be fully removed from the database. // Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil { err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error
if err != nil {
return err return err
} }
@@ -337,9 +355,11 @@ func (hsdb *HSDatabase) DeleteEphemeralNode(
nodeID types.NodeID, nodeID types.NodeID,
) error { ) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil { err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error
if err != nil {
return err return err
} }
return nil return nil
}) })
} }
@@ -352,19 +372,19 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
} }
logEvent := log.Debug(). logEvent := log.Debug().
Str("node", node.Hostname). Str(zf.NodeHostname, node.Hostname).
Str("machine_key", node.MachineKey.ShortString()). Str(zf.MachineKey, node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()) Str(zf.NodeKey, node.NodeKey.ShortString())
if node.User != nil { if node.User != nil {
logEvent = logEvent.Str("user", node.User.Username()) logEvent = logEvent.Str(zf.UserName, node.User.Username())
} else if node.UserID != nil { } else if node.UserID != nil {
logEvent = logEvent.Uint("user_id", *node.UserID) logEvent = logEvent.Uint(zf.UserID, *node.UserID)
} else { } else {
logEvent = logEvent.Str("user", "none") logEvent = logEvent.Str(zf.UserName, "none")
} }
logEvent.Msg("Registering test node") logEvent.Msg("registering test node")
// If the a new node is registered with the same machine key, to the same user, // If the a new node is registered with the same machine key, to the same user,
// update the existing node. // update the existing node.
@@ -379,6 +399,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
if ipv4 == nil { if ipv4 == nil {
ipv4 = oldNode.IPv4 ipv4 = oldNode.IPv4
} }
if ipv6 == nil { if ipv6 == nil {
ipv6 = oldNode.IPv6 ipv6 = oldNode.IPv6
} }
@@ -388,16 +409,17 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
// so we store the node.Expire and node.Nodekey that has been set when // so we store the node.Expire and node.Nodekey that has been set when
// adding it to the registrationCache // adding it to the registrationCache
if node.IPv4 != nil || node.IPv6 != nil { if node.IPv4 != nil || node.IPv6 != nil {
if err := tx.Save(&node).Error; err != nil { err := tx.Save(&node).Error
return nil, fmt.Errorf("failed register existing node in the database: %w", err) if err != nil {
return nil, fmt.Errorf("registering existing node in database: %w", err)
} }
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname). Str(zf.NodeHostname, node.Hostname).
Str("machine_key", node.MachineKey.ShortString()). Str(zf.MachineKey, node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()). Str(zf.NodeKey, node.NodeKey.ShortString()).
Str("user", node.User.Username()). Str(zf.UserName, node.User.Username()).
Msg("Test node authorized again") Msg("Test node authorized again")
return &node, nil return &node, nil
@@ -407,29 +429,30 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
node.IPv6 = ipv6 node.IPv6 = ipv6
var err error var err error
node.Hostname, err = util.NormaliseHostname(node.Hostname) node.Hostname, err = util.NormaliseHostname(node.Hostname)
if err != nil { if err != nil {
newHostname := util.InvalidString() newHostname := util.InvalidString()
log.Info().Err(err).Str("invalid-hostname", node.Hostname).Str("new-hostname", newHostname).Msgf("Invalid hostname, replacing") log.Info().Err(err).Str(zf.InvalidHostname, node.Hostname).Str(zf.NewHostname, newHostname).Msgf("invalid hostname, replacing")
node.Hostname = newHostname node.Hostname = newHostname
} }
if node.GivenName == "" { if node.GivenName == "" {
givenName, err := EnsureUniqueGivenName(tx, node.Hostname) givenName, err := EnsureUniqueGivenName(tx, node.Hostname)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to ensure unique given name: %w", err) return nil, fmt.Errorf("ensuring unique given name: %w", err)
} }
node.GivenName = givenName node.GivenName = givenName
} }
if err := tx.Save(&node).Error; err != nil { if err := tx.Save(&node).Error; err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("failed register(save) node in the database: %w", err) return nil, fmt.Errorf("saving node to database: %w", err)
} }
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname). Str(zf.NodeHostname, node.Hostname).
Msg("Test node registered with the database") Msg("Test node registered with the database")
return &node, nil return &node, nil
@@ -491,8 +514,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
func isUniqueName(tx *gorm.DB, name string) (bool, error) { func isUniqueName(tx *gorm.DB, name string) (bool, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx.
Where("given_name = ?", name).Find(&nodes).Error; err != nil { err := tx.
Where("given_name = ?", name).Find(&nodes).Error
if err != nil {
return false, err return false, err
} }
@@ -646,7 +671,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
panic("CreateNodeForTest requires a valid user") panic("CreateNodeForTest requires a valid user")
} }
nodeName := "testnode" nodeName := defaultTestNodePrefix
if len(hostname) > 0 && hostname[0] != "" { if len(hostname) > 0 && hostname[0] != "" {
nodeName = hostname[0] nodeName = hostname[0]
} }
@@ -657,6 +682,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err)) panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
} }
pakID := pak.ID
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
discoKey := key.NewDisco() discoKey := key.NewDisco()
@@ -668,7 +694,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
Hostname: nodeName, Hostname: nodeName,
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID), AuthKeyID: &pakID,
} }
err = hsdb.DB.Save(node).Error err = hsdb.DB.Save(node).Error
@@ -694,9 +720,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
} }
var registeredNode *types.Node var registeredNode *types.Node
err = hsdb.DB.Transaction(func(tx *gorm.DB) error { err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
var err error var err error
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
return err return err
}) })
if err != nil { if err != nil {
@@ -715,7 +744,7 @@ func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostname
panic("CreateNodesForTest requires a valid user") panic("CreateNodesForTest requires a valid user")
} }
prefix := "testnode" prefix := defaultTestNodePrefix
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0] prefix = hostnamePrefix[0]
} }
@@ -738,7 +767,7 @@ func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int
panic("CreateRegisteredNodesForTest requires a valid user") panic("CreateRegisteredNodesForTest requires a valid user")
} }
prefix := "testnode" prefix := defaultTestNodePrefix
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0] prefix = hostnamePrefix[0]
} }

View File

@@ -22,7 +22,6 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
func TestGetNode(t *testing.T) { func TestGetNode(t *testing.T) {
@@ -102,6 +101,8 @@ func TestExpireNode(t *testing.T) {
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err) require.NoError(t, err)
pakID := pak.ID
_, err = db.getNode(types.UserID(user.ID), "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode")
require.Error(t, err) require.Error(t, err)
@@ -115,7 +116,7 @@ func TestExpireNode(t *testing.T) {
Hostname: "testnode", Hostname: "testnode",
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID), AuthKeyID: &pakID,
Expiry: &time.Time{}, Expiry: &time.Time{},
} }
db.DB.Save(node) db.DB.Save(node)
@@ -146,6 +147,8 @@ func TestSetTags(t *testing.T) {
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err) require.NoError(t, err)
pakID := pak.ID
_, err = db.getNode(types.UserID(user.ID), "testnode") _, err = db.getNode(types.UserID(user.ID), "testnode")
require.Error(t, err) require.Error(t, err)
@@ -159,7 +162,7 @@ func TestSetTags(t *testing.T) {
Hostname: "testnode", Hostname: "testnode",
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID), AuthKeyID: &pakID,
} }
trx := db.DB.Save(node) trx := db.DB.Save(node)
@@ -187,6 +190,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
suppliedName string suppliedName string
randomSuffix bool randomSuffix bool
} }
tests := []struct { tests := []struct {
name string name string
args args args args
@@ -443,7 +447,7 @@ func TestAutoApproveRoutes(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.routes, RoutableIPs: tt.routes,
}, },
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), IPv4: new(netip.MustParseAddr("100.64.0.1")),
} }
err = adb.DB.Save(&node).Error err = adb.DB.Save(&node).Error
@@ -460,17 +464,17 @@ func TestAutoApproveRoutes(t *testing.T) {
RoutableIPs: tt.routes, RoutableIPs: tt.routes,
}, },
Tags: []string{"tag:exit"}, Tags: []string{"tag:exit"},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), IPv4: new(netip.MustParseAddr("100.64.0.2")),
} }
err = adb.DB.Save(&nodeTagged).Error err = adb.DB.Save(&nodeTagged).Error
require.NoError(t, err) require.NoError(t, err)
users, err := adb.ListUsers() users, err := adb.ListUsers()
assert.NoError(t, err) require.NoError(t, err)
nodes, err := adb.ListNodes() nodes, err := adb.ListNodes()
assert.NoError(t, err) require.NoError(t, err)
pm, err := pmf(users, nodes.ViewSlice()) pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err) require.NoError(t, err)
@@ -498,6 +502,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes1) == 0 { if len(expectedRoutes1) == 0 {
expectedRoutes1 = nil expectedRoutes1 = nil
} }
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
} }
@@ -509,6 +514,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes2) == 0 { if len(expectedRoutes2) == 0 {
expectedRoutes2 = nil expectedRoutes2 = nil
} }
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
} }
@@ -520,6 +526,7 @@ func TestAutoApproveRoutes(t *testing.T) {
func TestEphemeralGarbageCollectorOrder(t *testing.T) { func TestEphemeralGarbageCollectorOrder(t *testing.T) {
want := []types.NodeID{1, 3} want := []types.NodeID{1, 3}
got := []types.NodeID{} got := []types.NodeID{}
var mu sync.Mutex var mu sync.Mutex
deletionCount := make(chan struct{}, 10) deletionCount := make(chan struct{}, 10)
@@ -527,6 +534,7 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
e := NewEphemeralGarbageCollector(func(ni types.NodeID) { e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
got = append(got, ni) got = append(got, ni)
deletionCount <- struct{}{} deletionCount <- struct{}{}
@@ -576,8 +584,10 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
} }
func TestEphemeralGarbageCollectorLoads(t *testing.T) { func TestEphemeralGarbageCollectorLoads(t *testing.T) {
var got []types.NodeID var (
var mu sync.Mutex got []types.NodeID
mu sync.Mutex
)
want := 1000 want := 1000
@@ -589,6 +599,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
// Yield to other goroutines to introduce variability // Yield to other goroutines to introduce variability
runtime.Gosched() runtime.Gosched()
got = append(got, ni) got = append(got, ni)
atomic.AddInt64(&deletedCount, 1) atomic.AddInt64(&deletedCount, 1)
@@ -616,9 +627,12 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
} }
} }
func generateRandomNumber(t *testing.T, max int64) int64 { //nolint:unused
func generateRandomNumber(t *testing.T, maxVal int64) int64 {
t.Helper() t.Helper()
maxB := big.NewInt(max)
maxB := big.NewInt(maxVal)
n, err := rand.Int(rand.Reader, maxB) n, err := rand.Int(rand.Reader, maxB)
if err != nil { if err != nil {
t.Fatalf("getting random number: %s", err) t.Fatalf("getting random number: %s", err)
@@ -642,6 +656,9 @@ func TestListEphemeralNodes(t *testing.T) {
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
require.NoError(t, err) require.NoError(t, err)
pakID := pak.ID
pakEphID := pakEph.ID
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
@@ -649,7 +666,7 @@ func TestListEphemeralNodes(t *testing.T) {
Hostname: "test", Hostname: "test",
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID), AuthKeyID: &pakID,
} }
nodeEph := types.Node{ nodeEph := types.Node{
@@ -659,7 +676,7 @@ func TestListEphemeralNodes(t *testing.T) {
Hostname: "ephemeral", Hostname: "ephemeral",
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pakEph.ID), AuthKeyID: &pakEphID,
} }
err = db.DB.Save(&node).Error err = db.DB.Save(&node).Error
@@ -722,7 +739,7 @@ func TestNodeNaming(t *testing.T) {
nodeInvalidHostname := types.Node{ nodeInvalidHostname := types.Node{
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: "我的电脑", Hostname: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
UserID: &user2.ID, UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
} }
@@ -746,12 +763,15 @@ func TestNodeNaming(t *testing.T) {
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNodeForTest(tx, node2, nil, nil) _, err = RegisterNodeForTest(tx, node2, nil, nil)
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil) _, _ = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
return err return err
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -810,25 +830,25 @@ func TestNodeNaming(t *testing.T) {
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "test") return RenameNode(tx, nodes[0].ID, "test")
}) })
assert.ErrorContains(t, err, "name is not unique") require.ErrorContains(t, err, "name is not unique")
// Rename invalid chars // Rename invalid chars
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[2].ID, "我的电脑") return RenameNode(tx, nodes[2].ID, "我的电脑") //nolint:gosmopolitan // intentional i18n test data
}) })
assert.ErrorContains(t, err, "invalid characters") require.ErrorContains(t, err, "invalid characters")
// Rename too short // Rename too short
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[3].ID, "a") return RenameNode(tx, nodes[3].ID, "a")
}) })
assert.ErrorContains(t, err, "at least 2 characters") require.ErrorContains(t, err, "at least 2 characters")
// Rename with emoji // Rename with emoji
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "hostname-with-💩") return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
}) })
assert.ErrorContains(t, err, "invalid characters") require.ErrorContains(t, err, "invalid characters")
// Rename with only emoji // Rename with only emoji
err = db.Write(func(tx *gorm.DB) error { err = db.Write(func(tx *gorm.DB) error {
@@ -896,12 +916,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
}, },
{ {
name: "chinese_chars_with_dash_rejected", name: "chinese_chars_with_dash_rejected",
newName: "server-北京-01", newName: "server-北京-01", //nolint:gosmopolitan // intentional i18n test data
wantErr: "invalid characters", wantErr: "invalid characters",
}, },
{ {
name: "chinese_only_rejected", name: "chinese_only_rejected",
newName: "我的电脑", newName: "我的电脑", //nolint:gosmopolitan // intentional i18n test data
wantErr: "invalid characters", wantErr: "invalid characters",
}, },
{ {
@@ -911,7 +931,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
}, },
{ {
name: "mixed_chinese_emoji_rejected", name: "mixed_chinese_emoji_rejected",
newName: "测试💻机器", newName: "测试💻机器", //nolint:gosmopolitan // intentional i18n test data
wantErr: "invalid characters", wantErr: "invalid characters",
}, },
{ {
@@ -1000,6 +1020,7 @@ func TestListPeers(t *testing.T) {
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNodeForTest(tx, node2, nil, nil) _, err = RegisterNodeForTest(tx, node2, nil, nil)
return err return err
@@ -1085,6 +1106,7 @@ func TestListNodes(t *testing.T) {
if err != nil { if err != nil {
return err return err
} }
_, err = RegisterNodeForTest(tx, node2, nil, nil) _, err = RegisterNodeForTest(tx, node2, nil, nil)
return err return err

View File

@@ -17,7 +17,8 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
Data: policy, Data: policy,
} }
if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil { err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error
if err != nil {
return nil, err return nil, err
} }

View File

@@ -138,8 +138,8 @@ func CreatePreAuthKey(
Hash: hash, // Store hash Hash: hash, // Store hash
} }
if err := tx.Save(&key).Error; err != nil { if err := tx.Save(&key).Error; err != nil { //nolint:noinlineerr
return nil, fmt.Errorf("failed to create key in the database: %w", err) return nil, fmt.Errorf("creating key in database: %w", err)
} }
return &types.PreAuthKeyNew{ return &types.PreAuthKeyNew{
@@ -155,9 +155,7 @@ func CreatePreAuthKey(
} }
func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) { func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { return Read(hsdb.DB, ListPreAuthKeys)
return ListPreAuthKeys(rx)
})
} }
// ListPreAuthKeys returns all PreAuthKeys in the database. // ListPreAuthKeys returns all PreAuthKeys in the database.
@@ -296,7 +294,7 @@ func DestroyPreAuthKey(tx *gorm.DB, id uint64) error {
Where("auth_key_id = ?", id). Where("auth_key_id = ?", id).
Update("auth_key_id", nil).Error Update("auth_key_id", nil).Error
if err != nil { if err != nil {
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err) return fmt.Errorf("clearing auth_key_id on nodes: %w", err)
} }
// Then delete the pre-auth key // Then delete the pre-auth key
@@ -325,14 +323,15 @@ func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error {
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
err := tx.Model(k).Update("used", true).Error err := tx.Model(k).Update("used", true).Error
if err != nil { if err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err) return fmt.Errorf("updating key used status in database: %w", err)
} }
k.Used = true k.Used = true
return nil return nil
} }
// MarkExpirePreAuthKey marks a PreAuthKey as expired. // ExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error { func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
now := time.Now() now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error

View File

@@ -11,7 +11,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"tailscale.com/types/ptr"
) )
func TestCreatePreAuthKey(t *testing.T) { func TestCreatePreAuthKey(t *testing.T) {
@@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) {
test: func(t *testing.T, db *HSDatabase) { test: func(t *testing.T, db *HSDatabase) {
t.Helper() t.Helper()
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil) _, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil)
assert.Error(t, err) assert.Error(t, err)
}, },
}, },
@@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
Hostname: "testest", Hostname: "testest",
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(key.ID), AuthKeyID: new(key.ID),
} }
db.DB.Save(&node) db.DB.Save(&node)

View File

@@ -104,3 +104,9 @@ CREATE TABLE policies(
deleted_at datetime deleted_at datetime
); );
CREATE INDEX idx_policies_deleted_at ON policies(deleted_at); CREATE INDEX idx_policies_deleted_at ON policies(deleted_at);
CREATE TABLE database_versions(
id integer PRIMARY KEY,
version text NOT NULL,
updated_at datetime
);

View File

@@ -362,7 +362,8 @@ func (c *Config) Validate() error {
// ToURL builds a properly encoded SQLite connection string using _pragma parameters // ToURL builds a properly encoded SQLite connection string using _pragma parameters
// compatible with modernc.org/sqlite driver. // compatible with modernc.org/sqlite driver.
func (c *Config) ToURL() (string, error) { func (c *Config) ToURL() (string, error) {
if err := c.Validate(); err != nil { err := c.Validate()
if err != nil {
return "", fmt.Errorf("invalid config: %w", err) return "", fmt.Errorf("invalid config: %w", err)
} }
@@ -372,18 +373,23 @@ func (c *Config) ToURL() (string, error) {
if c.BusyTimeout > 0 { if c.BusyTimeout > 0 {
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout)) pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
} }
if c.JournalMode != "" { if c.JournalMode != "" {
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode)) pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
} }
if c.AutoVacuum != "" { if c.AutoVacuum != "" {
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum)) pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
} }
if c.WALAutocheckpoint >= 0 { if c.WALAutocheckpoint >= 0 {
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint)) pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
} }
if c.Synchronous != "" { if c.Synchronous != "" {
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous)) pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
} }
if c.ForeignKeys { if c.ForeignKeys {
pragmas = append(pragmas, "foreign_keys=ON") pragmas = append(pragmas, "foreign_keys=ON")
} }

View File

@@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
t.Errorf("Config.ToURL() error = %v", err) t.Errorf("Config.ToURL() error = %v", err)
return return
} }
if got != tt.want { if got != tt.want {
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want) t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
} }
@@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
Path: "", Path: "",
BusyTimeout: -1, BusyTimeout: -1,
} }
_, err := config.ToURL() _, err := config.ToURL()
if err == nil { if err == nil {
t.Error("Config.ToURL() with invalid config should return error") t.Error("Config.ToURL() with invalid config should return error")

View File

@@ -1,6 +1,7 @@
package sqliteconfig package sqliteconfig
import ( import (
"context"
"database/sql" "database/sql"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -101,7 +102,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
defer db.Close() defer db.Close()
// Test connection // Test connection
if err := db.Ping(); err != nil { ctx := context.Background()
err = db.PingContext(ctx)
if err != nil {
t.Fatalf("Failed to ping database: %v", err) t.Fatalf("Failed to ping database: %v", err)
} }
@@ -109,8 +113,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
for pragma, expectedValue := range tt.expected { for pragma, expectedValue := range tt.expected {
t.Run("pragma_"+pragma, func(t *testing.T) { t.Run("pragma_"+pragma, func(t *testing.T) {
var actualValue any var actualValue any
query := "PRAGMA " + pragma query := "PRAGMA " + pragma
err := db.QueryRow(query).Scan(&actualValue)
err := db.QueryRowContext(ctx, query).Scan(&actualValue)
if err != nil { if err != nil {
t.Fatalf("Failed to query %s: %v", query, err) t.Fatalf("Failed to query %s: %v", query, err)
} }
@@ -163,6 +169,8 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
} }
defer db.Close() defer db.Close()
ctx := context.Background()
// Create test tables with foreign key relationship // Create test tables with foreign key relationship
schema := ` schema := `
CREATE TABLE parent ( CREATE TABLE parent (
@@ -178,23 +186,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
); );
` `
if _, err := db.Exec(schema); err != nil { _, err = db.ExecContext(ctx, schema)
if err != nil {
t.Fatalf("Failed to create schema: %v", err) t.Fatalf("Failed to create schema: %v", err)
} }
// Insert parent record // Insert parent record
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { _, err = db.ExecContext(ctx, "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')")
if err != nil {
t.Fatalf("Failed to insert parent: %v", err) t.Fatalf("Failed to insert parent: %v", err)
} }
// Test 1: Valid foreign key should work // Test 1: Valid foreign key should work
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") _, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
if err != nil { if err != nil {
t.Fatalf("Valid foreign key insert failed: %v", err) t.Fatalf("Valid foreign key insert failed: %v", err)
} }
// Test 2: Invalid foreign key should fail // Test 2: Invalid foreign key should fail
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") _, err = db.ExecContext(ctx, "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
if err == nil { if err == nil {
t.Error("Expected foreign key constraint violation, but insert succeeded") t.Error("Expected foreign key constraint violation, but insert succeeded")
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") { } else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
@@ -204,7 +214,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
} }
// Test 3: Deleting referenced parent should fail // Test 3: Deleting referenced parent should fail
_, err = db.Exec("DELETE FROM parent WHERE id = 1") _, err = db.ExecContext(ctx, "DELETE FROM parent WHERE id = 1")
if err == nil { if err == nil {
t.Error("Expected foreign key constraint violation when deleting referenced parent") t.Error("Expected foreign key constraint violation when deleting referenced parent")
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") { } else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
@@ -249,7 +259,8 @@ func TestJournalModeValidation(t *testing.T) {
defer db.Close() defer db.Close()
var actualMode string var actualMode string
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
if err != nil { if err != nil {
t.Fatalf("Failed to query journal_mode: %v", err) t.Fatalf("Failed to query journal_mode: %v", err)
} }

View File

@@ -53,16 +53,19 @@ func newPostgresDBForTest(t *testing.T) *url.URL {
t.Helper() t.Helper()
ctx := t.Context() ctx := t.Context()
srv, err := postgrestest.Start(ctx) srv, err := postgrestest.Start(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(srv.Cleanup) t.Cleanup(srv.Cleanup)
u, err := srv.CreateDatabase(ctx) u, err := srv.CreateDatabase(ctx)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("created local postgres: %s", u) t.Logf("created local postgres: %s", u)
pu, _ := url.Parse(u) pu, _ := url.Parse(u)

View File

@@ -3,12 +3,19 @@ package db
import ( import (
"context" "context"
"encoding" "encoding"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
) )
var (
errUnmarshalTextValue = errors.New("unmarshalling text value")
errUnsupportedType = errors.New("unsupported type")
errTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
)
// Got from https://github.com/xdg-go/strum/blob/main/types.go // Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
@@ -24,7 +31,7 @@ func maybeInstantiatePtr(rv reflect.Value) {
} }
func decodingError(name string, err error) error { func decodingError(name string, err error) error {
return fmt.Errorf("error decoding to %s: %w", name, err) return fmt.Errorf("decoding to %s: %w", name, err)
} }
// TextSerialiser implements the Serialiser interface for fields that // TextSerialiser implements the Serialiser interface for fields that
@@ -42,22 +49,26 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
if dbValue != nil { if dbValue != nil {
var bytes []byte var bytes []byte
switch v := dbValue.(type) { switch v := dbValue.(type) {
case []byte: case []byte:
bytes = v bytes = v
case string: case string:
bytes = []byte(v) bytes = []byte(v)
default: default:
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue) return fmt.Errorf("%w: %#v", errUnmarshalTextValue, dbValue)
} }
if isTextUnmarshaler(fieldValue) { if isTextUnmarshaler(fieldValue) {
maybeInstantiatePtr(fieldValue) maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText") f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)} args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args) ret := f.Call(args)
if !ret[0].IsNil() { if !ret[0].IsNil() {
return decodingError(field.Name, ret[0].Interface().(error)) if err, ok := ret[0].Interface().(error); ok {
return decodingError(field.Name, err)
}
} }
// If the underlying field is to a pointer type, we need to // If the underlying field is to a pointer type, we need to
@@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
return nil return nil
} else { } else {
return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) return fmt.Errorf("%w: %T", errUnsupportedType, fieldValue.Interface())
} }
} }
@@ -87,8 +98,9 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
// always comparable, particularly when reflection is involved: // always comparable, particularly when reflection is involved:
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) { if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
return nil, nil return nil, nil //nolint:nilnil // intentional: nil value for GORM serializer
} }
b, err := v.MarshalText() b, err := v.MarshalText()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
return string(b), nil return string(b), nil
default: default:
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) return nil, fmt.Errorf("%w, got %T", errTextMarshalerOnly, v)
} }
} }

View File

@@ -12,9 +12,11 @@ import (
) )
var ( var (
ErrUserExists = errors.New("user already exists") ErrUserExists = errors.New("user already exists")
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
ErrUserWhereInvalidCount = errors.New("expect 0 or 1 where User structs")
ErrUserNotUnique = errors.New("expected exactly one user")
) )
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
@@ -26,10 +28,13 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
// CreateUser creates a new User. Returns error if could not be created // CreateUser creates a new User. Returns error if could not be created
// or another user already exists. // or another user already exists.
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) { func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
if err := util.ValidateHostname(user.Name); err != nil { err := util.ValidateHostname(user.Name)
if err != nil {
return nil, err return nil, err
} }
if err := tx.Create(&user).Error; err != nil {
err = tx.Create(&user).Error
if err != nil {
return nil, fmt.Errorf("creating user: %w", err) return nil, fmt.Errorf("creating user: %w", err)
} }
@@ -54,6 +59,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
if err != nil { if err != nil {
return err return err
} }
if len(nodes) > 0 { if len(nodes) > 0 {
return ErrUserStillHasNodes return ErrUserStillHasNodes
} }
@@ -62,6 +68,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
if err != nil { if err != nil {
return err return err
} }
for _, key := range keys { for _, key := range keys {
err = DestroyPreAuthKey(tx, key.ID) err = DestroyPreAuthKey(tx, key.ID)
if err != nil { if err != nil {
@@ -88,11 +95,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
// not exist or if another User exists with the new name. // not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
var err error var err error
oldUser, err := GetUserByID(tx, uid) oldUser, err := GetUserByID(tx, uid)
if err != nil { if err != nil {
return err return err
} }
if err = util.ValidateHostname(newName); err != nil {
if err = util.ValidateHostname(newName); err != nil { //nolint:noinlineerr
return err return err
} }
@@ -151,7 +160,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
// ListUsers gets all the existing users. // ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
if len(where) > 1 { if len(where) > 1 {
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) return nil, fmt.Errorf("%w, got %d", ErrUserWhereInvalidCount, len(where))
} }
var user *types.User var user *types.User
@@ -160,7 +169,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
} }
users := []types.User{} users := []types.User{}
if err := tx.Where(user).Find(&users).Error; err != nil {
err := tx.Where(user).Find(&users).Error
if err != nil {
return nil, err return nil, err
} }
@@ -180,7 +191,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
} }
if len(users) != 1 { if len(users) != 1 {
return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) return nil, fmt.Errorf("%w, found %d", ErrUserNotUnique, len(users))
} }
return &users[0], nil return &users[0], nil

View File

@@ -8,7 +8,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/ptr"
) )
func TestCreateAndDestroyUser(t *testing.T) { func TestCreateAndDestroyUser(t *testing.T) {
@@ -74,12 +73,14 @@ func TestDestroyUserErrors(t *testing.T) {
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err) require.NoError(t, err)
pakID := pak.ID
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testnode", Hostname: "testnode",
UserID: &user.ID, UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID), AuthKeyID: &pakID,
} }
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
require.NoError(t, trx.Error) require.NoError(t, trx.Error)

View File

@@ -0,0 +1,251 @@
package db
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
)
var errVersionUpgrade = errors.New("version upgrade not supported")
var errVersionDowngrade = errors.New("version downgrade not supported")
var errVersionMajorChange = errors.New("major version change not supported")
var errVersionParse = errors.New("cannot parse version")
var errVersionFormat = errors.New(
"version does not follow semver major.minor.patch format",
)
// DatabaseVersion tracks the headscale version that last
// successfully started against this database.
// It is a single-row table (ID is always 1).
type DatabaseVersion struct {
ID uint `gorm:"primaryKey"`
Version string `gorm:"not null"`
UpdatedAt time.Time
}
// semver holds parsed major.minor.patch components.
type semver struct {
Major int
Minor int
Patch int
}
func (s semver) String() string {
return fmt.Sprintf("v%d.%d.%d", s.Major, s.Minor, s.Patch)
}
// parseVersion parses a version string like "v0.25.0", "0.25.1",
// "v0.25.0-beta.1", or "v0.25.0-rc1+build123" into its major, minor,
// patch components. Pre-release and build metadata suffixes are stripped.
func parseVersion(s string) (semver, error) {
if s == "" || s == "dev" {
return semver{}, fmt.Errorf("%q: %w", s, errVersionParse)
}
v := strings.TrimPrefix(s, "v")
// Strip pre-release suffix (everything after first '-')
// and build metadata (everything after first '+').
if idx := strings.IndexAny(v, "-+"); idx != -1 {
v = v[:idx]
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
return semver{}, fmt.Errorf("%q: %w", s, errVersionFormat)
}
major, err := strconv.Atoi(parts[0])
if err != nil {
return semver{}, fmt.Errorf("invalid major version in %q: %w", s, err)
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return semver{}, fmt.Errorf("invalid minor version in %q: %w", s, err)
}
patch, err := strconv.Atoi(parts[2])
if err != nil {
return semver{}, fmt.Errorf("invalid patch version in %q: %w", s, err)
}
return semver{Major: major, Minor: minor, Patch: patch}, nil
}
// ensureDatabaseVersionTable creates the database_versions table if it
// does not already exist. Uses GORM AutoMigrate to handle dialect
// differences between SQLite (datetime) and PostgreSQL (timestamp).
// This runs before gormigrate migrations.
func ensureDatabaseVersionTable(db *gorm.DB) error {
err := db.AutoMigrate(&DatabaseVersion{})
if err != nil {
return fmt.Errorf("creating database version table: %w", err)
}
return nil
}
// getDatabaseVersion reads the stored version from the database.
// Returns an empty string if no version has been stored yet.
func getDatabaseVersion(db *gorm.DB) (string, error) {
var version string
result := db.Raw("SELECT version FROM database_versions WHERE id = 1").Scan(&version)
if result.Error != nil {
return "", fmt.Errorf("reading database version: %w", result.Error)
}
if result.RowsAffected == 0 {
return "", nil
}
return version, nil
}
// setDatabaseVersion upserts the version row in the database.
func setDatabaseVersion(db *gorm.DB, version string) error {
now := time.Now().UTC()
// Try update first, then insert if no rows affected.
result := db.Exec(
"UPDATE database_versions SET version = ?, updated_at = ? WHERE id = 1",
version, now,
)
if result.Error != nil {
return fmt.Errorf("updating database version: %w", result.Error)
}
if result.RowsAffected == 0 {
err := db.Exec(
"INSERT INTO database_versions (id, version, updated_at) VALUES (1, ?, ?)",
version, now,
).Error
if err != nil {
return fmt.Errorf("inserting database version: %w", err)
}
}
return nil
}
// isDev reports whether a version string represents a development build
// that should skip version checking.
func isDev(version string) bool {
return version == "" || version == "dev" || version == "(devel)"
}
// checkVersionUpgradePath verifies that the running headscale version
// is compatible with the version that last used this database.
//
// Rules:
// - If the running binary has no version ("dev" or empty), warn and skip.
// - If no version is stored in the database, allow (first run with this feature).
// - If the stored version is "dev", allow (previous run was unversioned).
// - Same minor version: always allowed (patch changes in either direction).
// - Single minor version upgrade (stored.minor+1 == current.minor): allowed.
// - Multi-minor upgrade or any minor downgrade: blocked with a fatal error.
func checkVersionUpgradePath(db *gorm.DB) error {
err := ensureDatabaseVersionTable(db)
if err != nil {
return err
}
currentVersion := types.GetVersionInfo().Version
// Running binary has no real version — skip the check but
// preserve whatever version is already stored.
if isDev(currentVersion) {
storedVersion, err := getDatabaseVersion(db)
if err != nil {
return err
}
if storedVersion != "" && !isDev(storedVersion) {
log.Warn().
Str("database_version", storedVersion).
Msg("running a development build of headscale without a version number, " +
"database version check is skipped, the stored database version is preserved")
}
return nil
}
storedVersion, err := getDatabaseVersion(db)
if err != nil {
return err
}
// No stored version — first run with this feature. Allow startup;
// the version will be stored after migrations succeed.
if storedVersion == "" {
return nil
}
// Previous run was an unversioned build — no meaningful comparison.
if isDev(storedVersion) {
return nil
}
current, err := parseVersion(currentVersion)
if err != nil {
return fmt.Errorf("parsing current version: %w", err)
}
stored, err := parseVersion(storedVersion)
if err != nil {
return fmt.Errorf("parsing stored database version: %w", err)
}
if current.Major != stored.Major {
return fmt.Errorf(
"headscale version %s cannot be used with a database last used by %s: %w",
currentVersion, storedVersion, errVersionMajorChange,
)
}
minorDiff := current.Minor - stored.Minor
switch {
case minorDiff == 0:
// Same minor version — patch changes are always fine.
return nil
case minorDiff == 1:
// Single minor version upgrade — allowed.
return nil
case minorDiff > 1:
// Multi-minor upgrade — blocked.
return fmt.Errorf(
"headscale version %s cannot be used with a database last used by %s, "+
"upgrading more than one minor version at a time is not supported, "+
"please upgrade to the latest v%d.%d.x release first, then to %s, "+
"release page: https://github.com/juanfont/headscale/releases: %w",
currentVersion, storedVersion,
stored.Major, stored.Minor+1,
current.String(),
errVersionUpgrade,
)
default:
// minorDiff < 0 — any minor downgrade is blocked.
return fmt.Errorf(
"headscale version %s cannot be used with a database last used by %s, "+
"downgrading to a previous minor version is not supported, "+
"release page: https://github.com/juanfont/headscale/releases: %w",
currentVersion, storedVersion,
errVersionDowngrade,
)
}
}

View File

@@ -0,0 +1,318 @@
package db
import (
"fmt"
"testing"
"github.com/glebarez/sqlite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func TestParseVersion(t *testing.T) {
tests := []struct {
input string
want semver
wantErr bool
}{
{input: "v0.25.0", want: semver{0, 25, 0}},
{input: "0.25.0", want: semver{0, 25, 0}},
{input: "v0.25.1", want: semver{0, 25, 1}},
{input: "v1.0.0", want: semver{1, 0, 0}},
{input: "v0.28.3", want: semver{0, 28, 3}},
// Pre-release suffixes stripped
{input: "v0.25.0-beta.1", want: semver{0, 25, 0}},
{input: "v0.25.0-rc1", want: semver{0, 25, 0}},
// Build metadata stripped
{input: "v0.25.0+build123", want: semver{0, 25, 0}},
{input: "v0.25.0-beta.1+build123", want: semver{0, 25, 0}},
// Invalid inputs
{input: "", wantErr: true},
{input: "dev", wantErr: true},
{input: "vfoo.bar.baz", wantErr: true},
{input: "v1.2", wantErr: true},
{input: "v1", wantErr: true},
{input: "not-a-version", wantErr: true},
{input: "v1.2.3.4", wantErr: true},
{input: "(devel)", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got, err := parseVersion(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func TestSemverString(t *testing.T) {
s := semver{0, 28, 3}
assert.Equal(t, "v0.28.3", s.String())
}
func TestIsDev(t *testing.T) {
assert.True(t, isDev(""))
assert.True(t, isDev("dev"))
assert.True(t, isDev("(devel)"))
assert.False(t, isDev("v0.28.0"))
assert.False(t, isDev("0.28.0"))
}
// versionTestDB creates an in-memory SQLite database with the
// database_versions table already bootstrapped.
func versionTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{})
require.NoError(t, err)
err = ensureDatabaseVersionTable(db)
require.NoError(t, err)
return db
}
func TestSetAndGetDatabaseVersion(t *testing.T) {
db := versionTestDB(t)
// Initially empty
v, err := getDatabaseVersion(db)
require.NoError(t, err)
assert.Empty(t, v)
// Set a version
err = setDatabaseVersion(db, "v0.27.0")
require.NoError(t, err)
v, err = getDatabaseVersion(db)
require.NoError(t, err)
assert.Equal(t, "v0.27.0", v)
// Update the version (upsert)
err = setDatabaseVersion(db, "v0.28.0")
require.NoError(t, err)
v, err = getDatabaseVersion(db)
require.NoError(t, err)
assert.Equal(t, "v0.28.0", v)
}
func TestEnsureDatabaseVersionTableIdempotent(t *testing.T) {
db, err := gorm.Open(sqlite.Open("file::memory:"), &gorm.Config{})
require.NoError(t, err)
// Call twice — should not error
err = ensureDatabaseVersionTable(db)
require.NoError(t, err)
err = ensureDatabaseVersionTable(db)
require.NoError(t, err)
}
// TestCheckVersionUpgradePathDirect tests the version comparison logic
// by directly seeding the database, bypassing types.GetVersionInfo()
// (which returns "dev" in test environments and cannot be overridden).
func TestCheckVersionUpgradePathDirect(t *testing.T) {
tests := []struct {
name string
storedVersion string // empty means no row stored
currentVersion string
wantErr bool
errContains string
}{
// Fresh database (no stored version)
{
name: "fresh db allows any version",
storedVersion: "",
currentVersion: "v0.28.0",
},
// Stored is dev
{
name: "real version over dev db",
storedVersion: "dev",
currentVersion: "v0.28.0",
},
{
name: "devel version in db",
storedVersion: "(devel)",
currentVersion: "v0.28.0",
},
// Same version
{
name: "same version",
storedVersion: "v0.27.0",
currentVersion: "v0.27.0",
},
// Patch changes within same minor
{
name: "patch upgrade",
storedVersion: "v0.27.0",
currentVersion: "v0.27.3",
},
{
name: "patch downgrade within same minor",
storedVersion: "v0.27.3",
currentVersion: "v0.27.0",
},
// Single minor upgrade
{
name: "single minor upgrade",
storedVersion: "v0.27.0",
currentVersion: "v0.28.0",
},
{
name: "single minor upgrade with different patches",
storedVersion: "v0.27.3",
currentVersion: "v0.28.1",
},
// Multi-minor upgrade (blocked)
{
name: "two minor versions ahead",
storedVersion: "v0.25.0",
currentVersion: "v0.27.0",
wantErr: true,
errContains: "latest v0.26.x",
},
{
name: "three minor versions ahead",
storedVersion: "v0.25.0",
currentVersion: "v0.28.0",
wantErr: true,
errContains: "latest v0.26.x",
},
// Minor downgrades (blocked)
{
name: "single minor downgrade",
storedVersion: "v0.28.0",
currentVersion: "v0.27.0",
wantErr: true,
errContains: "downgrading",
},
{
name: "multi minor downgrade",
storedVersion: "v0.28.0",
currentVersion: "v0.25.0",
wantErr: true,
errContains: "downgrading",
},
// Major version mismatch
{
name: "major version upgrade",
storedVersion: "v0.28.0",
currentVersion: "v1.0.0",
wantErr: true,
errContains: "major version",
},
{
name: "major version downgrade",
storedVersion: "v1.0.0",
currentVersion: "v0.28.0",
wantErr: true,
errContains: "major version",
},
// Pre-release versions
{
name: "pre-release single minor upgrade",
storedVersion: "v0.27.0",
currentVersion: "v0.28.0-beta.1",
},
{
name: "pre-release multi minor upgrade blocked",
storedVersion: "v0.25.0",
currentVersion: "v0.27.0-rc1",
wantErr: true,
errContains: "latest v0.26.x",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := versionTestDB(t)
// Seed the stored version if provided
if tt.storedVersion != "" {
err := setDatabaseVersion(db, tt.storedVersion)
require.NoError(t, err)
}
err := checkVersionUpgradePathFromVersions(db, tt.currentVersion)
if tt.wantErr {
require.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
// checkVersionUpgradePathFromVersions is a test helper that runs the
// version comparison logic with a specific currentVersion string,
// bypassing types.GetVersionInfo(). It replicates the logic from
// checkVersionUpgradePath but accepts the version as a parameter.
func checkVersionUpgradePathFromVersions(db *gorm.DB, currentVersion string) error {
if isDev(currentVersion) {
return nil
}
storedVersion, err := getDatabaseVersion(db)
if err != nil {
return err
}
if storedVersion == "" {
return nil
}
if isDev(storedVersion) {
return nil
}
current, err := parseVersion(currentVersion)
if err != nil {
return err
}
stored, err := parseVersion(storedVersion)
if err != nil {
return err
}
if current.Major != stored.Major {
return errVersionMajorChange
}
minorDiff := current.Minor - stored.Minor
switch {
case minorDiff == 0:
return nil
case minorDiff == 1:
return nil
case minorDiff > 1:
return fmt.Errorf(
"please upgrade to the latest v%d.%d.x release first: %w",
stored.Major, stored.Minor+1,
errVersionUpgrade,
)
default:
return fmt.Errorf("downgrading: %w", errVersionDowngrade)
}
}

View File

@@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON { if wantsJSON {
overview := h.state.DebugOverviewJSON() overview := h.state.DebugOverviewJSON()
overviewJSON, err := json.MarshalIndent(overview, "", " ") overviewJSON, err := json.MarshalIndent(overview, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(overviewJSON) _, _ = w.Write(overviewJSON)
} else { } else {
// Default to text/plain for backward compatibility // Default to text/plain for backward compatibility
overview := h.state.DebugOverview() overview := h.state.DebugOverview()
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(overview)) _, _ = w.Write([]byte(overview))
} }
})) }))
// Configuration endpoint // Configuration endpoint
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config := h.state.DebugConfig() config := h.state.DebugConfig()
configJSON, err := json.MarshalIndent(config, "", " ") configJSON, err := json.MarshalIndent(config, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(configJSON) _, _ = w.Write(configJSON)
})) }))
// Policy endpoint // Policy endpoint
@@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server {
} else { } else {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(policy)) _, _ = w.Write([]byte(policy))
})) }))
// Filter rules endpoint // Filter rules endpoint
@@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server {
httpError(w, err) httpError(w, err)
return return
} }
filterJSON, err := json.MarshalIndent(filter, "", " ") filterJSON, err := json.MarshalIndent(filter, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(filterJSON) _, _ = w.Write(filterJSON)
})) }))
// SSH policies endpoint // SSH policies endpoint
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sshPolicies := h.state.DebugSSHPolicies() sshPolicies := h.state.DebugSSHPolicies()
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ") sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(sshJSON) _, _ = w.Write(sshJSON)
})) }))
// DERP map endpoint // DERP map endpoint
@@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON { if wantsJSON {
derpInfo := h.state.DebugDERPJSON() derpInfo := h.state.DebugDERPJSON()
derpJSON, err := json.MarshalIndent(derpInfo, "", " ") derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(derpJSON) _, _ = w.Write(derpJSON)
} else { } else {
// Default to text/plain for backward compatibility // Default to text/plain for backward compatibility
derpInfo := h.state.DebugDERPMap() derpInfo := h.state.DebugDERPMap()
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(derpInfo)) _, _ = w.Write([]byte(derpInfo))
} }
})) }))
@@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON { if wantsJSON {
nodeStoreNodes := h.state.DebugNodeStoreJSON() nodeStoreNodes := h.state.DebugNodeStoreJSON()
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ") nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(nodeStoreJSON) _, _ = w.Write(nodeStoreJSON)
} else { } else {
// Default to text/plain for backward compatibility // Default to text/plain for backward compatibility
nodeStoreInfo := h.state.DebugNodeStore() nodeStoreInfo := h.state.DebugNodeStore()
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(nodeStoreInfo)) _, _ = w.Write([]byte(nodeStoreInfo))
} }
})) }))
// Registration cache endpoint // Registration cache endpoint
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cacheInfo := h.state.DebugRegistrationCache() cacheInfo := h.state.DebugRegistrationCache()
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ") cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(cacheJSON) _, _ = w.Write(cacheJSON)
})) }))
// Routes endpoint // Routes endpoint
@@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON { if wantsJSON {
routes := h.state.DebugRoutes() routes := h.state.DebugRoutes()
routesJSON, err := json.MarshalIndent(routes, "", " ") routesJSON, err := json.MarshalIndent(routes, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(routesJSON) _, _ = w.Write(routesJSON)
} else { } else {
// Default to text/plain for backward compatibility // Default to text/plain for backward compatibility
routes := h.state.DebugRoutesString() routes := h.state.DebugRoutesString()
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(routes)) _, _ = w.Write([]byte(routes))
} }
})) }))
@@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON { if wantsJSON {
policyManagerInfo := h.state.DebugPolicyManagerJSON() policyManagerInfo := h.state.DebugPolicyManagerJSON()
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
if err != nil { if err != nil {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(policyManagerJSON) _, _ = w.Write(policyManagerJSON)
} else { } else {
// Default to text/plain for backward compatibility // Default to text/plain for backward compatibility
policyManagerInfo := h.state.DebugPolicyManager() policyManagerInfo := h.state.DebugPolicyManager()
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(policyManagerInfo)) _, _ = w.Write([]byte(policyManagerInfo))
} }
})) }))
@@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if res == nil { if res == nil {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) _, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
return return
} }
@@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server {
httpError(w, err) httpError(w, err)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(resJSON) _, _ = w.Write(resJSON)
})) }))
// Batcher endpoint // Batcher endpoint
@@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(batcherJSON) _, _ = w.Write(batcherJSON)
} else { } else {
// Default to text/plain for backward compatibility // Default to text/plain for backward compatibility
batcherInfo := h.debugBatcher() batcherInfo := h.debugBatcher()
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(batcherInfo)) _, _ = w.Write([]byte(batcherInfo))
} }
})) }))
@@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
activeConnections: info.ActiveConnections, activeConnections: info.ActiveConnections,
}) })
totalNodes++ totalNodes++
if info.Connected { if info.Connected {
connectedCount++ connectedCount++
} }
@@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
activeConnections: 0, activeConnections: 0,
}) })
totalNodes++ totalNodes++
if connected { if connected {
connectedCount++ connectedCount++
} }
return true return true
}) })
} }
@@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
ActiveConnections: 0, ActiveConnections: 0,
} }
info.TotalNodes++ info.TotalNodes++
return true return true
}) })
} }

View File

@@ -28,11 +28,14 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
return nil, err return nil, err
} }
defer derpFile.Close() defer derpFile.Close()
var derpMap tailcfg.DERPMap var derpMap tailcfg.DERPMap
b, err := io.ReadAll(derpFile) b, err := io.ReadAll(derpFile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = yaml.Unmarshal(b, &derpMap) err = yaml.Unmarshal(b, &derpMap)
return &derpMap, err return &derpMap, err
@@ -57,12 +60,14 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var derpMap tailcfg.DERPMap var derpMap tailcfg.DERPMap
err = json.Unmarshal(body, &derpMap) err = json.Unmarshal(body, &derpMap)
return &derpMap, err return &derpMap, err
@@ -134,6 +139,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
for id := range dm.Regions { for id := range dm.Regions {
ids = append(ids, id) ids = append(ids, id)
} }
slices.Sort(ids) slices.Sort(ids)
for _, id := range ids { for _, id := range ids {
@@ -160,16 +166,18 @@ func derpRandom() *rand.Rand {
derpRandomOnce.Do(func() { derpRandomOnce.Do(func() {
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String()) seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
rnd := rand.New(rand.NewSource(0)) rnd := rand.New(rand.NewSource(0)) //nolint:gosec // weak random is fine for DERP scrambling
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) //nolint:gosec // safe conversion
derpRandomInst = rnd derpRandomInst = rnd
}) })
return derpRandomInst return derpRandomInst
} }
func resetDerpRandomForTesting() { func resetDerpRandomForTesting() {
derpRandomMu.Lock() derpRandomMu.Lock()
defer derpRandomMu.Unlock() defer derpRandomMu.Unlock()
derpRandomOnce = sync.Once{} derpRandomOnce = sync.Once{}
derpRandomInst = nil derpRandomInst = nil
} }

View File

@@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
viper.Set("dns.base_domain", tt.baseDomain) viper.Set("dns.base_domain", tt.baseDomain)
defer viper.Reset() defer viper.Reset()
resetDerpRandomForTesting() resetDerpRandomForTesting()
testMap := tt.derpMap.View().AsStruct() testMap := tt.derpMap.View().AsStruct()

View File

@@ -54,7 +54,7 @@ func NewDERPServer(
derpKey key.NodePrivate, derpKey key.NodePrivate,
cfg *types.DERPConfig, cfg *types.DERPConfig,
) (*DERPServer, error) { ) (*DERPServer, error) {
log.Trace().Caller().Msg("Creating new embedded DERP server") log.Trace().Caller().Msg("creating new embedded DERP server")
server := derpserver.New(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains server := derpserver.New(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains
if cfg.ServerVerifyClients { if cfg.ServerVerifyClients {
@@ -75,9 +75,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
if err != nil { if err != nil {
return tailcfg.DERPRegion{}, err return tailcfg.DERPRegion{}, err
} }
var host string
var port int var (
var portStr string host string
port int
portStr string
)
// Extract hostname and port from URL // Extract hostname and port from URL
host, portStr, err = net.SplitHostPort(serverURL.Host) host, portStr, err = net.SplitHostPort(serverURL.Host)
@@ -98,13 +101,13 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
// If debug flag is set, resolve hostname to IP address // If debug flag is set, resolve hostname to IP address
if debugUseDERPIP { if debugUseDERPIP {
ips, err := net.LookupIP(host) ips, err := new(net.Resolver).LookupIPAddr(context.Background(), host)
if err != nil { if err != nil {
log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host) log.Error().Caller().Err(err).Msgf("failed to resolve DERP hostname %s to IP, using hostname", host)
} else if len(ips) > 0 { } else if len(ips) > 0 {
// Use the first IP address // Use the first IP address
ipStr := ips[0].String() ipStr := ips[0].IP.String()
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr) log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: resolved %s to %s", host, ipStr)
host = ipStr host = ipStr
} }
} }
@@ -130,14 +133,16 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
if err != nil { if err != nil {
return tailcfg.DERPRegion{}, err return tailcfg.DERPRegion{}, err
} }
portSTUN, err := strconv.Atoi(portSTUNStr) portSTUN, err := strconv.Atoi(portSTUNStr)
if err != nil { if err != nil {
return tailcfg.DERPRegion{}, err return tailcfg.DERPRegion{}, err
} }
localDERPregion.Nodes[0].STUNPort = portSTUN localDERPregion.Nodes[0].STUNPort = portSTUN
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion) log.Info().Caller().Msgf("derp region: %+v", localDERPregion)
log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0]) log.Info().Caller().Msgf("derp nodes[0]: %+v", localDERPregion.Nodes[0])
return localDERPregion, nil return localDERPregion, nil
} }
@@ -155,8 +160,10 @@ func (d *DERPServer) DERPHandler(
Caller(). Caller().
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.") Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
} }
writer.Header().Set("Content-Type", "text/plain") writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusUpgradeRequired) writer.WriteHeader(http.StatusUpgradeRequired)
_, err := writer.Write([]byte("DERP requires connection upgrade")) _, err := writer.Write([]byte("DERP requires connection upgrade"))
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -206,6 +213,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
return return
} }
defer websocketConn.Close(websocket.StatusInternalError, "closing") defer websocketConn.Close(websocket.StatusInternalError, "closing")
if websocketConn.Subprotocol() != "derp" { if websocketConn.Subprotocol() != "derp" {
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
@@ -222,9 +230,10 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
hijacker, ok := writer.(http.Hijacker) hijacker, ok := writer.(http.Hijacker)
if !ok { if !ok {
log.Error().Caller().Msg("DERP requires Hijacker interface from Gin") log.Error().Caller().Msg("derp requires Hijacker interface from Gin")
writer.Header().Set("Content-Type", "text/plain") writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("HTTP does not support general TCP support")) _, err := writer.Write([]byte("HTTP does not support general TCP support"))
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -238,9 +247,10 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
netConn, conn, err := hijacker.Hijack() netConn, conn, err := hijacker.Hijack()
if err != nil { if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed") log.Error().Caller().Err(err).Msgf("hijack failed")
writer.Header().Set("Content-Type", "text/plain") writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, err = writer.Write([]byte("HTTP does not support general TCP support")) _, err = writer.Write([]byte("HTTP does not support general TCP support"))
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -251,7 +261,8 @@ func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) {
return return
} }
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
log.Trace().Caller().Msgf("hijacked connection from %v", req.RemoteAddr)
if !fastStart { if !fastStart {
pubKey := d.key.Public() pubKey := d.key.Public()
@@ -280,6 +291,7 @@ func DERPProbeHandler(
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
default: default:
writer.WriteHeader(http.StatusMethodNotAllowed) writer.WriteHeader(http.StatusMethodNotAllowed)
_, err := writer.Write([]byte("bogus probe method")) _, err := writer.Write([]byte("bogus probe method"))
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -309,9 +321,11 @@ func DERPBootstrapDNSHandler(
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
defer cancel() defer cancel()
var resolver net.Resolver var resolver net.Resolver
for _, region := range derpMap.Regions().All() {
for _, node := range region.Nodes().All() { // we don't care if we override some nodes for _, region := range derpMap.Regions().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
for _, node := range region.Nodes().All() { //nolint:unqueryvet // not SQLBoiler, tailcfg iterator
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
if err != nil { if err != nil {
log.Trace(). log.Trace().
@@ -321,11 +335,14 @@ func DERPBootstrapDNSHandler(
continue continue
} }
dnsEntries[node.HostName()] = addrs dnsEntries[node.HostName()] = addrs
} }
} }
writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
err := json.NewEncoder(writer).Encode(dnsEntries) err := json.NewEncoder(writer).Encode(dnsEntries)
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -338,33 +355,37 @@ func DERPBootstrapDNSHandler(
// ServeSTUN starts a STUN server on the configured addr. // ServeSTUN starts a STUN server on the configured addr.
func (d *DERPServer) ServeSTUN() { func (d *DERPServer) ServeSTUN() {
packetConn, err := net.ListenPacket("udp", d.cfg.STUNAddr) packetConn, err := new(net.ListenConfig).ListenPacket(context.Background(), "udp", d.cfg.STUNAddr)
if err != nil { if err != nil {
log.Fatal().Msgf("failed to open STUN listener: %v", err) log.Fatal().Msgf("failed to open STUN listener: %v", err)
} }
log.Info().Msgf("STUN server started at %s", packetConn.LocalAddr())
log.Info().Msgf("stun server started at %s", packetConn.LocalAddr())
udpConn, ok := packetConn.(*net.UDPConn) udpConn, ok := packetConn.(*net.UDPConn)
if !ok { if !ok {
log.Fatal().Msg("STUN listener is not a UDP listener") log.Fatal().Msg("stun listener is not a UDP listener")
} }
serverSTUNListener(context.Background(), udpConn) serverSTUNListener(context.Background(), udpConn)
} }
func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
var buf [64 << 10]byte
var ( var (
buf [64 << 10]byte
bytesRead int bytesRead int
udpAddr *net.UDPAddr udpAddr *net.UDPAddr
err error err error
) )
for { for {
bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:]) bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:])
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Error().Caller().Err(err).Msgf("STUN ReadFrom")
log.Error().Caller().Err(err).Msgf("stun ReadFrom")
// Rate limit error logging - wait before retrying, but respect context cancellation // Rate limit error logging - wait before retrying, but respect context cancellation
select { select {
@@ -375,25 +396,29 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
continue continue
} }
log.Trace().Caller().Msgf("STUN request from %v", udpAddr)
log.Trace().Caller().Msgf("stun request from %v", udpAddr)
pkt := buf[:bytesRead] pkt := buf[:bytesRead]
if !stun.Is(pkt) { if !stun.Is(pkt) {
log.Trace().Caller().Msgf("UDP packet is not STUN") log.Trace().Caller().Msgf("udp packet is not stun")
continue continue
} }
txid, err := stun.ParseBindingRequest(pkt) txid, err := stun.ParseBindingRequest(pkt)
if err != nil { if err != nil {
log.Trace().Caller().Err(err).Msgf("STUN parse error") log.Trace().Caller().Err(err).Msgf("stun parse error")
continue continue
} }
addr, _ := netip.AddrFromSlice(udpAddr.IP) addr, _ := netip.AddrFromSlice(udpAddr.IP)
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port))) //nolint:gosec // port is always <=65535
_, err = packetConn.WriteTo(res, udpAddr) _, err = packetConn.WriteTo(res, udpAddr)
if err != nil { if err != nil {
log.Trace().Caller().Err(err).Msgf("Issue writing to UDP") log.Trace().Caller().Err(err).Msgf("issue writing to UDP")
continue continue
} }
@@ -412,8 +437,10 @@ type DERPVerifyTransport struct {
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := t.handleVerifyRequest(req, buf); err != nil {
log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ") err := t.handleVerifyRequest(req, buf)
if err != nil {
log.Error().Caller().Err(err).Msg("failed to handle client verify request")
return nil, err return nil, err
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@@ -15,6 +16,9 @@ import (
"tailscale.com/util/set" "tailscale.com/util/set"
) )
// ErrPathIsDirectory is returned when a directory path is provided where a file is expected.
var ErrPathIsDirectory = errors.New("path is a directory, only file is supported")
type ExtraRecordsMan struct { type ExtraRecordsMan struct {
mu sync.RWMutex mu sync.RWMutex
records set.Set[tailcfg.DNSRecord] records set.Set[tailcfg.DNSRecord]
@@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) {
} }
if fi.IsDir() { if fi.IsDir() {
return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path)
} }
records, hash, err := readExtraRecordsFromPath(path) records, hash, err := readExtraRecordsFromPath(path)
@@ -85,19 +89,22 @@ func (e *ExtraRecordsMan) Run() {
log.Error().Caller().Msgf("file watcher event channel closing") log.Error().Caller().Msgf("file watcher event channel closing")
return return
} }
switch event.Op { switch event.Op {
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
if event.Name != e.path { if event.Name != e.path {
continue continue
} }
e.updateRecords() e.updateRecords()
// If a file is removed or renamed, fsnotify will loose track of it // If a file is removed or renamed, fsnotify will loose track of it
// and not watch it. We will therefore attempt to re-add it with a backoff. // and not watch it. We will therefore attempt to re-add it with a backoff.
case fsnotify.Remove, fsnotify.Rename: case fsnotify.Remove, fsnotify.Rename:
_, err := backoff.Retry(context.Background(), func() (struct{}, error) { _, err := backoff.Retry(context.Background(), func() (struct{}, error) {
if _, err := os.Stat(e.path); err != nil { if _, err := os.Stat(e.path); err != nil { //nolint:noinlineerr
return struct{}{}, err return struct{}{}, err
} }
@@ -123,6 +130,7 @@ func (e *ExtraRecordsMan) Run() {
log.Error().Caller().Msgf("file watcher error channel closing") log.Error().Caller().Msgf("file watcher error channel closing")
return return
} }
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
} }
} }
@@ -165,6 +173,7 @@ func (e *ExtraRecordsMan) updateRecords() {
e.hashes[e.path] = newHash e.hashes[e.path] = newHash
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
e.updateCh <- e.records.Slice() e.updateCh <- e.records.Slice()
} }
@@ -183,6 +192,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
} }
var records []tailcfg.DNSRecord var records []tailcfg.DNSRecord
err = json.Unmarshal(b, &records) err = json.Unmarshal(b, &records)
if err != nil { if err != nil {
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)

View File

@@ -29,6 +29,7 @@ import (
"github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
) )
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
@@ -54,7 +55,7 @@ func (api headscaleV1APIServer) CreateUser(
} }
user, policyChanged, err := api.h.state.CreateUser(newUser) user, policyChanged, err := api.h.state.CreateUser(newUser)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) return nil, status.Errorf(codes.Internal, "creating user: %s", err)
} }
// CreateUser returns a policy change response if the user creation affected policy. // CreateUser returns a policy change response if the user creation affected policy.
@@ -235,16 +236,16 @@ func (api headscaleV1APIServer) RegisterNode(
// Generate ephemeral registration key for tracking this registration flow in logs // Generate ephemeral registration key for tracking this registration flow in logs
registrationKey, err := util.GenerateRegistrationKey() registrationKey, err := util.GenerateRegistrationKey()
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Failed to generate registration key") log.Warn().Err(err).Msg("failed to generate registration key")
registrationKey = "" // Continue without key if generation fails registrationKey = "" // Continue without key if generation fails
} }
log.Trace(). log.Trace().
Caller(). Caller().
Str("user", request.GetUser()). Str(zf.UserName, request.GetUser()).
Str("registration_id", request.GetKey()). Str(zf.RegistrationID, request.GetKey()).
Str("registration_key", registrationKey). Str(zf.RegistrationKey, registrationKey).
Msg("Registering node") Msg("registering node")
registrationId, err := types.RegistrationIDFromString(request.GetKey()) registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil { if err != nil {
@@ -264,17 +265,16 @@ func (api headscaleV1APIServer) RegisterNode(
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Str("registration_key", registrationKey). Str(zf.RegistrationKey, registrationKey).
Err(err). Err(err).
Msg("Failed to register node") Msg("failed to register node")
return nil, err return nil, err
} }
log.Info(). log.Info().
Str("registration_key", registrationKey). Str(zf.RegistrationKey, registrationKey).
Str("node_id", fmt.Sprintf("%d", node.ID())). EmbedObject(node).
Str("hostname", node.Hostname()). Msg("node registered successfully")
Msg("Node registered successfully")
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here. // dependency here.
@@ -355,9 +355,9 @@ func (api headscaleV1APIServer) SetTags(
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname()). EmbedObject(node).
Strs("tags", request.GetTags()). Strs("tags", request.GetTags()).
Msg("Changing tags of node") Msg("changing tags of node")
return &v1.SetTagsResponse{Node: node.Proto()}, nil return &v1.SetTagsResponse{Node: node.Proto()}, nil
} }
@@ -368,7 +368,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
) (*v1.SetApprovedRoutesResponse, error) { ) (*v1.SetApprovedRoutesResponse, error) {
log.Debug(). log.Debug().
Caller(). Caller().
Uint64("node.id", request.GetNodeId()). Uint64(zf.NodeID, request.GetNodeId()).
Strs("requestedRoutes", request.GetRoutes()). Strs("requestedRoutes", request.GetRoutes()).
Msg("gRPC SetApprovedRoutes called") Msg("gRPC SetApprovedRoutes called")
@@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
newApproved = append(newApproved, prefix) newApproved = append(newApproved, prefix)
} }
} }
tsaddr.SortPrefixes(newApproved) slices.SortFunc(newApproved, netip.Prefix.Compare)
newApproved = slices.Compact(newApproved) newApproved = slices.Compact(newApproved)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
@@ -406,7 +406,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
log.Debug(). log.Debug().
Caller(). Caller().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)). Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)).
Strs("finalSubnetRoutes", proto.SubnetRoutes). Strs("finalSubnetRoutes", proto.SubnetRoutes).
@@ -423,7 +423,7 @@ func validateTag(tag string) error {
return errors.New("tag should be lowercase") return errors.New("tag should be lowercase")
} }
if len(strings.Fields(tag)) > 1 { if len(strings.Fields(tag)) > 1 {
return errors.New("tag should not contains space") return errors.New("tags must not contain spaces")
} }
return nil return nil
} }
@@ -466,8 +466,8 @@ func (api headscaleV1APIServer) ExpireNode(
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname()). EmbedObject(node).
Time("expiry", *node.AsStruct().Expiry). Time(zf.ExpiresAt, *node.AsStruct().Expiry).
Msg("node expired") Msg("node expired")
return &v1.ExpireNodeResponse{Node: node.Proto()}, nil return &v1.ExpireNodeResponse{Node: node.Proto()}, nil
@@ -487,8 +487,8 @@ func (api headscaleV1APIServer) RenameNode(
log.Trace(). log.Trace().
Caller(). Caller().
Str("node", node.Hostname()). EmbedObject(node).
Str("new_name", request.GetNewName()). Str(zf.NewName, request.GetNewName()).
Msg("node renamed") Msg("node renamed")
return &v1.RenameNodeResponse{Node: node.Proto()}, nil return &v1.RenameNodeResponse{Node: node.Proto()}, nil
@@ -546,7 +546,7 @@ func (api headscaleV1APIServer) BackfillNodeIPs(
ctx context.Context, ctx context.Context,
request *v1.BackfillNodeIPsRequest, request *v1.BackfillNodeIPsRequest,
) (*v1.BackfillNodeIPsResponse, error) { ) (*v1.BackfillNodeIPsResponse, error) {
log.Trace().Caller().Msg("Backfill called") log.Trace().Caller().Msg("backfill called")
if !request.Confirmed { if !request.Confirmed {
return nil, errors.New("not confirmed, aborting") return nil, errors.New("not confirmed, aborting")
@@ -817,13 +817,13 @@ func (api headscaleV1APIServer) Health(
response := &v1.HealthResponse{} response := &v1.HealthResponse{}
if err := api.h.state.PingDB(ctx); err != nil { if err := api.h.state.PingDB(ctx); err != nil {
healthErr = fmt.Errorf("database ping failed: %w", err) healthErr = fmt.Errorf("pinging database: %w", err)
} else { } else {
response.DatabaseConnectivity = true response.DatabaseConnectivity = true
} }
if healthErr != nil { if healthErr != nil {
log.Error().Err(healthErr).Msg("Health check failed") log.Error().Err(healthErr).Msg("health check failed")
} }
return response, healthErr return response, healthErr

View File

@@ -17,6 +17,7 @@ func Test_validateTag(t *testing.T) {
type args struct { type args struct {
tag string tag string
} }
tests := []struct { tests := []struct {
name string name string
args args args args
@@ -45,7 +46,8 @@ func Test_validateTag(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr { err := validateTag(tt.args.tag)
if (err != nil) != tt.wantErr {
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })

View File

@@ -20,7 +20,7 @@ import (
) )
const ( const (
// The CapabilityVersion is used by Tailscale clients to indicate // NoiseCapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021 // their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28, but we only have good support for it // from CapabilityVersion 28, but we only have good support for it
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port). // since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
@@ -36,8 +36,7 @@ const (
// httpError logs an error and sends an HTTP error response with the given. // httpError logs an error and sends an HTTP error response with the given.
func httpError(w http.ResponseWriter, err error) { func httpError(w http.ResponseWriter, err error) {
var herr HTTPError if herr, ok := errors.AsType[HTTPError](err); ok {
if errors.As(err, &herr) {
http.Error(w, herr.Msg, herr.Code) http.Error(w, herr.Msg, herr.Code)
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg) log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
} else { } else {
@@ -56,7 +55,7 @@ type HTTPError struct {
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
func (e HTTPError) Unwrap() error { return e.Err } func (e HTTPError) Unwrap() error { return e.Err }
// Error returns an HTTPError containing the given information. // NewHTTPError returns an HTTPError containing the given information.
func NewHTTPError(code int, msg string, err error) HTTPError { func NewHTTPError(code int, msg string, err error) HTTPError {
return HTTPError{Code: code, Msg: msg, Err: err} return HTTPError{Code: code, Msg: msg, Err: err}
} }
@@ -64,7 +63,7 @@ func NewHTTPError(code int, msg string, err error) HTTPError {
var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil) var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil)
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
"machines registered with CLI does not support expire", "machines registered with CLI do not support expiry",
) )
func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) { func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) {
@@ -76,7 +75,7 @@ func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr) clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
if err != nil { if err != nil {
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err)) return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("parsing capability version: %w", err))
} }
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
@@ -88,12 +87,12 @@ func (h *Headscale) handleVerifyRequest(
) error { ) error {
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
if err != nil { if err != nil {
return fmt.Errorf("cannot read request body: %w", err) return fmt.Errorf("reading request body: %w", err)
} }
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { //nolint:noinlineerr
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("parsing DERP client request: %w", err))
} }
nodes := h.state.ListNodes() nodes := h.state.ListNodes()
@@ -155,7 +154,11 @@ func (h *Headscale) KeyHandler(
} }
writer.Header().Set("Content-Type", "application/json") writer.Header().Set("Content-Type", "application/json")
json.NewEncoder(writer).Encode(resp)
err := json.NewEncoder(writer).Encode(resp)
if err != nil {
log.Error().Err(err).Msg("failed to encode public key response")
}
return return
} }
@@ -180,8 +183,12 @@ func (h *Headscale) HealthHandler(
res.Status = "fail" res.Status = "fail"
} }
json.NewEncoder(writer).Encode(res) encErr := json.NewEncoder(writer).Encode(res)
if encErr != nil {
log.Error().Err(encErr).Msg("failed to encode health response")
}
} }
err := h.state.PingDB(req.Context()) err := h.state.PingDB(req.Context())
if err != nil { if err != nil {
respond(err) respond(err)
@@ -218,6 +225,7 @@ func (h *Headscale) VersionHandler(
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
versionInfo := types.GetVersionInfo() versionInfo := types.GetVersionInfo()
err := json.NewEncoder(writer).Encode(versionInfo) err := json.NewEncoder(writer).Encode(versionInfo)
if err != nil { if err != nil {
log.Error(). log.Error().
@@ -244,7 +252,7 @@ func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
registrationId.String()) registrationId.String())
} }
// RegisterWebAPI shows a simple message in the browser to point to the CLI // RegisterHandler shows a simple message in the browser to point to the CLI
// Listens in /register/:registration_id. // Listens in /register/:registration_id.
// //
// This is not part of the Tailscale control API, as we could send whatever URL // This is not part of the Tailscale control API, as we could send whatever URL
@@ -267,7 +275,11 @@ func (a *AuthProviderWeb) RegisterHandler(
writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
_, err = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
if err != nil {
log.Error().Err(err).Msg("failed to write register response")
}
} }
func FaviconHandler(writer http.ResponseWriter, req *http.Request) { func FaviconHandler(writer http.ResponseWriter, req *http.Request) {

View File

@@ -8,6 +8,7 @@ import (
"github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
"github.com/puzpuzpuz/xsync/v4" "github.com/puzpuzpuz/xsync/v4"
@@ -15,6 +16,14 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
// Mapper errors.
var (
ErrInvalidNodeID = errors.New("invalid nodeID")
ErrMapperNil = errors.New("mapper is nil")
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
ErrNodeNotFoundMapper = errors.New("node not found")
)
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale", Namespace: "headscale",
Name: "mapresponse_generated_total", Name: "mapresponse_generated_total",
@@ -80,11 +89,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
} }
if nodeID == 0 { if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID) return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
} }
if mapper == nil { if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
} }
// Handle self-only responses // Handle self-only responses
@@ -135,12 +144,12 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil { if nc == nil {
return errors.New("nodeConnection is nil") return ErrNodeConnectionNil
} }
nodeID := nc.nodeID() nodeID := nc.nodeID()
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received") log.Debug().Caller().Uint64(zf.NodeID, nodeID.Uint64()).Str(zf.Reason, r.Reason).Msg("node change processing started")
data, err := generateMapResponse(nc, mapper, r) data, err := generateMapResponse(nc, mapper, r)
if err != nil { if err != nil {

View File

@@ -2,6 +2,7 @@ package mapper
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
@@ -10,13 +11,20 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util/zlog/zf"
"github.com/puzpuzpuz/xsync/v4" "github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/ptr"
) )
var errConnectionClosed = errors.New("connection channel already closed") // LockFreeBatcher errors.
var (
errConnectionClosed = errors.New("connection channel already closed")
ErrInitialMapSendTimeout = errors.New("sending initial map: timeout")
ErrBatcherShuttingDown = errors.New("batcher shutting down")
ErrConnectionSendTimeout = errors.New("timeout sending to channel (likely stale connection)")
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. // LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct { type LockFreeBatcher struct {
@@ -48,6 +56,7 @@ type LockFreeBatcher struct {
// and notifies other nodes that this node has come online. // and notifies other nodes that this node has come online.
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error {
addNodeStart := time.Now() addNodeStart := time.Now()
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
// Generate connection ID // Generate connection ID
connID := generateConnectionID() connID := generateConnectionID()
@@ -76,9 +85,10 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
// Use the worker pool for controlled concurrency instead of direct generation // Use the worker pool for controlled concurrency instead of direct generation
initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id))
if err != nil { if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") nlog.Error().Err(err).Msg("initial map generation failed")
nodeConn.removeConnectionByChannel(c) nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
return fmt.Errorf("generating initial map for node %d: %w", id, err)
} }
// Use a blocking send with timeout for initial map since the channel should be ready // Use a blocking send with timeout for initial map since the channel should be ready
@@ -86,12 +96,13 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
select { select {
case c <- initialMap: case c <- initialMap:
// Success // Success
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second): //nolint:mnd
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") nlog.Error().Err(ErrInitialMapSendTimeout).Msg("initial map send timeout")
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). nlog.Debug().Caller().Dur("timeout.duration", 5*time.Second). //nolint:mnd
Msg("Initial map send timed out because channel was blocked or receiver not ready") Msg("initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c) nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
return fmt.Errorf("%w for node %d", ErrInitialMapSendTimeout, id)
} }
// Update connection status // Update connection status
@@ -100,9 +111,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
// Node will automatically receive updates through the normal flow // Node will automatically receive updates through the normal flow
// The initial full map already contains all current state // The initial full map already contains all current state
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)). nlog.Debug().Caller().Dur(zf.TotalDuration, time.Since(addNodeStart)).
Int("active.connections", nodeConn.getActiveConnectionCount()). Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection established in batcher because AddNode completed successfully") Msg("node connection established in batcher")
return nil return nil
} }
@@ -112,31 +123,34 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
// and keeps the node entry alive for rapid reconnections instead of aggressive deletion. // and keeps the node entry alive for rapid reconnections instead of aggressive deletion.
// Reports if the node still has active connections after removal. // Reports if the node still has active connections after removal.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool {
nlog := log.With().Uint64(zf.NodeID, id.Uint64()).Logger()
nodeConn, exists := b.nodes.Load(id) nodeConn, exists := b.nodes.Load(id)
if !exists { if !exists {
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher") nlog.Debug().Caller().Msg("removeNode called for non-existent node")
return false return false
} }
// Remove specific connection // Remove specific connection
removed := nodeConn.removeConnectionByChannel(c) removed := nodeConn.removeConnectionByChannel(c)
if !removed { if !removed {
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid") nlog.Debug().Caller().Msg("removeNode: channel not found, connection already removed or invalid")
return false return false
} }
// Check if node has any remaining active connections // Check if node has any remaining active connections
if nodeConn.hasActiveConnections() { if nodeConn.hasActiveConnections() {
log.Debug().Caller().Uint64("node.id", id.Uint64()). nlog.Debug().Caller().
Int("active.connections", nodeConn.getActiveConnectionCount()). Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection removed but keeping online because other connections remain") Msg("node connection removed but keeping online, other connections remain")
return true // Node still has active connections return true // Node still has active connections
} }
// No active connections - keep the node entry alive for rapid reconnections // No active connections - keep the node entry alive for rapid reconnections
// The node will get a fresh full map when it reconnects // The node will get a fresh full map when it reconnects
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection") nlog.Debug().Caller().Msg("node disconnected from batcher, keeping entry for rapid reconnection")
b.connected.Store(id, ptr.To(time.Now())) b.connected.Store(id, new(time.Now()))
return false return false
} }
@@ -196,11 +210,13 @@ func (b *LockFreeBatcher) doWork() {
} }
func (b *LockFreeBatcher) worker(workerID int) { func (b *LockFreeBatcher) worker(workerID int) {
wlog := log.With().Int(zf.WorkerID, workerID).Logger()
for { for {
select { select {
case w, ok := <-b.workCh: case w, ok := <-b.workCh:
if !ok { if !ok {
log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID) wlog.Debug().Msg("worker channel closing, shutting down")
return return
} }
@@ -212,29 +228,29 @@ func (b *LockFreeBatcher) worker(workerID int) {
// This is used for synchronous map generation. // This is used for synchronous map generation.
if w.resultCh != nil { if w.resultCh != nil {
var result workResult var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists { if nc, exists := b.nodes.Load(w.nodeID); exists {
var err error var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err result.err = err
if result.err != nil { if result.err != nil {
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(result.err). wlog.Error().Err(result.err).
Int("worker.id", workerID). Uint64(zf.NodeID, w.nodeID.Uint64()).
Uint64("node.id", w.nodeID.Uint64()). Str(zf.Reason, w.c.Reason).
Str("reason", w.c.Reason).
Msg("failed to generate map response for synchronous work") Msg("failed to generate map response for synchronous work")
} else if result.mapResponse != nil { } else if result.mapResponse != nil {
// Update peer tracking for synchronous responses too // Update peer tracking for synchronous responses too
nc.updateSentPeers(result.mapResponse) nc.updateSentPeers(result.mapResponse)
} }
} else { } else {
result.err = fmt.Errorf("node %d not found", w.nodeID) result.err = fmt.Errorf("%w: %d", ErrNodeNotFoundMapper, w.nodeID)
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(result.err). wlog.Error().Err(result.err).
Int("worker.id", workerID). Uint64(zf.NodeID, w.nodeID.Uint64()).
Uint64("node.id", w.nodeID.Uint64()).
Msg("node not found for synchronous work") Msg("node not found for synchronous work")
} }
@@ -257,15 +273,14 @@ func (b *LockFreeBatcher) worker(workerID int) {
err := nc.change(w.c) err := nc.change(w.c)
if err != nil { if err != nil {
b.workErrors.Add(1) b.workErrors.Add(1)
log.Error().Err(err). wlog.Error().Err(err).
Int("worker.id", workerID). Uint64(zf.NodeID, w.nodeID.Uint64()).
Uint64("node.id", w.nodeID.Uint64()). Str(zf.Reason, w.c.Reason).
Str("reason", w.c.Reason).
Msg("failed to apply change") Msg("failed to apply change")
} }
} }
case <-b.done: case <-b.done:
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker") wlog.Debug().Msg("batcher shutting down, exiting worker")
return return
} }
} }
@@ -310,8 +325,8 @@ func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
if _, existed := b.nodes.LoadAndDelete(removedID); existed { if _, existed := b.nodes.LoadAndDelete(removedID); existed {
b.totalNodes.Add(-1) b.totalNodes.Add(-1)
log.Debug(). log.Debug().
Uint64("node.id", removedID.Uint64()). Uint64(zf.NodeID, removedID.Uint64()).
Msg("Removed deleted node from batcher") Msg("removed deleted node from batcher")
} }
b.connected.Delete(removedID) b.connected.Delete(removedID)
@@ -398,14 +413,15 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
} }
} }
} }
return true return true
}) })
// Clean up the identified nodes // Clean up the identified nodes
for _, nodeID := range nodesToCleanup { for _, nodeID := range nodesToCleanup {
log.Info().Uint64("node.id", nodeID.Uint64()). log.Info().Uint64(zf.NodeID, nodeID.Uint64()).
Dur("offline_duration", cleanupThreshold). Dur("offline_duration", cleanupThreshold).
Msg("Cleaning up node that has been offline for too long") Msg("cleaning up node that has been offline for too long")
b.nodes.Delete(nodeID) b.nodes.Delete(nodeID)
b.connected.Delete(nodeID) b.connected.Delete(nodeID)
@@ -413,8 +429,8 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
} }
if len(nodesToCleanup) > 0 { if len(nodesToCleanup) > 0 {
log.Info().Int("cleaned_nodes", len(nodesToCleanup)). log.Info().Int(zf.CleanedNodes, len(nodesToCleanup)).
Msg("Completed cleanup of long-offline nodes") Msg("completed cleanup of long-offline nodes")
} }
} }
@@ -450,6 +466,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
if nodeConn.hasActiveConnections() { if nodeConn.hasActiveConnections() {
ret.Store(id, true) ret.Store(id, true)
} }
return true return true
}) })
@@ -465,6 +482,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret.Store(id, false) ret.Store(id, false)
} }
} }
return true return true
}) })
@@ -484,7 +502,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
case result := <-resultCh: case result := <-resultCh:
return result.mapResponse, result.err return result.mapResponse, result.err
case <-b.done: case <-b.done:
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) return nil, fmt.Errorf("%w while generating map response for node %d", ErrBatcherShuttingDown, id)
} }
} }
@@ -502,6 +520,7 @@ type connectionEntry struct {
type multiChannelNodeConn struct { type multiChannelNodeConn struct {
id types.NodeID id types.NodeID
mapper *mapper mapper *mapper
log zerolog.Logger
mutex sync.RWMutex mutex sync.RWMutex
connections []*connectionEntry connections []*connectionEntry
@@ -518,8 +537,9 @@ type multiChannelNodeConn struct {
// generateConnectionID generates a unique connection identifier. // generateConnectionID generates a unique connection identifier.
func generateConnectionID() string { func generateConnectionID() string {
bytes := make([]byte, 8) bytes := make([]byte, 8)
rand.Read(bytes) _, _ = rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
return hex.EncodeToString(bytes)
} }
// newMultiChannelNodeConn creates a new multi-channel node connection. // newMultiChannelNodeConn creates a new multi-channel node connection.
@@ -528,6 +548,7 @@ func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeC
id: id, id: id,
mapper: mapper, mapper: mapper,
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](), lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
log: log.With().Uint64(zf.NodeID, id.Uint64()).Logger(),
} }
} }
@@ -546,18 +567,21 @@ func (mc *multiChannelNodeConn) close() {
// addConnection adds a new connection. // addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now() mutexWaitStart := time.Now()
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock() mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart) mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock() defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry) mc.connections = append(mc.connections, entry)
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", entry.c)).Str(zf.ConnID, entry.id).
Int("total_connections", len(mc.connections)). Int("total_connections", len(mc.connections)).
Dur("mutex_wait_time", mutexWaitDur). Dur("mutex_wait_time", mutexWaitDur).
Msg("Successfully added connection after mutex wait") Msg("successfully added connection after mutex wait")
} }
// removeConnectionByChannel removes a connection by matching channel pointer. // removeConnectionByChannel removes a connection by matching channel pointer.
@@ -569,12 +593,14 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
if entry.c == c { if entry.c == c {
// Remove this connection // Remove this connection
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)). mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)). Int("remaining_connections", len(mc.connections)).
Msg("Successfully removed connection") Msg("successfully removed connection")
return true return true
} }
} }
return false return false
} }
@@ -606,36 +632,41 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if len(mc.connections) == 0 { if len(mc.connections) == 0 {
// During rapid reconnection, nodes may temporarily have no active connections // During rapid reconnection, nodes may temporarily have no active connections
// This is not an error - the node will receive a full map when it reconnects // This is not an error - the node will receive a full map when it reconnects
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). mc.log.Debug().Caller().
Msg("send: skipping send to node with no active connections (likely rapid reconnection)") Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error return nil // Return success instead of error
} }
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). mc.log.Debug().Caller().
Int("total_connections", len(mc.connections)). Int("total_connections", len(mc.connections)).
Msg("send: broadcasting to all connections") Msg("send: broadcasting to all connections")
var lastErr error var lastErr error
successCount := 0 successCount := 0
var failedConnections []int // Track failed connections for removal var failedConnections []int // Track failed connections for removal
// Send to all connections // Send to all connections
for i, conn := range mc.connections { for i, conn := range mc.connections {
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i). Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Msg("send: attempting to send to connection") Msg("send: attempting to send to connection")
if err := conn.send(data); err != nil { err := conn.send(data)
if err != nil {
lastErr = err lastErr = err
failedConnections = append(failedConnections, i) failedConnections = append(failedConnections, i)
log.Warn().Err(err). mc.log.Warn().Err(err).Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: connection send failed") Msg("send: connection send failed")
} else { } else {
successCount++ successCount++
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i). mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", conn.c)).
Str(zf.ConnID, conn.id).Int(zf.ConnectionIndex, i).
Msg("send: successfully sent to connection") Msg("send: successfully sent to connection")
} }
} }
@@ -643,15 +674,15 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// Remove failed connections (in reverse order to maintain indices) // Remove failed connections (in reverse order to maintain indices)
for i := len(failedConnections) - 1; i >= 0; i-- { for i := len(failedConnections) - 1; i >= 0; i-- {
idx := failedConnections[i] idx := failedConnections[i]
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). mc.log.Debug().Caller().
Str("conn.id", mc.connections[idx].id). Str(zf.ConnID, mc.connections[idx].id).
Msg("send: removing failed connection") Msg("send: removing failed connection")
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...) mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
} }
mc.updateCount.Add(1) mc.updateCount.Add(1)
log.Debug().Uint64("node.id", mc.id.Uint64()). mc.log.Debug().
Int("successful_sends", successCount). Int("successful_sends", successCount).
Int("failed_connections", len(failedConnections)). Int("failed_connections", len(failedConnections)).
Int("remaining_connections", len(mc.connections)). Int("remaining_connections", len(mc.connections)).
@@ -688,7 +719,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
case <-time.After(50 * time.Millisecond): case <-time.After(50 * time.Millisecond):
// Connection is likely stale - client isn't reading from channel // Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open // This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id) return fmt.Errorf("connection %s: %w", entry.id, ErrConnectionSendTimeout)
} }
} }
@@ -798,6 +829,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
Connected: connected, Connected: connected,
ActiveConnections: activeConnCount, ActiveConnections: activeConnCount,
} }
return true return true
}) })
@@ -812,6 +844,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
ActiveConnections: 0, ActiveConnections: 0,
} }
} }
return true return true
}) })

View File

@@ -35,6 +35,7 @@ type batcherTestCase struct {
// that would normally be sent by poll.go in production. // that would normally be sent by poll.go in production.
type testBatcherWrapper struct { type testBatcherWrapper struct {
Batcher Batcher
state *state.State state *state.State
} }
@@ -80,12 +81,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
} }
// Finally remove from the real batcher // Finally remove from the real batcher
removed := t.Batcher.RemoveNode(id, c) return t.Batcher.RemoveNode(id, c)
if !removed {
return false
}
return true
} }
// wrapBatcherForTest wraps a batcher with test-specific behavior. // wrapBatcherForTest wraps a batcher with test-specific behavior.
@@ -129,8 +125,6 @@ const (
SMALL_BUFFER_SIZE = 3 SMALL_BUFFER_SIZE = 3
TINY_BUFFER_SIZE = 1 // For maximum contention TINY_BUFFER_SIZE = 1 // For maximum contention
LARGE_BUFFER_SIZE = 200 LARGE_BUFFER_SIZE = 200
reservedResponseHeaderSize = 4
) )
// TestData contains all test entities created for a test scenario. // TestData contains all test entities created for a test scenario.
@@ -241,8 +235,8 @@ func setupBatcherWithTestData(
} }
derpMap, err := derp.GetDERPMap(cfg.DERP) derpMap, err := derp.GetDERPMap(cfg.DERP)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, derpMap) require.NotNil(t, derpMap)
state.SetDERPMap(derpMap) state.SetDERPMap(derpMap)
@@ -319,6 +313,8 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
} }
// getStats returns a copy of the statistics for a node. // getStats returns a copy of the statistics for a node.
//
//nolint:unused
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
ut.mu.RLock() ut.mu.RLock()
defer ut.mu.RUnlock() defer ut.mu.RUnlock()
@@ -386,16 +382,14 @@ type UpdateInfo struct {
} }
// parseUpdateAndAnalyze parses an update and returns detailed information. // parseUpdateAndAnalyze parses an update and returns detailed information.
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
info := UpdateInfo{ return UpdateInfo{
PeerCount: len(resp.Peers), PeerCount: len(resp.Peers),
PatchCount: len(resp.PeersChangedPatch), PatchCount: len(resp.PeersChangedPatch),
IsFull: len(resp.Peers) > 0, IsFull: len(resp.Peers) > 0,
IsPatch: len(resp.PeersChangedPatch) > 0, IsPatch: len(resp.PeersChangedPatch) > 0,
IsDERP: resp.DERPMap != nil, IsDERP: resp.DERPMap != nil,
} }
return info, nil
} }
// start begins consuming updates from the node's channel and tracking stats. // start begins consuming updates from the node's channel and tracking stats.
@@ -417,7 +411,8 @@ func (n *node) start() {
atomic.AddInt64(&n.updateCount, 1) atomic.AddInt64(&n.updateCount, 1)
// Parse update and track detailed stats // Parse update and track detailed stats
if info, err := parseUpdateAndAnalyze(data); err == nil { info := parseUpdateAndAnalyze(data)
{
// Track update types // Track update types
if info.IsFull { if info.IsFull {
atomic.AddInt64(&n.fullCount, 1) atomic.AddInt64(&n.fullCount, 1)
@@ -548,7 +543,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start() testNode.start()
// Connect the node to the batcher // Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
// Wait for connection to be established // Wait for connection to be established
assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.EventuallyWithT(t, func(c *assert.CollectT) {
@@ -657,7 +652,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity // Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullUpdate()) batcher.AddWork(change.FullUpdate())
@@ -676,6 +671,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) { assert.EventuallyWithT(t, func(c *assert.CollectT) {
connectedCount := 0 connectedCount := 0
for i := range allNodes { for i := range allNodes {
node := &allNodes[i] node := &allNodes[i]
@@ -693,6 +689,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity") }, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
t.Logf("✅ All nodes achieved full connectivity!") t.Logf("✅ All nodes achieved full connectivity!")
totalTime := time.Since(startTime) totalTime := time.Since(startTime)
// Disconnect all nodes // Disconnect all nodes
@@ -820,11 +817,11 @@ func TestBatcherBasicOperations(t *testing.T) {
defer cleanup() defer cleanup()
batcher := testData.Batcher batcher := testData.Batcher
tn := testData.Nodes[0] tn := &testData.Nodes[0]
tn2 := testData.Nodes[1] tn2 := &testData.Nodes[1]
// Test AddNode with real node ID // Test AddNode with real node ID
batcher.AddNode(tn.n.ID, tn.ch, 100) _ = batcher.AddNode(tn.n.ID, tn.ch, 100)
if !batcher.IsConnected(tn.n.ID) { if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode") t.Error("Node should be connected after AddNode")
@@ -842,10 +839,10 @@ func TestBatcherBasicOperations(t *testing.T) {
} }
// Drain any initial messages from first node // Drain any initial messages from first node
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) drainChannelTimeout(tn.ch, 100*time.Millisecond)
// Add the second node and verify update message // Add the second node and verify update message
batcher.AddNode(tn2.n.ID, tn2.ch, 100) _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
assert.True(t, batcher.IsConnected(tn2.n.ID)) assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected. // First node should get an update that second node has connected.
@@ -911,18 +908,14 @@ func TestBatcherBasicOperations(t *testing.T) {
} }
} }
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, timeout time.Duration) {
count := 0
timer := time.NewTimer(timeout) timer := time.NewTimer(timeout)
defer timer.Stop() defer timer.Stop()
for { for {
select { select {
case data := <-ch: case <-ch:
count++ // Drain message
// Optional: add debug output if needed
_ = data
case <-timer.C: case <-timer.C:
return return
} }
@@ -1050,7 +1043,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10) ch := make(chan *tailcfg.MapResponse, 10)
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
// Track update content for validation // Track update content for validation
var receivedUpdates []*tailcfg.MapResponse var receivedUpdates []*tailcfg.MapResponse
@@ -1131,6 +1124,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// even when real node updates are being processed, ensuring no race conditions // even when real node updates are being processed, ensuring no race conditions
// occur during channel replacement with actual workload. // occur during channel replacement with actual workload.
func XTestBatcherChannelClosingRace(t *testing.T) { func XTestBatcherChannelClosingRace(t *testing.T) {
t.Helper()
for _, batcherFunc := range allBatcherFunctions { for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) { t.Run(batcherFunc.name, func(t *testing.T) {
// Create test environment with real database and nodes // Create test environment with real database and nodes
@@ -1138,7 +1133,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
defer cleanup() defer cleanup()
batcher := testData.Batcher batcher := testData.Batcher
testNode := testData.Nodes[0] testNode := &testData.Nodes[0]
var ( var (
channelIssues int channelIssues int
@@ -1154,7 +1149,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
ch1 := make(chan *tailcfg.MapResponse, 1) ch1 := make(chan *tailcfg.MapResponse, 1)
wg.Go(func() { wg.Go(func() {
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
}) })
// Add real work during connection chaos // Add real work during connection chaos
@@ -1167,7 +1162,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Go(func() { wg.Go(func() {
runtime.Gosched() // Yield to introduce timing variability runtime.Gosched() // Yield to introduce timing variability
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
}) })
// Remove second connection // Remove second connection
@@ -1231,7 +1227,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
defer cleanup() defer cleanup()
batcher := testData.Batcher batcher := testData.Batcher
testNode := testData.Nodes[0] testNode := &testData.Nodes[0]
var ( var (
panics int panics int
@@ -1258,7 +1254,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5) ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work // Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPMap()) batcher.AddWork(change.DERPMap())
// Consumer goroutine to validate data and detect channel issues // Consumer goroutine to validate data and detect channel issues
@@ -1308,6 +1304,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
for range i % 3 { for range i % 3 {
runtime.Gosched() // Introduce timing variability runtime.Gosched() // Introduce timing variability
} }
batcher.RemoveNode(testNode.n.ID, ch) batcher.RemoveNode(testNode.n.ID, ch)
// Yield to allow workers to process and close channels // Yield to allow workers to process and close channels
@@ -1350,6 +1347,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// real node data. The test validates that stable clients continue to function // real node data. The test validates that stable clients continue to function
// normally and receive proper updates despite the connection churn from other clients, // normally and receive proper updates despite the connection churn from other clients,
// ensuring system stability under concurrent load. // ensuring system stability under concurrent load.
//
//nolint:gocyclo // complex concurrent test scenario
func TestBatcherConcurrentClients(t *testing.T) { func TestBatcherConcurrentClients(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping concurrent client test in short mode") t.Skip("Skipping concurrent client test in short mode")
@@ -1377,10 +1376,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable
stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse)
for _, node := range stableNodes { for i := range stableNodes {
node := &stableNodes[i]
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch stableChannels[node.n.ID] = ch
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
// Monitor updates for each stable client // Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@@ -1391,6 +1391,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully // Channel was closed, exit gracefully
return return
} }
if valid, reason := validateUpdateContent(data); valid { if valid, reason := validateUpdateContent(data); valid {
tracker.recordUpdate( tracker.recordUpdate(
nodeID, nodeID,
@@ -1427,7 +1428,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Connection churn cycles - rapidly connect/disconnect to test concurrency safety // Connection churn cycles - rapidly connect/disconnect to test concurrency safety
for i := range numCycles { for i := range numCycles {
for _, node := range churningNodes { for j := range churningNodes {
node := &churningNodes[j]
wg.Add(2) wg.Add(2)
// Connect churning node // Connect churning node
@@ -1448,10 +1451,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock() churningChannelsMutex.Lock()
churningChannels[nodeID] = ch churningChannels[nodeID] = ch
churningChannelsMutex.Unlock() churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking // Consume updates to prevent blocking
go func() { go func() {
@@ -1462,6 +1467,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully // Channel was closed, exit gracefully
return return
} }
if valid, _ := validateUpdateContent(data); valid { if valid, _ := validateUpdateContent(data); valid {
tracker.recordUpdate( tracker.recordUpdate(
nodeID, nodeID,
@@ -1494,6 +1500,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for range i % 5 { for range i % 5 {
runtime.Gosched() // Introduce timing variability runtime.Gosched() // Introduce timing variability
} }
churningChannelsMutex.Lock() churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID] ch, exists := churningChannels[nodeID]
@@ -1519,7 +1526,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
if i%7 == 0 && len(allNodes) > 0 { if i%7 == 0 && len(allNodes) > 0 {
// Node-specific changes using real nodes // Node-specific changes using real nodes
node := allNodes[i%len(allNodes)] node := &allNodes[i%len(allNodes)]
// Use a valid expiry time for testing since test nodes don't have expiry set // Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour) testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry)) batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
@@ -1567,7 +1574,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls", t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls",
expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork) expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork)
for _, node := range stableNodes { for i := range stableNodes {
node := &stableNodes[i]
if stats, exists := allStats[node.n.ID]; exists { if stats, exists := allStats[node.n.ID]; exists {
stableUpdateCount += stats.TotalUpdates stableUpdateCount += stats.TotalUpdates
t.Logf("Stable node %d: %d updates", t.Logf("Stable node %d: %d updates",
@@ -1580,7 +1588,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
} }
} }
for _, node := range churningNodes { for i := range churningNodes {
node := &churningNodes[i]
if stats, exists := allStats[node.n.ID]; exists { if stats, exists := allStats[node.n.ID]; exists {
churningUpdateCount += stats.TotalUpdates churningUpdateCount += stats.TotalUpdates
} }
@@ -1605,7 +1614,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
} }
// Verify all stable clients are still functional // Verify all stable clients are still functional
for _, node := range stableNodes { for i := range stableNodes {
node := &stableNodes[i]
if !batcher.IsConnected(node.n.ID) { if !batcher.IsConnected(node.n.ID) {
t.Errorf("Stable node %d lost connection during racing", node.n.ID) t.Errorf("Stable node %d lost connection during racing", node.n.ID)
} }
@@ -1623,6 +1633,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
// It validates that the system remains stable with no deadlocks, panics, or // It validates that the system remains stable with no deadlocks, panics, or
// missed updates under sustained high load. The test uses real node data to // missed updates under sustained high load. The test uses real node data to
// generate authentic update scenarios and tracks comprehensive statistics. // generate authentic update scenarios and tracks comprehensive statistics.
//
//nolint:gocyclo,thelper // complex scalability test scenario
func XTestBatcherScalability(t *testing.T) { func XTestBatcherScalability(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("Skipping scalability test in short mode") t.Skip("Skipping scalability test in short mode")
@@ -1651,7 +1663,7 @@ func XTestBatcherScalability(t *testing.T) {
description string description string
} }
var testCases []testCase testCases := make([]testCase, 0, len(chaosTypes)*len(bufferSizes)*len(cycles)*len(nodes))
// Generate all combinations of the test matrix // Generate all combinations of the test matrix
for _, nodeCount := range nodes { for _, nodeCount := range nodes {
@@ -1762,7 +1774,8 @@ func XTestBatcherScalability(t *testing.T) {
for i := range testNodes { for i := range testNodes {
node := &testNodes[i] node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true connectedNodes[node.n.ID] = true
@@ -1824,7 +1837,8 @@ func XTestBatcherScalability(t *testing.T) {
} }
// Connection/disconnection cycles for subset of nodes // Connection/disconnection cycles for subset of nodes
for i, node := range chaosNodes { for i := range chaosNodes {
node := &chaosNodes[i]
// Only add work if this is connection chaos or mixed // Only add work if this is connection chaos or mixed
if tc.chaosType == "connection" || tc.chaosType == "mixed" { if tc.chaosType == "connection" || tc.chaosType == "mixed" {
wg.Add(2) wg.Add(2)
@@ -1878,6 +1892,7 @@ func XTestBatcherScalability(t *testing.T) {
channel, channel,
tailcfg.CapabilityVersion(100), tailcfg.CapabilityVersion(100),
) )
connectedNodesMutex.Lock() connectedNodesMutex.Lock()
connectedNodes[nodeID] = true connectedNodes[nodeID] = true
@@ -2138,8 +2153,9 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
t.Logf("Created %d nodes in database", len(allNodes)) t.Logf("Created %d nodes in database", len(allNodes))
// Connect nodes one at a time and wait for each to be connected // Connect nodes one at a time and wait for each to be connected
for i, node := range allNodes { for i := range allNodes {
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) node := &allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID) t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Wait for node to be connected // Wait for node to be connected
@@ -2157,7 +2173,8 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect") }, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
// Check how many peers each node should see // Check how many peers each node should see
for i, node := range allNodes { for i := range allNodes {
node := &allNodes[i]
peers := testData.State.ListPeers(node.n.ID) peers := testData.State.ListPeers(node.n.ID)
t.Logf("Node %d should see %d peers from state", i, peers.Len()) t.Logf("Node %d should see %d peers from state", i, peers.Len())
} }
@@ -2286,7 +2303,10 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 1: Connect all nodes initially // Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...") t.Logf("Phase 1: Connecting all nodes...")
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
if err != nil { if err != nil {
t.Fatalf("Failed to add node %d: %v", i, err) t.Fatalf("Failed to add node %d: %v", i, err)
@@ -2302,16 +2322,21 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...") t.Logf("Phase 2: Rapid disconnect all nodes...")
for i, node := range allNodes {
for i := range allNodes {
node := &allNodes[i]
removed := batcher.RemoveNode(node.n.ID, node.ch) removed := batcher.RemoveNode(node.n.ID, node.ch)
t.Logf("Node %d RemoveNode result: %t", i, removed) t.Logf("Node %d RemoveNode result: %t", i, removed)
} }
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...") t.Logf("Phase 3: Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i, node := range allNodes { for i := range allNodes {
node := &allNodes[i]
newChannels[i] = make(chan *tailcfg.MapResponse, 10) newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
if err != nil { if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err) t.Errorf("Failed to reconnect node %d: %v", i, err)
@@ -2334,7 +2359,8 @@ func TestBatcherRapidReconnection(t *testing.T) {
debugInfo := debugBatcher.Debug() debugInfo := debugBatcher.Debug()
disconnectedCount := 0 disconnectedCount := 0
for i, node := range allNodes { for i := range allNodes {
node := &allNodes[i]
if info, exists := debugInfo[node.n.ID]; exists { if info, exists := debugInfo[node.n.ID]; exists {
t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info) t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info)
@@ -2342,11 +2368,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
if infoMap, ok := info.(map[string]any); ok { if infoMap, ok := info.(map[string]any); ok {
if connected, ok := infoMap["connected"].(bool); ok && !connected { if connected, ok := infoMap["connected"].(bool); ok && !connected {
disconnectedCount++ disconnectedCount++
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
} }
} }
} else { } else {
disconnectedCount++ disconnectedCount++
t.Logf("Node %d missing from debug info entirely", i) t.Logf("Node %d missing from debug info entirely", i)
} }
@@ -2381,6 +2409,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
case update := <-newChannels[i]: case update := <-newChannels[i]:
if update != nil { if update != nil {
receivedCount++ receivedCount++
t.Logf("Node %d received update successfully", i) t.Logf("Node %d received update successfully", i)
} }
case <-timeout: case <-timeout:
@@ -2399,6 +2428,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
} }
} }
//nolint:gocyclo // complex multi-connection test scenario
func TestBatcherMultiConnection(t *testing.T) { func TestBatcherMultiConnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions { for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) { t.Run(batcherFunc.name, func(t *testing.T) {
@@ -2406,13 +2436,14 @@ func TestBatcherMultiConnection(t *testing.T) {
defer cleanup() defer cleanup()
batcher := testData.Batcher batcher := testData.Batcher
node1 := testData.Nodes[0] node1 := &testData.Nodes[0]
node2 := testData.Nodes[1] node2 := &testData.Nodes[1]
t.Logf("=== MULTI-CONNECTION TEST ===") t.Logf("=== MULTI-CONNECTION TEST ===")
// Phase 1: Connect first node with initial connection // Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...") t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
if err != nil { if err != nil {
t.Fatalf("Failed to add node1: %v", err) t.Fatalf("Failed to add node1: %v", err)
@@ -2432,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 2: Add second connection for node1 (multi-connection scenario) // Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...") t.Logf("Phase 2: Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10) secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
if err != nil { if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err) t.Fatalf("Failed to add second connection for node1: %v", err)
@@ -2443,7 +2476,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 3: Add third connection for node1 // Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...") t.Logf("Phase 3: Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10) thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
if err != nil { if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err) t.Fatalf("Failed to add third connection for node1: %v", err)
@@ -2454,6 +2489,7 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 4: Verify debug status shows correct connection count // Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...") t.Logf("Phase 4: Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface { if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any Debug() map[types.NodeID]any
}); ok { }); ok {
@@ -2461,6 +2497,7 @@ func TestBatcherMultiConnection(t *testing.T) {
if info, exists := debugInfo[node1.n.ID]; exists { if info, exists := debugInfo[node1.n.ID]; exists {
t.Logf("Node1 debug info: %+v", info) t.Logf("Node1 debug info: %+v", info)
if infoMap, ok := info.(map[string]any); ok { if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok { if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 3 { if activeConnections != 3 {
@@ -2469,6 +2506,7 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Logf("SUCCESS: Node1 correctly shows 3 active connections") t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
} }
} }
if connected, ok := infoMap["connected"].(bool); ok && !connected { if connected, ok := infoMap["connected"].(bool); ok && !connected {
t.Errorf("Node1 should show as connected with 3 active connections") t.Errorf("Node1 should show as connected with 3 active connections")
} }

View File

@@ -1,7 +1,6 @@
package mapper package mapper
import ( import (
"errors"
"net/netip" "net/netip"
"sort" "sort"
"time" "time"
@@ -36,6 +35,7 @@ const (
// NewMapResponseBuilder creates a new builder with basic fields set. // NewMapResponseBuilder creates a new builder with basic fields set.
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now() now := time.Now()
return &MapResponseBuilder{ return &MapResponseBuilder{
resp: &tailcfg.MapResponse{ resp: &tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
@@ -69,7 +69,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
nv, ok := b.mapper.state.GetNodeByID(b.nodeID) nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
b.addError(errors.New("node not found")) b.addError(ErrNodeNotFoundMapper)
return b return b
} }
@@ -123,6 +123,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{ b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled, DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
} }
return b return b
} }
@@ -130,7 +131,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
b.addError(errors.New("node not found")) b.addError(ErrNodeNotFoundMapper)
return b return b
} }
@@ -149,7 +150,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
b.addError(errors.New("node not found")) b.addError(ErrNodeNotFoundMapper)
return b return b
} }
@@ -162,7 +163,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
b.addError(errors.New("node not found")) b.addError(ErrNodeNotFoundMapper)
return b return b
} }
@@ -175,7 +176,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
b.addError(errors.New("node not found")) b.addError(ErrNodeNotFoundMapper)
return b return b
} }
@@ -229,7 +230,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
node, ok := b.mapper.state.GetNodeByID(b.nodeID) node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok { if !ok {
return nil, errors.New("node not found") return nil, ErrNodeNotFoundMapper
} }
// Get unreduced matchers for peer relationship determination. // Get unreduced matchers for peer relationship determination.
@@ -276,20 +277,22 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
// WithPeersRemoved adds removed peer IDs. // WithPeersRemoved adds removed peer IDs.
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID tailscaleIDs := make([]tailcfg.NodeID, 0, len(removedIDs))
for _, id := range removedIDs { for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID()) tailscaleIDs = append(tailscaleIDs, id.NodeID())
} }
b.resp.PeersRemoved = tailscaleIDs b.resp.PeersRemoved = tailscaleIDs
return b return b
} }
// Build finalizes the response and returns marshaled bytes // Build finalizes the response and returns marshaled bytes.
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 { if len(b.errs) > 0 {
return nil, multierr.New(b.errs...) return nil, multierr.New(b.errs...)
} }
if debugDumpMapResponsePath != "" { if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.debugType, b.nodeID) writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
} }

View File

@@ -339,8 +339,8 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
// Build should return a multierr // Build should return a multierr
data, err := result.Build() data, err := result.Build()
assert.Nil(t, data) require.Nil(t, data)
assert.Error(t, err) require.Error(t, err)
// The error should contain information about multiple errors // The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors") assert.Contains(t, err.Error(), "multiple errors")

View File

@@ -24,7 +24,6 @@ import (
const ( const (
nextDNSDoHPrefix = "https://dns.nextdns.io" nextDNSDoHPrefix = "https://dns.nextdns.io"
mapperIDLength = 8
debugMapResponsePerm = 0o755 debugMapResponsePerm = 0o755
) )
@@ -50,6 +49,7 @@ type mapper struct {
created time.Time created time.Time
} }
//nolint:unused
type patch struct { type patch struct {
timestamp time.Time timestamp time.Time
change *tailcfg.PeerChange change *tailcfg.PeerChange
@@ -60,7 +60,6 @@ func newMapper(
state *state.State, state *state.State,
) *mapper { ) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &mapper{ return &mapper{
state: state, state: state,
cfg: cfg, cfg: cfg,
@@ -76,23 +75,26 @@ func generateUserProfiles(
) []tailcfg.UserProfile { ) []tailcfg.UserProfile {
userMap := make(map[uint]*types.UserView) userMap := make(map[uint]*types.UserView)
ids := make([]uint, 0, len(userMap)) ids := make([]uint, 0, len(userMap))
user := node.Owner() user := node.Owner()
if !user.Valid() { if !user.Valid() {
log.Error(). log.Error().
Uint64("node.id", node.ID().Uint64()). EmbedObject(node).
Str("node.name", node.Hostname()).
Msg("node has no valid owner, skipping user profile generation") Msg("node has no valid owner, skipping user profile generation")
return nil return nil
} }
userID := user.Model().ID userID := user.Model().ID
userMap[userID] = &user userMap[userID] = &user
ids = append(ids, userID) ids = append(ids, userID)
for _, peer := range peers.All() { for _, peer := range peers.All() {
peerUser := peer.Owner() peerUser := peer.Owner()
if !peerUser.Valid() { if !peerUser.Valid() {
continue continue
} }
peerUserID := peerUser.Model().ID peerUserID := peerUser.Model().ID
userMap[peerUserID] = &peerUser userMap[peerUserID] = &peerUser
ids = append(ids, peerUserID) ids = append(ids, peerUserID)
@@ -100,7 +102,9 @@ func generateUserProfiles(
slices.Sort(ids) slices.Sort(ids)
ids = slices.Compact(ids) ids = slices.Compact(ids)
var profiles []tailcfg.UserProfile var profiles []tailcfg.UserProfile
for _, id := range ids { for _, id := range ids {
if userMap[id] != nil { if userMap[id] != nil {
profiles = append(profiles, userMap[id].TailscaleUserProfile()) profiles = append(profiles, userMap[id].TailscaleUserProfile())
@@ -150,6 +154,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
} }
// fullMapResponse returns a MapResponse for the given node. // fullMapResponse returns a MapResponse for the given node.
//
//nolint:unused
func (m *mapper) fullMapResponse( func (m *mapper) fullMapResponse(
nodeID types.NodeID, nodeID types.NodeID,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
@@ -317,6 +323,7 @@ func writeDebugMapResponse(
perms := fs.FileMode(debugMapResponsePerm) perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
err = os.MkdirAll(mPath, perms) err = os.MkdirAll(mPath, perms)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -329,7 +336,8 @@ func writeDebugMapResponse(
fmt.Sprintf("%s-%s.json", now, t), fmt.Sprintf("%s-%s.json", now, t),
) )
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) log.Trace().Msgf("writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms) err = os.WriteFile(mapResponsePath, body, perms)
if err != nil { if err != nil {
panic(err) panic(err)
@@ -338,7 +346,7 @@ func writeDebugMapResponse(
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
if debugDumpMapResponsePath == "" { if debugDumpMapResponsePath == "" {
return nil, nil return nil, nil //nolint:nilnil // intentional: no data when debug path not set
} }
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath) return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
@@ -351,6 +359,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
} }
result := make(map[types.NodeID][]tailcfg.MapResponse) result := make(map[types.NodeID][]tailcfg.MapResponse)
for _, node := range nodes { for _, node := range nodes {
if !node.IsDir() { if !node.IsDir() {
continue continue
@@ -358,7 +367,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64) nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name()) log.Error().Err(err).Msgf("parsing node ID from dir %s", node.Name())
continue continue
} }
@@ -366,7 +375,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
files, err := os.ReadDir(path.Join(dir, node.Name())) files, err := os.ReadDir(path.Join(dir, node.Name()))
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Reading dir %s", node.Name()) log.Error().Err(err).Msgf("reading dir %s", node.Name())
continue continue
} }
@@ -381,14 +390,15 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name())) body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name()))
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Reading file %s", file.Name()) log.Error().Err(err).Msgf("reading file %s", file.Name())
continue continue
} }
var resp tailcfg.MapResponse var resp tailcfg.MapResponse
err = json.Unmarshal(body, &resp) err = json.Unmarshal(body, &resp)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name()) log.Error().Err(err).Msgf("unmarshalling file %s", file.Name())
continue continue
} }

View File

@@ -3,18 +3,13 @@ package mapper
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
"tailscale.com/types/ptr"
) )
var iap = func(ipStr string) *netip.Addr { var iap = func(ipStr string) *netip.Addr {
@@ -51,7 +46,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
mach := func(hostname, username string, userid uint) *types.Node { mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{ return &types.Node{
Hostname: hostname, Hostname: hostname,
UserID: ptr.To(userid), UserID: new(userid),
User: &types.User{ User: &types.User{
Name: username, Name: username,
}, },
@@ -81,90 +76,3 @@ func TestDNSConfigMapResponse(t *testing.T) {
}) })
} }
} }
// mockState is a mock implementation that provides the required methods.
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
primary *routes.PrimaryRoutes
nodes types.Nodes
peers types.Nodes
}
func (m *mockState) DERPMap() *tailcfg.DERPMap {
return m.derpMap
}
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
if m.polMan == nil {
return tailcfg.FilterAllowAll, nil
}
return m.polMan.Filter()
}
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
if m.polMan == nil {
return nil, nil
}
return m.polMan.SSHPolicy(node)
}
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
if m.polMan == nil {
return false
}
return m.polMan.NodeCanHaveTag(node, tag)
}
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
if m.primary == nil {
return nil
}
return m.primary.PrimaryRoutes(nodeID)
}
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
if len(peerIDs) > 0 {
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
// Return all peers except the node itself
var filtered types.Nodes
for _, peer := range m.peers {
if peer.ID != nodeID {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) > 0 {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)
}
}
return filtered, nil
}
return m.nodes, nil
}
func Test_fullMapResponse(t *testing.T) {
t.Skip("Test needs to be refactored for new state-based architecture")
// TODO: Refactor this test to work with the new state-based mapper
// The test architecture needs to be updated to work with the state interface
// instead of the old direct dependency injection pattern
}

View File

@@ -13,12 +13,12 @@ import (
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
func TestTailNode(t *testing.T) { func TestTailNode(t *testing.T) {
mustNK := func(str string) key.NodePublic { mustNK := func(str string) key.NodePublic {
var k key.NodePublic var k key.NodePublic
_ = k.UnmarshalText([]byte(str)) _ = k.UnmarshalText([]byte(str))
return k return k
@@ -26,6 +26,7 @@ func TestTailNode(t *testing.T) {
mustDK := func(str string) key.DiscoPublic { mustDK := func(str string) key.DiscoPublic {
var k key.DiscoPublic var k key.DiscoPublic
_ = k.UnmarshalText([]byte(str)) _ = k.UnmarshalText([]byte(str))
return k return k
@@ -33,6 +34,7 @@ func TestTailNode(t *testing.T) {
mustMK := func(str string) key.MachinePublic { mustMK := func(str string) key.MachinePublic {
var k key.MachinePublic var k key.MachinePublic
_ = k.UnmarshalText([]byte(str)) _ = k.UnmarshalText([]byte(str))
return k return k
@@ -95,7 +97,7 @@ func TestTailNode(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
Hostname: "mini", Hostname: "mini",
GivenName: "mini", GivenName: "mini",
UserID: ptr.To(uint(0)), UserID: new(uint(0)),
User: &types.User{ User: &types.User{
Name: "mini", Name: "mini",
}, },
@@ -137,8 +139,8 @@ func TestTailNode(t *testing.T) {
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{ AllowedIPs: []netip.Prefix{
tsaddr.AllIPv4(), tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("100.64.0.1/32"),
netip.MustParsePrefix("192.168.0.0/24"),
tsaddr.AllIPv6(), tsaddr.AllIPv6(),
}, },
PrimaryRoutes: []netip.Prefix{ PrimaryRoutes: []netip.Prefix{
@@ -255,7 +257,7 @@ func TestNodeExpiry(t *testing.T) {
}, },
{ {
name: "localtime", name: "localtime",
exp: tp(time.Time{}.Local()), exp: tp(time.Time{}.Local()), //nolint:gosmopolitan
wantTimeZero: true, wantTimeZero: true,
}, },
} }
@@ -284,7 +286,9 @@ func TestNodeExpiry(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("nodeExpiry() error = %v", err) t.Fatalf("nodeExpiry() error = %v", err)
} }
var deseri tailcfg.Node var deseri tailcfg.Node
err = json.Unmarshal(seri, &deseri) err = json.Unmarshal(seri, &deseri)
if err != nil { if err != nil {
t.Fatalf("nodeExpiry() error = %v", err) t.Fatalf("nodeExpiry() error = %v", err)

View File

@@ -71,6 +71,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
rw := &respWriterProm{ResponseWriter: w} rw := &respWriterProm{ResponseWriter: w}
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path)) timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
next.ServeHTTP(rw, r) next.ServeHTTP(rw, r)
timer.ObserveDuration() timer.ObserveDuration()
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc() httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
@@ -79,6 +80,7 @@ func prometheusMiddleware(next http.Handler) http.Handler {
type respWriterProm struct { type respWriterProm struct {
http.ResponseWriter http.ResponseWriter
status int status int
written int64 written int64
wroteHeader bool wroteHeader bool
@@ -94,6 +96,7 @@ func (r *respWriterProm) Write(b []byte) (int, error) {
if !r.wroteHeader { if !r.wroteHeader {
r.WriteHeader(http.StatusOK) r.WriteHeader(http.StatusOK)
} }
n, err := r.ResponseWriter.Write(b) n, err := r.ResponseWriter.Write(b)
r.written += int64(n) r.written += int64(n)

View File

@@ -19,6 +19,9 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
) )
// ErrUnsupportedClientVersion is returned when a client connects with an unsupported protocol version.
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
const ( const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade. // ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021" ts2021UpgradePath = "/ts2021"
@@ -51,7 +54,7 @@ func (h *Headscale) NoiseUpgradeHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", req.RemoteAddr) log.Trace().Caller().Msgf("noise upgrade handler for client %s", req.RemoteAddr)
upgrade := req.Header.Get("Upgrade") upgrade := req.Header.Get("Upgrade")
if upgrade == "" { if upgrade == "" {
@@ -60,7 +63,7 @@ func (h *Headscale) NoiseUpgradeHandler(
// be passed to Headscale. Let's give them a hint. // be passed to Headscale. Let's give them a hint.
log.Warn(). log.Warn().
Caller(). Caller().
Msg("No Upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.") Msg("no upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
http.Error(writer, "Internal error", http.StatusInternalServerError) http.Error(writer, "Internal error", http.StatusInternalServerError)
return return
@@ -79,7 +82,7 @@ func (h *Headscale) NoiseUpgradeHandler(
noiseServer.earlyNoise, noiseServer.earlyNoise,
) )
if err != nil { if err != nil {
httpError(writer, fmt.Errorf("noise upgrade failed: %w", err)) httpError(writer, fmt.Errorf("upgrading noise connection: %w", err))
return return
} }
@@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler(
} }
func unsupportedClientError(version tailcfg.CapabilityVersion) error { func unsupportedClientError(version tailcfg.CapabilityVersion) error {
return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version) return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
} }
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
@@ -137,17 +140,20 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
// an HTTP/2 settings frame, which isn't of type 'T') // an HTTP/2 settings frame, which isn't of type 'T')
var notH2Frame [5]byte var notH2Frame [5]byte
copy(notH2Frame[:], earlyPayloadMagic) copy(notH2Frame[:], earlyPayloadMagic)
var lenBuf [4]byte var lenBuf [4]byte
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) //nolint:gosec // JSON length is bounded
// These writes are all buffered by caller, so fine to do them // These writes are all buffered by caller, so fine to do them
// separately: // separately:
if _, err := writer.Write(notH2Frame[:]); err != nil { if _, err := writer.Write(notH2Frame[:]); err != nil { //nolint:noinlineerr
return err return err
} }
if _, err := writer.Write(lenBuf[:]); err != nil {
if _, err := writer.Write(lenBuf[:]); err != nil { //nolint:noinlineerr
return err return err
} }
if _, err := writer.Write(earlyJSON); err != nil {
if _, err := writer.Write(earlyJSON); err != nil { //nolint:noinlineerr
return err return err
} }
@@ -199,7 +205,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
body, _ := io.ReadAll(req.Body) body, _ := io.ReadAll(req.Body)
var mapRequest tailcfg.MapRequest var mapRequest tailcfg.MapRequest
if err := json.Unmarshal(body, &mapRequest); err != nil { if err := json.Unmarshal(body, &mapRequest); err != nil { //nolint:noinlineerr
httpError(writer, err) httpError(writer, err)
return return
} }
@@ -218,7 +224,8 @@ func (ns *noiseServer) NoisePollNetMapHandler(
ns.nodeKey = nv.NodeKey() ns.nodeKey = nv.NodeKey()
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct()) sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
sess.tracef("a node sending a MapRequest with Noise protocol") sess.log.Trace().Caller().Msg("a node sending a MapRequest with Noise protocol")
if !sess.isStreaming() { if !sess.isStreaming() {
sess.serve() sess.serve()
} else { } else {
@@ -241,14 +248,16 @@ func (ns *noiseServer) NoiseRegistrationHandler(
return return
} }
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { //nolint:contextcheck
var resp *tailcfg.RegisterResponse var resp *tailcfg.RegisterResponse
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
if err != nil { if err != nil {
return &tailcfg.RegisterRequest{}, regErr(err) return &tailcfg.RegisterRequest{}, regErr(err)
} }
var regReq tailcfg.RegisterRequest var regReq tailcfg.RegisterRequest
if err := json.Unmarshal(body, &regReq); err != nil { if err := json.Unmarshal(body, &regReq); err != nil { //nolint:noinlineerr
return &regReq, regErr(err) return &regReq, regErr(err)
} }
@@ -256,11 +265,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer()) resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
if err != nil { if err != nil {
var httpErr HTTPError if httpErr, ok := errors.AsType[HTTPError](err); ok {
if errors.As(err, &httpErr) {
resp = &tailcfg.RegisterResponse{ resp = &tailcfg.RegisterResponse{
Error: httpErr.Msg, Error: httpErr.Msg,
} }
return &regReq, resp return &regReq, resp
} }
@@ -278,8 +287,9 @@ func (ns *noiseServer) NoiseRegistrationHandler(
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { err := json.NewEncoder(writer).Encode(registerResponse)
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") if err != nil {
log.Error().Caller().Err(err).Msg("noise registration handler: failed to encode RegisterResponse")
return return
} }

View File

@@ -32,8 +32,8 @@ const (
var ( var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback") errNoOIDCIDToken = errors.New("extracting ID token")
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache") errNoOIDCRegistrationInfo = errors.New("registration info not in cache")
errOIDCAllowedDomains = errors.New( errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain", "authenticated principal does not match any allowed domain",
) )
@@ -68,7 +68,7 @@ func NewAuthProviderOIDC(
) (*AuthProviderOIDC, error) { ) (*AuthProviderOIDC, error) {
var err error var err error
// grab oidc config if it hasn't been already // grab oidc config if it hasn't been already
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) //nolint:contextcheck
if err != nil { if err != nil {
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err) return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
} }
@@ -163,13 +163,14 @@ func (a *AuthProviderOIDC) RegisterHandler(
for k, v := range a.cfg.ExtraParams { for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v)) extras = append(extras, oauth2.SetAuthURLParam(k, v))
} }
extras = append(extras, oidc.Nonce(nonce)) extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info // Cache the registration info
a.registrationCache.Set(state, registrationInfo) a.registrationCache.Set(state, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(state, extras...) authURL := a.oauth2Config.AuthCodeURL(state, extras...)
log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Caller().Msgf("redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound) http.Redirect(writer, req, authURL, http.StatusFound)
} }
@@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
} }
stateCookieName := getCookieName("state", state) stateCookieName := getCookieName("state", state)
cookieState, err := req.Cookie(stateCookieName) cookieState, err := req.Cookie(stateCookieName)
if err != nil { if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
@@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
httpError(writer, err) httpError(writer, err)
return return
} }
if idToken.Nonce == "" { if idToken.Nonce == "" {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
return return
} }
nonceCookieName := getCookieName("nonce", idToken.Nonce) nonceCookieName := getCookieName("nonce", idToken.Nonce)
nonce, err := req.Cookie(nonceCookieName) nonce, err := req.Cookie(nonceCookieName)
if err != nil { if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return return
} }
if idToken.Nonce != nonce.Value { if idToken.Nonce != nonce.Value {
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
return return
@@ -231,7 +236,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
nodeExpiry := a.determineNodeExpiry(idToken.Expiry) nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
var claims types.OIDCClaims var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil { if err := idToken.Claims(&claims); err != nil { //nolint:noinlineerr
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
return return
} }
@@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Fetch user information (email, groups, name, etc) from the userinfo endpoint // Fetch user information (email, groups, name, etc) from the userinfo endpoint
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
var userinfo *oidc.UserInfo var userinfo *oidc.UserInfo
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token)) userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
if err != nil { if err != nil {
util.LogErr(err, "could not get userinfo; only using claims from id token") util.LogErr(err, "could not get userinfo; only using claims from id token")
@@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified) claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username) claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
claims.Name = cmp.Or(userinfo2.Name, claims.Name) claims.Name = cmp.Or(userinfo2.Name, claims.Name)
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL) claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
if userinfo2.Groups != nil { if userinfo2.Groups != nil {
claims.Groups = userinfo2.Groups claims.Groups = userinfo2.Groups
@@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
Msgf("could not create or update user") Msgf("could not create or update user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not create or update user")) _, werr := writer.Write([]byte("Could not create or update user"))
if werr != nil { if werr != nil {
log.Error(). log.Error().
@@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist. // Register the node if it does not exist.
if registrationId != nil { if registrationId != nil {
verb := "Reauthenticated" verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
@@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return return
} }
httpError(writer, err) httpError(writer, err)
return return
} }
@@ -316,15 +327,12 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
} }
// TODO(kradalby): replace with go-elem // TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb) content := renderOIDCCallbackTemplate(user, verb)
if err != nil {
httpError(writer, err)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
if _, err := writer.Write(content.Bytes()); err != nil { //nolint:noinlineerr
util.LogErr(err, "Failed to write HTTP response") util.LogErr(err, "Failed to write HTTP response")
} }
@@ -370,6 +378,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if !ok { if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
} }
if regInfo.Verifier != nil { if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
} }
@@ -377,7 +386,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("exchanging code for token: %w", err))
} }
return oauth2Token, err return oauth2Token, err
@@ -394,9 +403,10 @@ func (a *AuthProviderOIDC) extractIDToken(
} }
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID}) verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken) idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err)) return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("verifying ID token: %w", err))
} }
return idToken, nil return idToken, nil
@@ -516,6 +526,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
newUser bool newUser bool
c change.Change c change.Change
) )
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) { if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
@@ -561,7 +572,7 @@ func (a *AuthProviderOIDC) handleRegistration(
util.RegisterMethodOIDC, util.RegisterMethodOIDC,
) )
if err != nil { if err != nil {
return false, fmt.Errorf("could not register node: %w", err) return false, fmt.Errorf("registering node: %w", err)
} }
// This is a bit of a back and forth, but we have a bit of a chicken and egg // This is a bit of a back and forth, but we have a bit of a chicken and egg
@@ -589,9 +600,9 @@ func (a *AuthProviderOIDC) handleRegistration(
func renderOIDCCallbackTemplate( func renderOIDCCallbackTemplate(
user *types.User, user *types.User,
verb string, verb string,
) (*bytes.Buffer, error) { ) *bytes.Buffer {
html := templates.OIDCCallback(user.Display(), verb).Render() html := templates.OIDCCallback(user.Display(), verb).Render()
return bytes.NewBufferString(html), nil return bytes.NewBufferString(html)
} }
// getCookieName generates a unique cookie name based on a cookie value. // getCookieName generates a unique cookie name based on a cookie value.

View File

@@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
) { ) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) _, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
} }
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
@@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage(
) { ) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) _, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
} }
func (h *Headscale) ApplePlatformConfig( func (h *Headscale) ApplePlatformConfig(
@@ -37,6 +37,7 @@ func (h *Headscale) ApplePlatformConfig(
req *http.Request, req *http.Request,
) { ) {
vars := mux.Vars(req) vars := mux.Vars(req)
platform, ok := vars["platform"] platform, ok := vars["platform"]
if !ok { if !ok {
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil)) httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
@@ -64,17 +65,20 @@ func (h *Headscale) ApplePlatformConfig(
switch platform { switch platform {
case "macos-standalone": case "macos-standalone":
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil { err := macosStandaloneTemplate.Execute(&payload, platformConfig)
if err != nil {
httpError(writer, err) httpError(writer, err)
return return
} }
case "macos-app-store": case "macos-app-store":
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil { err := macosAppStoreTemplate.Execute(&payload, platformConfig)
if err != nil {
httpError(writer, err) httpError(writer, err)
return return
} }
case "ios": case "ios":
if err := iosTemplate.Execute(&payload, platformConfig); err != nil { err := iosTemplate.Execute(&payload, platformConfig)
if err != nil {
httpError(writer, err) httpError(writer, err)
return return
} }
@@ -90,7 +94,7 @@ func (h *Headscale) ApplePlatformConfig(
} }
var content bytes.Buffer var content bytes.Buffer
if err := commonTemplate.Execute(&content, config); err != nil { if err := commonTemplate.Execute(&content, config); err != nil { //nolint:noinlineerr
httpError(writer, err) httpError(writer, err)
return return
} }
@@ -98,7 +102,7 @@ func (h *Headscale) ApplePlatformConfig(
writer.Header(). writer.Header().
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
writer.Write(content.Bytes()) _, _ = writer.Write(content.Bytes())
} }
type AppleMobileConfig struct { type AppleMobileConfig struct {

View File

@@ -16,15 +16,18 @@ type Match struct {
dests *netipx.IPSet dests *netipx.IPSet
} }
func (m Match) DebugString() string { func (m *Match) DebugString() string {
var sb strings.Builder var sb strings.Builder
sb.WriteString("Match:\n") sb.WriteString("Match:\n")
sb.WriteString(" Sources:\n") sb.WriteString(" Sources:\n")
for _, prefix := range m.srcs.Prefixes() { for _, prefix := range m.srcs.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n") sb.WriteString(" " + prefix.String() + "\n")
} }
sb.WriteString(" Destinations:\n") sb.WriteString(" Destinations:\n")
for _, prefix := range m.dests.Prefixes() { for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n") sb.WriteString(" " + prefix.String() + "\n")
} }
@@ -42,7 +45,7 @@ func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match {
} }
func MatchFromFilterRule(rule tailcfg.FilterRule) Match { func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
dests := []string{} dests := make([]string, 0, len(rule.DstPorts))
for _, dest := range rule.DstPorts { for _, dest := range rule.DstPorts {
dests = append(dests, dest.IP) dests = append(dests, dest.IP)
} }
@@ -93,11 +96,24 @@ func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool {
return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix) return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix)
} }
// DestsIsTheInternet reports if the destination is equal to "the internet" // DestsIsTheInternet reports if the destination contains "the internet"
// which is a IPSet that represents "autogroup:internet" and is special // which is a IPSet that represents "autogroup:internet" and is special
// cased for exit nodes. // cased for exit nodes.
func (m Match) DestsIsTheInternet() bool { // This checks if dests is a superset of TheInternet(), which handles
return m.dests.Equal(util.TheInternet()) || // merged filter rules where TheInternet is combined with other destinations.
m.dests.ContainsPrefix(tsaddr.AllIPv4()) || func (m *Match) DestsIsTheInternet() bool {
m.dests.ContainsPrefix(tsaddr.AllIPv6()) if m.dests.ContainsPrefix(tsaddr.AllIPv4()) ||
m.dests.ContainsPrefix(tsaddr.AllIPv6()) {
return true
}
// Check if dests contains all prefixes of TheInternet (superset check)
theInternet := util.TheInternet()
for _, prefix := range theInternet.Prefixes() {
if !m.dests.ContainsPrefix(prefix) {
return false
}
}
return true
} }

View File

@@ -19,18 +19,18 @@ type PolicyManager interface {
MatchersForNode(node types.NodeView) ([]matcher.Match, error) MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes // BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error) SetPolicy(pol []byte) (bool, error)
SetUsers(users []types.User) (bool, error) SetUsers(users []types.User) (bool, error)
SetNodes(nodes views.Slice[types.NodeView]) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
// NodeCanHaveTag reports whether the given node can have the given tag. // NodeCanHaveTag reports whether the given node can have the given tag.
NodeCanHaveTag(types.NodeView, string) bool NodeCanHaveTag(node types.NodeView, tag string) bool
// TagExists reports whether the given tag is defined in the policy. // TagExists reports whether the given tag is defined in the policy.
TagExists(tag string) bool TagExists(tag string) bool
// NodeCanApproveRoute reports whether the given node can approve the given route. // NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool
Version() int Version() int
DebugString() string DebugString() string
@@ -38,8 +38,11 @@ type PolicyManager interface {
// NewPolicyManager returns a new policy manager. // NewPolicyManager returns a new policy manager.
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
var polMan PolicyManager var (
var err error polMan PolicyManager
err error
)
polMan, err = policyv2.NewPolicyManager(pol, users, nodes) polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
if err != nil { if err != nil {
return nil, err return nil, err
} }
polMans = append(polMans, pm) polMans = append(polMans, pm)
} }
@@ -66,7 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
} }
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) polmanFuncs := make([]func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error), 0, 1)
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
return policyv2.NewPolicyManager(pol, u, n) return policyv2.NewPolicyManager(pol, u, n)

View File

@@ -9,7 +9,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/types/views" "tailscale.com/types/views"
) )
@@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
} }
// Sort and deduplicate // Sort and deduplicate
tsaddr.SortPrefixes(newApproved) slices.SortFunc(newApproved, netip.Prefix.Compare)
newApproved = slices.Compact(newApproved) newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid() return route.IsValid()
@@ -120,12 +119,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
// Sort the current approved for comparison // Sort the current approved for comparison
sortedCurrent := make([]netip.Prefix, len(currentApproved)) sortedCurrent := make([]netip.Prefix, len(currentApproved))
copy(sortedCurrent, currentApproved) copy(sortedCurrent, currentApproved)
tsaddr.SortPrefixes(sortedCurrent) slices.SortFunc(sortedCurrent, netip.Prefix.Compare)
// Only update if the routes actually changed // Only update if the routes actually changed
if !slices.Equal(sortedCurrent, newApproved) { if !slices.Equal(sortedCurrent, newApproved) {
// Log what changed // Log what changed
var added, kept []netip.Prefix var added, kept []netip.Prefix
for _, route := range newApproved { for _, route := range newApproved {
if !slices.Contains(sortedCurrent, route) { if !slices.Contains(sortedCurrent, route) {
added = append(added, route) added = append(added, route)
@@ -136,8 +136,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
if len(added) > 0 { if len(added) > 0 {
log.Debug(). log.Debug().
Uint64("node.id", nv.ID().Uint64()). EmbedObject(nv).
Str("node.name", nv.Hostname()).
Strs("routes.added", util.PrefixesToString(added)). Strs("routes.added", util.PrefixesToString(added)).
Strs("routes.kept", util.PrefixesToString(kept)). Strs("routes.kept", util.PrefixesToString(kept)).
Int("routes.total", len(newApproved)). Int("routes.total", len(newApproved)).

View File

@@ -3,16 +3,16 @@ package policy
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"testing" "testing"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views" "tailscale.com/types/views"
) )
@@ -32,10 +32,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: "test-node", Hostname: "test-node",
UserID: ptr.To(user1.ID), UserID: new(user1.ID),
User: ptr.To(user1), User: new(user1),
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), IPv4: new(netip.MustParseAddr("100.64.0.1")),
Tags: []string{"tag:test"}, Tags: []string{"tag:test"},
} }
@@ -44,10 +44,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: "other-node", Hostname: "other-node",
UserID: ptr.To(user2.ID), UserID: new(user2.ID),
User: ptr.To(user2), User: new(user2),
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), IPv4: new(netip.MustParseAddr("100.64.0.2")),
} }
// Create a policy that auto-approves specific routes // Create a policy that auto-approves specific routes
@@ -76,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
}` }`
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
assert.NoError(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
@@ -194,7 +194,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
// Sort for comparison since ApproveRoutesWithPolicy sorts the results // Sort for comparison since ApproveRoutesWithPolicy sorts the results
tsaddr.SortPrefixes(tt.wantApproved) slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
// Verify that all previously approved routes are still present // Verify that all previously approved routes are still present
@@ -304,20 +304,23 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: "testnode", Hostname: "testnode",
UserID: ptr.To(user.ID), UserID: new(user.ID),
User: ptr.To(user), User: new(user),
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved, ApprovedRoutes: tt.currentApproved,
} }
nodes := types.Nodes{&node} nodes := types.Nodes{&node}
// Create policy manager or use nil if specified // Create policy manager or use nil if specified
var pm PolicyManager var (
var err error pm PolicyManager
err error
)
if tt.name != "nil_policy_manager" { if tt.name != "nil_policy_manager" {
pm, err = pmf(users, nodes.ViewSlice()) pm, err = pmf(users, nodes.ViewSlice())
assert.NoError(t, err) require.NoError(t, err)
} else { } else {
pm = nil pm = nil
} }
@@ -330,7 +333,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
if tt.wantApproved == nil { if tt.wantApproved == nil {
assert.Nil(t, gotApproved, "expected nil approved routes") assert.Nil(t, gotApproved, "expected nil approved routes")
} else { } else {
tsaddr.SortPrefixes(tt.wantApproved) slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
} }
}) })

View File

@@ -13,7 +13,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr"
) )
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
@@ -92,8 +91,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
announcedRoutes: []netip.Prefix{}, // No routes announced anymore announcedRoutes: []netip.Prefix{}, // No routes announced anymore
nodeUser: "test", nodeUser: "test",
wantApproved: []netip.Prefix{ wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
}, },
wantChanged: false, wantChanged: false,
@@ -124,8 +123,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
nodeUser: "test", nodeUser: "test",
nodeTags: []string{"tag:approved"}, nodeTags: []string{"tag:approved"},
wantApproved: []netip.Prefix{ wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
}, },
wantChanged: true, wantChanged: true,
}, },
@@ -168,13 +167,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname, Hostname: tt.nodeHostname,
UserID: ptr.To(user.ID), UserID: new(user.ID),
User: ptr.To(user), User: new(user),
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes, RoutableIPs: tt.announcedRoutes,
}, },
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved, ApprovedRoutes: tt.currentApproved,
Tags: tt.nodeTags, Tags: tt.nodeTags,
} }
@@ -294,13 +293,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: "testnode", Hostname: "testnode",
UserID: ptr.To(user.ID), UserID: new(user.ID),
User: ptr.To(user), User: new(user),
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes, RoutableIPs: tt.announcedRoutes,
}, },
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved, ApprovedRoutes: tt.currentApproved,
} }
nodes := types.Nodes{&node} nodes := types.Nodes{&node}
@@ -331,6 +330,8 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
Name: "test", Name: "test",
} }
userID := user.ID
currentApproved := []netip.Prefix{ currentApproved := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("10.0.0.0/24"),
} }
@@ -343,13 +344,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
MachineKey: key.NewMachine().Public(), MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(), NodeKey: key.NewNode().Public(),
Hostname: "testnode", Hostname: "testnode",
UserID: ptr.To(user.ID), UserID: &userID,
User: ptr.To(user), User: &user,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes, RoutableIPs: announcedRoutes,
}, },
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: currentApproved, ApprovedRoutes: currentApproved,
} }

View File

@@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/ptr"
) )
var ap = func(ipStr string) *netip.Addr { var ap = func(ipStr string) *netip.Addr {
@@ -33,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
rules []tailcfg.FilterRule rules []tailcfg.FilterRule
node *types.Node node *types.Node
} }
tests := []struct { tests := []struct {
name string name string
args args args args
@@ -783,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
for _, v := range gotViews.All() { for _, v := range gotViews.All() {
got = append(got, v.AsStruct()) got = append(got, v.AsStruct())
} }
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
t.Log("Matchers: ") t.Log("Matchers: ")
for _, m := range matchers { for _, m := range matchers {
t.Log("\t+", m.DebugString()) t.Log("\t+", m.DebugString())
} }
@@ -796,7 +798,7 @@ func TestReduceNodes(t *testing.T) {
func TestReduceNodesFromPolicy(t *testing.T) { func TestReduceNodesFromPolicy(t *testing.T) {
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
var routes []netip.Prefix routes := make([]netip.Prefix, 0, len(routess))
for _, route := range routess { for _, route := range routess {
routes = append(routes, netip.MustParsePrefix(route)) routes = append(routes, netip.MustParsePrefix(route))
} }
@@ -891,11 +893,13 @@ func TestReduceNodesFromPolicy(t *testing.T) {
] ]
}`, }`,
node: n(1, "100.64.0.1", "mobile", "mobile"), node: n(1, "100.64.0.1", "mobile", "mobile"),
// autogroup:internet does not generate packet filters - it's handled
// by exit node routing via AllowedIPs, not by packet filtering.
// Only server is visible through the mobile -> server:80 rule.
want: types.Nodes{ want: types.Nodes{
n(2, "100.64.0.2", "server", "server"), n(2, "100.64.0.2", "server", "server"),
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
}, },
wantMatchers: 2, wantMatchers: 1,
}, },
{ {
name: "2788-exit-node-0000-route", name: "2788-exit-node-0000-route",
@@ -938,7 +942,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
n(2, "100.64.0.2", "server", "server"), n(2, "100.64.0.2", "server", "server"),
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
}, },
wantMatchers: 2, wantMatchers: 1,
}, },
{ {
name: "2788-exit-node-::0-route", name: "2788-exit-node-::0-route",
@@ -981,7 +985,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
n(2, "100.64.0.2", "server", "server"), n(2, "100.64.0.2", "server", "server"),
n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"),
}, },
wantMatchers: 2, wantMatchers: 1,
}, },
{ {
name: "2784-split-exit-node-access", name: "2784-split-exit-node-access",
@@ -1032,8 +1036,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager var (
var err error pm PolicyManager
err error
)
pm, err = pmf(nil, tt.nodes.ViewSlice()) pm, err = pmf(nil, tt.nodes.ViewSlice())
require.NoError(t, err) require.NoError(t, err)
@@ -1051,9 +1058,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
for _, v := range gotViews.All() { for _, v := range gotViews.All() {
got = append(got, v.AsStruct()) got = append(got, v.AsStruct())
} }
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
t.Log("Matchers: ") t.Log("Matchers: ")
for _, m := range matchers { for _, m := range matchers {
t.Log("\t+", m.DebugString()) t.Log("\t+", m.DebugString())
} }
@@ -1074,21 +1083,21 @@ func TestSSHPolicyRules(t *testing.T) {
nodeUser1 := types.Node{ nodeUser1 := types.Node{
Hostname: "user1-device", Hostname: "user1-device",
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
UserID: ptr.To(uint(1)), UserID: new(uint(1)),
User: ptr.To(users[0]), User: new(users[0]),
} }
nodeUser2 := types.Node{ nodeUser2 := types.Node{
Hostname: "user2-device", Hostname: "user2-device",
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
UserID: ptr.To(uint(2)), UserID: new(uint(2)),
User: ptr.To(users[1]), User: new(users[1]),
} }
taggedClient := types.Node{ taggedClient := types.Node{
Hostname: "tagged-client", Hostname: "tagged-client",
IPv4: ap("100.64.0.4"), IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)), UserID: new(uint(2)),
User: ptr.To(users[1]), User: new(users[1]),
Tags: []string{"tag:client"}, Tags: []string{"tag:client"},
} }
@@ -1096,8 +1105,8 @@ func TestSSHPolicyRules(t *testing.T) {
nodeTaggedServer := types.Node{ nodeTaggedServer := types.Node{
Hostname: "tagged-server", Hostname: "tagged-server",
IPv4: ap("100.64.0.5"), IPv4: ap("100.64.0.5"),
UserID: ptr.To(uint(1)), UserID: new(uint(1)),
User: ptr.To(users[0]), User: new(users[0]),
Tags: []string{"tag:server"}, Tags: []string{"tag:server"},
} }
@@ -1231,7 +1240,7 @@ func TestSSHPolicyRules(t *testing.T) {
] ]
}`, }`,
expectErr: true, expectErr: true,
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
}, },
{ {
name: "invalid-check-period", name: "invalid-check-period",
@@ -1278,7 +1287,7 @@ func TestSSHPolicyRules(t *testing.T) {
] ]
}`, }`,
expectErr: true, expectErr: true,
errorMessage: "autogroup \"autogroup:invalid\" is not supported", errorMessage: "autogroup not supported for SSH user",
}, },
{ {
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded", name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
@@ -1451,13 +1460,17 @@ func TestSSHPolicyRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager var (
var err error pm PolicyManager
err error
)
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
if tt.expectErr { if tt.expectErr {
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMessage) require.Contains(t, err.Error(), tt.errorMessage)
return return
} }
@@ -1480,6 +1493,7 @@ func TestReduceRoutes(t *testing.T) {
routes []netip.Prefix routes []netip.Prefix
rules []tailcfg.FilterRule rules []tailcfg.FilterRule
} }
tests := []struct { tests := []struct {
name string name string
args args args args
@@ -2101,6 +2115,7 @@ func TestReduceRoutes(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
matchers := matcher.MatchesFromFilterRules(tt.args.rules) matchers := matcher.MatchesFromFilterRules(tt.args.rules)
got := ReduceRoutes( got := ReduceRoutes(
tt.args.node.View(), tt.args.node.View(),
tt.args.routes, tt.args.routes,

View File

@@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
for _, rule := range rules { for _, rule := range rules {
// record if the rule is actually relevant for the given node. // record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange var dests []tailcfg.NetPortRange
DEST_LOOP: DEST_LOOP:
for _, dest := range rule.DstPorts { for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil) expanded, err := util.ParseIPSet(dest.IP, nil)

Some files were not shown because too many files have changed in this diff Show More